In [1]:
import numpy as np
import json
import os
import copy
import pickle

import mesh_sampling
import trimesh
from shape_data import ShapeData

from autoencoder_dataset import autoencoder_dataset, cached_autoencoder_dataset
from torch.utils.data import DataLoader

from spiral_utils import get_adj_trigs, generate_spirals
from models import SpiralAutoencoderVariationalLoss
from train_funcs import train_variational_autoencoder_dataloader
from test_funcs import test_variational_autoencoder_dataloader


import torch
from tensorboardX import SummaryWriter

from sklearn.metrics.pairwise import euclidean_distances
meshpackage = 'mpi-mesh' # 'mpi-mesh', 'trimesh'
root_dir = '/home/jingwang/Data/data/'

dataset = 'FaceWarehouse'
name = 'real_variational'

torch.backends.cudnn.benchmark = True


GPU = True
os.environ["CUDA_VISIBLE_DEVICES"]="2"
device_idx = 0
print(torch.cuda.get_device_name(device_idx))

GeForce GTX TITAN X


In [2]:
args = {}

generative_model = 'variational_autoencoder'
downsample_method = 'COMA_downsample' # choose'COMA_downsample' or 'meshlab_downsample'


# below are the arguments for the DFAUST run
reference_mesh_file = os.path.join(root_dir, dataset, 'template', 'template.obj')
downsample_directory = os.path.join(root_dir, dataset,'template', downsample_method)
ds_factors = [4, 4, 4, 4]
step_sizes = [2, 2, 1, 1, 1]
filter_sizes_enc = [[3, 16, 32, 64, 128],[[],[],[],[],[]]]
filter_sizes_dec = [[128, 64, 32, 32, 16],[[],[],[],[],3]]
dilation_flag = True
if dilation_flag:
    dilation=[2, 2, 1, 1, 1] 
else:
    dilation = None
reference_points = [[5930]]# [[3567,4051,4597]] # [[414]]  # used for COMA with 3 disconnected components

args = {'generative_model': generative_model,
        'name': name, 'data': os.path.join(root_dir, dataset, 'preprocessed',name),
        'results_folder':  os.path.join(root_dir, dataset,'results/spirals_'+ generative_model),
        'reference_mesh_file':reference_mesh_file, 'downsample_directory': downsample_directory,
        'checkpoint_file': 'checkpoint',
        'seed':2, 'loss':'l1',
        'batch_size': 16, 'num_epochs':300, 'eval_frequency':200, 'num_workers': 40,
        'filter_sizes_enc': filter_sizes_enc, 'filter_sizes_dec': filter_sizes_dec,
        'nz': 128, # 100 identity + 46 expression 
        'ds_factors': ds_factors, 'step_sizes' : step_sizes, 'dilation': dilation,
        
        'lr':1e-3, 
        'regularization': 5e-5,         
        'scheduler': True, 'decay_rate': 0.99,'decay_steps':1,  
        'resume': True,
        
        'mode':'test', 'shuffle': False, 'nVal': 100, 'normalization': True,
        'write_mesh': True,
        'lambda_var': 1e-5,
        'worst_face_num': 20,
        'use_cache': True}

args['results_folder'] = os.path.join(args['results_folder'],'latent_'+str(args['nz']))
    
if not os.path.exists(os.path.join(args['results_folder'])):
    os.makedirs(os.path.join(args['results_folder']))

summary_path = os.path.join(args['results_folder'],'summaries',args['name'])
if not os.path.exists(summary_path):
    os.makedirs(summary_path)  
    
checkpoint_path = os.path.join(args['results_folder'],'checkpoints', args['name'])
if not os.path.exists(checkpoint_path):
    os.makedirs(checkpoint_path)
    
samples_path = os.path.join(args['results_folder'],'samples', args['name'])
if not os.path.exists(samples_path):
    os.makedirs(samples_path)
    
prediction_path = os.path.join(args['results_folder'],'predictions', args['name'])
if not os.path.exists(prediction_path):
    os.makedirs(prediction_path)

if not os.path.exists(downsample_directory):
    os.makedirs(downsample_directory)

downsample_mesh_path = os.path.join(args['results_folder'],'downsample_mesh', args['name'])
if not os.path.exists(downsample_mesh_path):
    os.makedirs(downsample_mesh_path)
    
worst_mesh_path = os.path.join(args['results_folder'],'worst_test', args['name'])
if not os.path.exists(downsample_mesh_path):
    os.makedirs(downsample_mesh_path)

In [3]:
np.random.seed(args['seed'])
print("Loading data .. ")
if not os.path.exists(args['data']+'/mean.npy') or not os.path.exists(args['data']+'/std.npy'):
    shapedata =  ShapeData(nVal=args['nVal'], 
                          train_file=args['data']+'/train.npy', 
                          test_file=args['data']+'/test.npy', 
                          reference_mesh_file=args['reference_mesh_file'],
                          normalization = args['normalization'],
                          meshpackage = meshpackage, load_flag = True)
    np.save(args['data']+'/mean.npy', shapedata.mean)
    np.save(args['data']+'/std.npy', shapedata.std)
else:
    shapedata = ShapeData(nVal=args['nVal'], 
                         train_file=args['data']+'/train.npy',
                         test_file=args['data']+'/test.npy', 
                         reference_mesh_file=args['reference_mesh_file'],
                         normalization = args['normalization'],
                         meshpackage = meshpackage, load_flag = False)
    shapedata.mean = np.load(args['data']+'/mean.npy')
    shapedata.std = np.load(args['data']+'/std.npy')
    shapedata.n_vertex = shapedata.mean.shape[0]
    shapedata.n_features = shapedata.mean.shape[1]

if not os.path.exists(os.path.join(args['downsample_directory'],'downsampling_matrices.pkl')):
    if shapedata.meshpackage == 'trimesh':
        raise NotImplementedError('Rerun with mpi-mesh as meshpackage')
    print("Generating Transform Matrices ..")
    if downsample_method == 'COMA_downsample':
        M,A,D,U,F = mesh_sampling.generate_transform_matrices(shapedata.reference_mesh, args['ds_factors'])
        if args['write_mesh']:
            import openmesh
            for i in range(len(M)):
                mesh = openmesh.TriMesh(points=M[i].v,face_vertex_indices=M[i].f)
                openmesh.write_mesh(os.path.join(args['results_folder'],'downsample_mesh',args['name'], '%d.obj'%i),mesh)
    with open(os.path.join(args['downsample_directory'],'downsampling_matrices.pkl'), 'wb') as fp:
        M_verts_faces = [(M[i].v, M[i].f) for i in range(len(M))]
        pickle.dump({'M_verts_faces':M_verts_faces,'A':A,'D':D,'U':U,'F':F}, fp)
else:
    print("Loading Transform Matrices ..")
    with open(os.path.join(args['downsample_directory'],'downsampling_matrices.pkl'), 'rb') as fp:
        #downsampling_matrices = pickle.load(fp,encoding = 'latin1')
        downsampling_matrices = pickle.load(fp)
            
    M_verts_faces = downsampling_matrices['M_verts_faces']
    if shapedata.meshpackage == 'mpi-mesh':
        from psbody.mesh import Mesh
        M = [Mesh(v=M_verts_faces[i][0], f=M_verts_faces[i][1]) for i in range(len(M_verts_faces))]
    elif shapedata.meshpackage == 'trimesh':
        M = [trimesh.base.Trimesh(vertices=M_verts_faces[i][0], faces=M_verts_faces[i][1], process = False) for i in range(len(M_verts_faces))]
    A = downsampling_matrices['A']
    D = downsampling_matrices['D']
    U = downsampling_matrices['U']
    F = downsampling_matrices['F']
        
# Needs also an extra check to enforce points to belong to different disconnected component at each hierarchy level
print("Calculating reference points for downsampled versions..")
for i in range(len(args['ds_factors'])):
    if shapedata.meshpackage == 'mpi-mesh':
        dist = euclidean_distances(M[i+1].v, M[0].v[reference_points[0]])
    elif shapedata.meshpackage == 'trimesh':
        dist = euclidean_distances(M[i+1].vertices, M[0].vertices[reference_points[0]])
    reference_points.append(np.argmin(dist,axis=0).tolist())



Loading data .. 
Loading Transform Matrices ..
Calculating reference points for downsampled versions..


In [4]:
if shapedata.meshpackage == 'mpi-mesh':
    sizes = [x.v.shape[0] for x in M]
elif shapedata.meshpackage == 'trimesh':
    sizes = [x.vertices.shape[0] for x in M]
Adj, Trigs = get_adj_trigs(A, F, shapedata.reference_mesh, meshpackage = shapedata.meshpackage)

spirals_np, spiral_sizes,spirals = generate_spirals(args['step_sizes'], 
                                                    M, Adj, Trigs, 
                                                    reference_points = reference_points, 
                                                    dilation = args['dilation'], random = False, 
                                                    meshpackage = shapedata.meshpackage, 
                                                    counter_clockwise = True)

bU = []
bD = []
for i in range(len(D)):
    d = np.zeros((1,D[i].shape[0]+1,D[i].shape[1]+1))
    u = np.zeros((1,U[i].shape[0]+1,U[i].shape[1]+1))
    d[0,:-1,:-1] = D[i].todense()
    u[0,:-1,:-1] = U[i].todense()
    d[0,-1,-1] = 1
    u[0,-1,-1] = 1
    bD.append(d)
    bU.append(u)


spiral generation for hierarchy 0 (11510 vertices) finished
spiral generation for hierarchy 1 (2878 vertices) finished
spiral generation for hierarchy 2 (720 vertices) finished
spiral generation for hierarchy 3 (180 vertices) finished
spiral generation for hierarchy 4 (45 vertices) finished
spiral sizes for hierarchy 0:  14
spiral sizes for hierarchy 1:  13
spiral sizes for hierarchy 2:  9
spiral sizes for hierarchy 3:  9
spiral sizes for hierarchy 4:  9


In [5]:
torch.manual_seed(args['seed'])

if GPU:
#     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = torch.device("cuda:"+str(device_idx) if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")
print(device)

tspirals = [torch.from_numpy(s).long().to(device) for s in spirals_np]
tD = [torch.from_numpy(s).float().to(device) for s in bD]
tU = [torch.from_numpy(s).float().to(device) for s in bU]

cuda:0


In [6]:
# Building model, optimizer, and loss function

if args['use_cache']:
    dataset_train = cached_autoencoder_dataset(root_dir = args['data'], points_dataset = 'train',
                                               shapedata = shapedata,
                                               normalization = args['normalization'], device=device)
    dataloader_train = DataLoader(dataset_train, batch_size=args['batch_size'],\
                                         shuffle = args['shuffle'], num_workers=0)
else:
    dataset_train = autoencoder_dataset(root_dir = args['data'], points_dataset = 'train',
                                               shapedata = shapedata,
                                               normalization = args['normalization'])
    dataloader_train = DataLoader(dataset_train, batch_size=args['batch_size'],\
                                         shuffle = args['shuffle'], num_workers = args['num_workers'],pin_memory=True)

if args['use_cache']:
    dataset_val = cached_autoencoder_dataset(root_dir = args['data'], points_dataset = 'val', 
                                             shapedata = shapedata,
                                             normalization = args['normalization'], device=device)
    dataloader_val = DataLoader(dataset_val, batch_size=args['batch_size'],\
                                         shuffle = False, num_workers=0)
else:
    dataset_val = autoencoder_dataset(root_dir = args['data'], points_dataset = 'val', 
                                             shapedata = shapedata,
                                             normalization = args['normalization'])
    dataloader_val = DataLoader(dataset_val, batch_size=args['batch_size'],\
                                         shuffle = False, num_workers = args['num_workers'],pin_memory=True)


dataset_test = autoencoder_dataset(root_dir = args['data'], points_dataset = 'test',
                                          shapedata = shapedata,
                                          normalization = args['normalization'])

dataloader_test = DataLoader(dataset_test, batch_size=args['batch_size'],\
                                     shuffle = False, num_workers = args['num_workers'],  pin_memory=True)



if 'autoencoder' == args['generative_model']:
        model = SpiralAutoencoder(filters_enc = args['filter_sizes_enc'],   
                                  filters_dec = args['filter_sizes_dec'],
                                  latent_size=args['nz'],
                                  sizes=sizes,
                                  spiral_sizes=spiral_sizes,
                                  spirals=tspirals,
                                  D=tD, U=tU).to(device)
elif 'variational_autoencoder' == args['generative_model']:
    model = SpiralAutoencoderVariationalLoss(filters_enc = args['filter_sizes_enc'],   
                                  filters_dec = args['filter_sizes_dec'],
                                  latent_size=args['nz'],
                                  sizes=sizes,
                                  spiral_sizes=spiral_sizes,
                                  spirals=tspirals,
                                  D=tD, U=tU).to(device)
 
    
optim = torch.optim.Adam(model.parameters(),lr=args['lr'],weight_decay=args['regularization'])
if args['scheduler']:
    scheduler=torch.optim.lr_scheduler.StepLR(optim, args['decay_steps'],gamma=args['decay_rate'])
else:
    scheduler = None

if args['loss']=='l1':
    def loss_l1(outputs, targets):
        L = torch.abs(outputs - targets).mean()
        return L 
    loss_fn = loss_l1

def var_loss_fn(mu, logvar):
    var_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return var_loss



In [7]:
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params)) 
print(model)
# print(M[4].v.shape)

Total number of parameters is: 2480323
SpiralAutoencoderVariationalLoss(
  (conv): ModuleList(
    (0): SpiralConv(
      (conv): Linear(in_features=42, out_features=16, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (1): SpiralConv(
      (conv): Linear(in_features=208, out_features=32, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (2): SpiralConv(
      (conv): Linear(in_features=288, out_features=64, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (3): SpiralConv(
      (conv): Linear(in_features=576, out_features=128, bias=True)
      (activation): ELU(alpha=1.0)
    )
  )
  (fc_latent_enc_mu): Linear(in_features=5888, out_features=128, bias=True)
  (fc_latent_enc_logvar): Linear(in_features=5888, out_features=128, bias=True)
  (fc_latent_dec): Linear(in_features=128, out_features=5888, bias=True)
  (dconv): ModuleList(
    (0): SpiralConv(
      (conv): Linear(in_features=1152, out_features=64, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (1):

In [8]:
if args['mode'] == 'train':
    writer = SummaryWriter(summary_path)
    with open(os.path.join(args['results_folder'],'checkpoints', args['name'] +'_params.json'),'w') as fp:
        saveparams = copy.deepcopy(args)
        json.dump(saveparams, fp)
        
    if args['resume']:
            print('loading checkpoint from file %s'%(os.path.join(checkpoint_path,args['checkpoint_file'])))
            checkpoint_dict = torch.load(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar'),map_location=device)
            start_epoch = checkpoint_dict['epoch'] + 1
            model.load_state_dict(checkpoint_dict['autoencoder_state_dict'])
            optim.load_state_dict(checkpoint_dict['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint_dict['scheduler_state_dict'])
            print('Resuming from epoch %s'%(str(start_epoch)))     
    else:
        start_epoch = 0
        
    if args['generative_model'] == 'autoencoder':
        train_autoencoder_dataloader(dataloader_train, dataloader_val,
                          device, model, optim, loss_fn,
                          bsize = args['batch_size'],
                          start_epoch = start_epoch,
                          n_epochs = args['num_epochs'],
                          eval_freq = args['eval_frequency'],
                          scheduler = scheduler,
                          writer = writer,
                          save_recons=True,
                          shapedata=shapedata,
                          metadata_dir=checkpoint_path, samples_dir=samples_path,
                          checkpoint_path = args['checkpoint_file'])
    elif args['generative_model'] == 'variational_autoencoder':
        train_variational_autoencoder_dataloader(dataloader_train, dataloader_val,
                          device, model, optim, loss_fn, var_loss_fn, args['lambda_var'],
                          bsize = args['batch_size'],
                          start_epoch = start_epoch,
                          n_epochs = args['num_epochs'],
                          eval_freq = args['eval_frequency'],
                          scheduler = scheduler,
                          writer = writer,
                          save_recons=True,
                          shapedata=shapedata,
                          metadata_dir=checkpoint_path, samples_dir=samples_path,
                          checkpoint_path = args['checkpoint_file'])

loading checkpoint from file /home/jingwang/Data/data/FaceWarehouse/results/spirals_variational_autoencoder/latent_128/checkpoints/real_variational/checkpoint


  0%|          | 0/580 [00:00<?, ?it/s]

Resuming from epoch 120


100%|██████████| 580/580 [06:04<00:00,  1.59it/s]
100%|██████████| 7/7 [00:00<00:00,  7.65it/s]


epoch 120 | tr 0.0935972279644 | val 0.12669129163


100%|██████████| 580/580 [05:56<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 121 | tr 0.0913352404175 | val 0.125673815608


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 122 | tr 0.0901174583826 | val 0.126921301782


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 123 | tr 0.0893398786927 | val 0.124335560799


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 124 | tr 0.0892811511868 | val 0.12744318068


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 125 | tr 0.0884004267768 | val 0.126046479046


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 126 | tr 0.0882731916447 | val 0.127746262252


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 127 | tr 0.0877273355707 | val 0.128338495195


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 128 | tr 0.087765673381 | val 0.126531115174


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 129 | tr 0.0871206572087 | val 0.124528908432


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 130 | tr 0.086957635692 | val 0.126261770427


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 131 | tr 0.0865166120488 | val 0.12572383821


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 132 | tr 0.0864467153261 | val 0.12890542984


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 133 | tr 0.0862603593489 | val 0.124186766744


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 134 | tr 0.085896241819 | val 0.127513096929


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.09it/s]


epoch 135 | tr 0.086112780201 | val 0.126776534319


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 136 | tr 0.0854330744723 | val 0.125587154329


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 137 | tr 0.0852663879636 | val 0.126974284649


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 138 | tr 0.0850088105505 | val 0.12556846261


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 139 | tr 0.0848608441394 | val 0.127779404223


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 140 | tr 0.084686493591 | val 0.127618067265


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 141 | tr 0.0847079316732 | val 0.127387380302


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 142 | tr 0.0842977338181 | val 0.128841271996


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 143 | tr 0.0838566522917 | val 0.129471877217


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 144 | tr 0.0843997863988 | val 0.127794386744


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 145 | tr 0.0838031536932 | val 0.126327231526


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 146 | tr 0.0836464648488 | val 0.126126395762


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 147 | tr 0.0834622500015 | val 0.125546510518


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 148 | tr 0.0833791050803 | val 0.128741071224


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 149 | tr 0.0830107125486 | val 0.128401395679


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 150 | tr 0.0827607629756 | val 0.126347078681


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 151 | tr 0.0826669955048 | val 0.127203900516


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 152 | tr 0.0828287194387 | val 0.129330022037


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 153 | tr 0.0827080810121 | val 0.131046710908


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 154 | tr 0.082114789378 | val 0.125595413148


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 155 | tr 0.0821181095109 | val 0.129833897352


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 156 | tr 0.0822378508875 | val 0.126279719174


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 157 | tr 0.0817456749878 | val 0.12827134639


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 158 | tr 0.0817082302838 | val 0.126649749875


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 159 | tr 0.0817975445811 | val 0.12727527082


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.11it/s]


epoch 160 | tr 0.0813754815994 | val 0.128032041788


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 161 | tr 0.0813587938275 | val 0.12840057224


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.09it/s]


epoch 162 | tr 0.0813308443874 | val 0.128878329396


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 163 | tr 0.0811190221459 | val 0.126294587553


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 164 | tr 0.0806963372205 | val 0.125161167681


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 165 | tr 0.0808768860353 | val 0.126601489186


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 166 | tr 0.0806600966459 | val 0.128456626534


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 167 | tr 0.0807317702164 | val 0.124753981531


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 168 | tr 0.0805051937057 | val 0.125815992057


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 169 | tr 0.080312535565 | val 0.125816102922


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 170 | tr 0.0800865649021 | val 0.126286758482


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 171 | tr 0.0799387988602 | val 0.128109132648


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 172 | tr 0.0797663990023 | val 0.125209099054


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 173 | tr 0.0797497565505 | val 0.12594635874


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 174 | tr 0.0797028376496 | val 0.12976885736


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 175 | tr 0.0795872466574 | val 0.126011635959


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 176 | tr 0.0794586300465 | val 0.12650963366


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.09it/s]


epoch 177 | tr 0.0793565129925 | val 0.12573453635


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.09it/s]


epoch 178 | tr 0.0791845409649 | val 0.128726720512


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 179 | tr 0.0790446617865 | val 0.124491409659


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 180 | tr 0.0788213660491 | val 0.127066037357


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 181 | tr 0.0788670497722 | val 0.126664746106


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 182 | tr 0.0788414641701 | val 0.126416695714


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 183 | tr 0.0786496347009 | val 0.125971558988


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 184 | tr 0.0785685076539 | val 0.12863658309


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 185 | tr 0.0785346501603 | val 0.127537178993


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.11it/s]


epoch 186 | tr 0.0783074166112 | val 0.125892575979


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.11it/s]


epoch 187 | tr 0.0781448099762 | val 0.125719977617


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 188 | tr 0.0779699227805 | val 0.125759177208


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.11it/s]


epoch 189 | tr 0.0779712836804 | val 0.128553471863


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 190 | tr 0.0777741661103 | val 0.125525082946


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 191 | tr 0.0779059798553 | val 0.127822740674


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 192 | tr 0.0777264970772 | val 0.128249840736


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 193 | tr 0.0774533278983 | val 0.127113197446


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 194 | tr 0.0773858966755 | val 0.12731711328


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 195 | tr 0.0774635881314 | val 0.126749370098


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 196 | tr 0.0772142034281 | val 0.126965379119


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 197 | tr 0.0771785249099 | val 0.126899375021


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 198 | tr 0.0769879036037 | val 0.126757373214


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 199 | tr 0.0770588014275 | val 0.127516851425


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 200 | tr 0.0770271060153 | val 0.126788114607


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 201 | tr 0.0768473504175 | val 0.126611439586


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 202 | tr 0.0767409796216 | val 0.128273918927


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 203 | tr 0.0765146295701 | val 0.126559756398


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 204 | tr 0.0765438966582 | val 0.126051509678


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 205 | tr 0.076446434393 | val 0.126202525496


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 206 | tr 0.0764046294679 | val 0.126996386945


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 207 | tr 0.0762459460903 | val 0.125927505791


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 208 | tr 0.0762136175972 | val 0.127116561234


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 209 | tr 0.076121080898 | val 0.125459685922


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 210 | tr 0.0760457974056 | val 0.124879873097


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 211 | tr 0.0759236869242 | val 0.125811567307


100%|██████████| 580/580 [05:55<00:00,  1.62it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 212 | tr 0.075852663404 | val 0.12544067502


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 213 | tr 0.075808577779 | val 0.126500205696


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 214 | tr 0.0757291557588 | val 0.124770555496


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 215 | tr 0.0755793549377 | val 0.125991169214


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 216 | tr 0.0755357187884 | val 0.12667963028


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 217 | tr 0.0752973137102 | val 0.126486702561


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 218 | tr 0.0753857966127 | val 0.126406072378


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.09it/s]


epoch 219 | tr 0.075325434899 | val 0.127099686265


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 220 | tr 0.0752708598329 | val 0.127271085978


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 221 | tr 0.0750408127262 | val 0.125698310435


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 222 | tr 0.0750743852864 | val 0.125086376965


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 223 | tr 0.0749590035014 | val 0.125726017058


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 224 | tr 0.0748558319334 | val 0.124701615572


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 225 | tr 0.0748306132981 | val 0.124636489451


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 226 | tr 0.0747984331743 | val 0.125046389997


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 227 | tr 0.0746982903958 | val 0.124719335139


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 228 | tr 0.0745955652975 | val 0.12552313447


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 229 | tr 0.0745003053726 | val 0.124526467323


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.09it/s]


epoch 230 | tr 0.0743836468535 | val 0.125060845912


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 231 | tr 0.0744648042809 | val 0.12596981585


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 232 | tr 0.0743616394827 | val 0.124261137843


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 233 | tr 0.074268277529 | val 0.124415925145


100%|██████████| 580/580 [05:55<00:00,  1.62it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 234 | tr 0.0742325203429 | val 0.124812020063


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 235 | tr 0.074034720428 | val 0.125182133913


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 236 | tr 0.0740647334092 | val 0.126199449301


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 237 | tr 0.0740378779841 | val 0.126683093905


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 238 | tr 0.0738046930782 | val 0.126140496731


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 239 | tr 0.0738911519919 | val 0.126868778765


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 240 | tr 0.0737141460437 | val 0.126441178024


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 241 | tr 0.0737213368431 | val 0.125625659823


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 242 | tr 0.0736605254869 | val 0.125148750842


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 243 | tr 0.0735761712337 | val 0.124772302508


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 244 | tr 0.0735725524482 | val 0.126012691557


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 245 | tr 0.0733930674221 | val 0.125349895358


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 246 | tr 0.0734201940098 | val 0.124251668453


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 247 | tr 0.0732974079277 | val 0.12510948658


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.11it/s]


epoch 248 | tr 0.0733169900446 | val 0.12582611531


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.09it/s]


epoch 249 | tr 0.0732395557751 | val 0.125025185347


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 250 | tr 0.0730811616984 | val 0.124700248837


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 251 | tr 0.0730218694899 | val 0.127046827972


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 252 | tr 0.0730235268972 | val 0.125921624303


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 253 | tr 0.0728960785116 | val 0.126529980004


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 254 | tr 0.0729816787587 | val 0.125391917825


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 255 | tr 0.0728241158328 | val 0.1262780267


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 256 | tr 0.072914211619 | val 0.12432972908


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 257 | tr 0.0727278788028 | val 0.125762955546


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 258 | tr 0.0726991102356 | val 0.126445642412


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 259 | tr 0.0726150561124 | val 0.127402655482


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 260 | tr 0.0725716702383 | val 0.126344363987


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 261 | tr 0.0725143952359 | val 0.12634203434


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 262 | tr 0.0724217004946 | val 0.126370007992


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.09it/s]


epoch 263 | tr 0.0724175876071 | val 0.125116153657


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 264 | tr 0.0723724500501 | val 0.124523094594


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.11it/s]


epoch 265 | tr 0.0722607970109 | val 0.124905199409


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 266 | tr 0.0722738785862 | val 0.125774158537


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 267 | tr 0.0722155652426 | val 0.125756448805


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 268 | tr 0.0721770735265 | val 0.125656427145


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.09it/s]


epoch 269 | tr 0.0721059902089 | val 0.126136979163


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 270 | tr 0.0720100818523 | val 0.124737569988


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 271 | tr 0.0720198792245 | val 0.124737045765


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 272 | tr 0.0719001466344 | val 0.125862728655


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 273 | tr 0.0718809844862 | val 0.126136627793


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 274 | tr 0.0719444268479 | val 0.125792791545


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 275 | tr 0.0717499486074 | val 0.12581820637


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 276 | tr 0.0717329240693 | val 0.126561007202


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 277 | tr 0.0717425702846 | val 0.125194103718


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 278 | tr 0.0716331135225 | val 0.127590142488


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 279 | tr 0.0715677490907 | val 0.124682282805


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 280 | tr 0.0715384514681 | val 0.126314548254


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 281 | tr 0.0715042720709 | val 0.126141423583


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 282 | tr 0.0714864759867 | val 0.126006202996


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 283 | tr 0.0713801364575 | val 0.125183651149


100%|██████████| 580/580 [05:55<00:00,  1.62it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 284 | tr 0.0713524203876 | val 0.125070055425


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 285 | tr 0.0712736987862 | val 0.126298620105


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 286 | tr 0.0713085501487 | val 0.12596480161


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 287 | tr 0.0712268155967 | val 0.126706005931


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 288 | tr 0.0711783676826 | val 0.125191653967


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 289 | tr 0.0711323280144 | val 0.126164340079


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 290 | tr 0.0710960313292 | val 0.126703008115


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 291 | tr 0.0710831587299 | val 0.126987924278


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 292 | tr 0.0710489946835 | val 0.126148284078


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 293 | tr 0.0709210478412 | val 0.125757313669


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 294 | tr 0.0709373486967 | val 0.125989741981


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 295 | tr 0.0709066725625 | val 0.124821308255


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 296 | tr 0.0708482611796 | val 0.126449755132


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 297 | tr 0.0708246250219 | val 0.125583043098


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 298 | tr 0.0707376234105 | val 0.125319937766


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 299 | tr 0.0707067936787 | val 0.126154121757


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 300 | tr 0.0707101332473 | val 0.125371993482


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 301 | tr 0.0706459562198 | val 0.126500203609


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 302 | tr 0.0705815352114 | val 0.126795408428


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 303 | tr 0.0706188007547 | val 0.126654640734


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 304 | tr 0.0705168687321 | val 0.126870725453


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 305 | tr 0.070499331093 | val 0.125136205256


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 306 | tr 0.0704839127588 | val 0.125927200913


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 307 | tr 0.070381807356 | val 0.126062508523


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 308 | tr 0.0704158423276 | val 0.126422215998


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 309 | tr 0.0703547338859 | val 0.125470393002


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 310 | tr 0.0703175487575 | val 0.125587899685


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 311 | tr 0.0702230943046 | val 0.12566847831


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 312 | tr 0.0702083020246 | val 0.127161608934


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 313 | tr 0.0702060948415 | val 0.125264793336


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 314 | tr 0.0701460674148 | val 0.125821827352


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 315 | tr 0.0701451830694 | val 0.127538530827


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.09it/s]


epoch 316 | tr 0.0700689092535 | val 0.126152086258


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 317 | tr 0.0699773145133 | val 0.125601114929


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 318 | tr 0.0700515101311 | val 0.126463345289


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.11it/s]


epoch 319 | tr 0.0700200240674 | val 0.125347690582


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 320 | tr 0.0699583492017 | val 0.125313014686


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 321 | tr 0.0699269647485 | val 0.126739203334


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 322 | tr 0.0698834715475 | val 0.126257829964


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.11it/s]


epoch 323 | tr 0.0698637931768 | val 0.12575891763


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 324 | tr 0.0698036185883 | val 0.126274283528


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 325 | tr 0.0697933421181 | val 0.12672375828


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 326 | tr 0.0697546898291 | val 0.125805251598


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 327 | tr 0.0697587345309 | val 0.124800254703


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 328 | tr 0.0696425692403 | val 0.12661898911


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 329 | tr 0.0696347154292 | val 0.125929565728


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 330 | tr 0.0696430529757 | val 0.125633936524


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 331 | tr 0.0695995225603 | val 0.125830941796


100%|██████████| 580/580 [05:54<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 332 | tr 0.0695593301592 | val 0.125192117989


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.11it/s]


epoch 333 | tr 0.0695437350664 | val 0.126333985925


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 334 | tr 0.0695080683149 | val 0.125679027736


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 335 | tr 0.0694728460934 | val 0.125698602796


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 336 | tr 0.0694221386429 | val 0.125385129452


100%|██████████| 580/580 [05:54<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 337 | tr 0.0694406057098 | val 0.126262603402


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 338 | tr 0.0693960348594 | val 0.126050354838


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 339 | tr 0.0693816938727 | val 0.126323172152


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 340 | tr 0.0693305124754 | val 0.126411953568


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 341 | tr 0.0693239208696 | val 0.125804611146


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 342 | tr 0.0692809742201 | val 0.125997051299


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 343 | tr 0.0692394615504 | val 0.126697075963


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 344 | tr 0.0692453726869 | val 0.126810880303


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.11it/s]


epoch 345 | tr 0.0691945655849 | val 0.126045403183


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 346 | tr 0.0691668771464 | val 0.125955170095


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 347 | tr 0.0691024605956 | val 0.12588629663


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 348 | tr 0.0691193533869 | val 0.127104579508


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 349 | tr 0.0690443966152 | val 0.126450093389


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 350 | tr 0.069027043712 | val 0.127187848389


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 351 | tr 0.069013124653 | val 0.126279476881


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 352 | tr 0.0690317950244 | val 0.126480930448


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.11it/s]


epoch 353 | tr 0.0690345647905 | val 0.126283698678


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.11it/s]


epoch 354 | tr 0.0689865933539 | val 0.127503915727


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 355 | tr 0.0689316942517 | val 0.126733381152


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 356 | tr 0.0688434758577 | val 0.127076848149


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 357 | tr 0.0688898181671 | val 0.126800936162


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 358 | tr 0.0688960872205 | val 0.126666752398


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 359 | tr 0.0688140448954 | val 0.126142551005


100%|██████████| 580/580 [05:55<00:00,  1.62it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 360 | tr 0.0688281873948 | val 0.126671007872


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 361 | tr 0.0688350415756 | val 0.127262001336


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 362 | tr 0.0687573303115 | val 0.126174449027


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 363 | tr 0.0687404742143 | val 0.127038908303


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 364 | tr 0.0687065460944 | val 0.126997961104


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.09it/s]


epoch 365 | tr 0.0687047539086 | val 0.125871758759


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 366 | tr 0.0686947175526 | val 0.126495281458


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 367 | tr 0.0686546038194 | val 0.126218557954


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 368 | tr 0.0686531202803 | val 0.125877706707


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 369 | tr 0.0685773194543 | val 0.127534774244


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 370 | tr 0.0686110633201 | val 0.126948643327


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 371 | tr 0.0685607945611 | val 0.126557149291


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 372 | tr 0.0685569384489 | val 0.126178435981


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 373 | tr 0.0685177414839 | val 0.126515457034


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 374 | tr 0.0685279195679 | val 0.126832596362


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 375 | tr 0.0685137347029 | val 0.12607257545


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.11it/s]


epoch 376 | tr 0.0684640418032 | val 0.127140611112


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 377 | tr 0.0684257396722 | val 0.126580443382


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 378 | tr 0.0684299375348 | val 0.126510769427


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 379 | tr 0.0684062492167 | val 0.126992024481


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 380 | tr 0.0683664715367 | val 0.126395225823


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 381 | tr 0.0683590652742 | val 0.127085318267


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 382 | tr 0.068329817108 | val 0.1264921242


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 383 | tr 0.0683372716195 | val 0.126560344696


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 384 | tr 0.0682956174273 | val 0.126943388879


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 385 | tr 0.068293615649 | val 0.127532443404


100%|██████████| 580/580 [05:55<00:00,  1.64it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 386 | tr 0.0682777991647 | val 0.126336897016


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 387 | tr 0.0682576247973 | val 0.127297624052


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 388 | tr 0.0682472153599 | val 0.126878826022


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 389 | tr 0.0682406686953 | val 0.126097805798


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 390 | tr 0.0681985943803 | val 0.126939208806


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 391 | tr 0.0681503986297 | val 0.126395356655


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 392 | tr 0.0681873807362 | val 0.126691091657


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 393 | tr 0.0681273298081 | val 0.127273091376


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 394 | tr 0.0681454330875 | val 0.126659329832


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 395 | tr 0.0680958915759 | val 0.12666795969


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 396 | tr 0.0680783032343 | val 0.12701872766


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 397 | tr 0.0680566154099 | val 0.126565287411


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 398 | tr 0.0680607265504 | val 0.126024469435


100%|██████████| 580/580 [05:55<00:00,  1.63it/s]
100%|██████████| 7/7 [00:00<00:00,  8.10it/s]


epoch 399 | tr 0.06799378485 | val 0.126273058951
~FIN~


lambda_var=1e-5 | epoch 399 | tr 0.06799378485 | val 0.126273058951

In [None]:
model.module.fc_latent_dec.weight.device

In [9]:
if args['mode'] == 'test': # lambda_var = 1e-5
    print('loading checkpoint from file %s'%(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar')))
    checkpoint_dict = torch.load(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar'),map_location=device)
    model.load_state_dict(checkpoint_dict['autoencoder_state_dict'])
        
    predictions, norm_l1_loss, l2_loss = test_variational_autoencoder_dataloader(device, model, dataloader_test, 
                                                                     shapedata, worst_path=worst_mesh_path, 
                                                                     worst_face_num=args['worst_face_num'],
                                                                     mm_constant = 100)    
    np.save(os.path.join(prediction_path,'predictions'), predictions)   
        
    print('autoencoder: normalized loss', norm_l1_loss)
    
    print('autoencoder: euclidean distance in mm=', l2_loss)

loading checkpoint from file /home/jingwang/Data/data/FaceWarehouse/results/spirals_variational_autoencoder/latent_128/checkpoints/real_variational/checkpoint.pth.tar


100%|██████████| 42/42 [00:08<00:00,  7.63it/s]


('loss for ', 402, tensor(3.7335, device='cuda:0'))
('loss for ', 169, tensor(3.7734, device='cuda:0'))
('loss for ', 364, tensor(3.7402, device='cuda:0'))
('loss for ', 537, tensor(3.7736, device='cuda:0'))
('loss for ', 370, tensor(3.8049, device='cuda:0'))
('loss for ', 612, tensor(3.7892, device='cuda:0'))
('loss for ', 194, tensor(3.8605, device='cuda:0'))
('loss for ', 499, tensor(3.8726, device='cuda:0'))
('loss for ', 8, tensor(3.8141, device='cuda:0'))
('loss for ', 628, tensor(3.8478, device='cuda:0'))
('loss for ', 389, tensor(3.9959, device='cuda:0'))
('loss for ', 50, tensor(3.8525, device='cuda:0'))
('loss for ', 560, tensor(3.8142, device='cuda:0'))
('loss for ', 583, tensor(3.8838, device='cuda:0'))
('loss for ', 125, tensor(3.8742, device='cuda:0'))
('loss for ', 341, tensor(3.8894, device='cuda:0'))
('loss for ', 602, tensor(3.9182, device='cuda:0'))
('loss for ', 626, tensor(3.9905, device='cuda:0'))
('loss for ', 559, tensor(3.8887, device='cuda:0'))
('loss for ', 3

In [8]:
if args['mode'] == 'test':
    print('loading checkpoint from file %s'%(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar')))
    checkpoint_dict = torch.load(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar'),map_location=device)
    model.load_state_dict(checkpoint_dict['autoencoder_state_dict'])
        
    predictions, norm_l1_loss, l2_loss = test_variational_autoencoder_dataloader(device, model, dataloader_test, 
                                                                     shapedata, worst_path=worst_mesh_path, 
                                                                     worst_face_num=args['worst_face_num'],
                                                                     mm_constant = 100)    
    np.save(os.path.join(prediction_path,'predictions'), predictions)   
        
    print('autoencoder: normalized loss', norm_l1_loss)
    
    print('autoencoder: euclidean distance in mm=', l2_loss)

loading checkpoint from file /home/jingwang/Data/data/FaceWarehouse/results/spirals_variational_autoencoder/latent_128/checkpoints/real_variational/checkpoint.pth.tar


100%|██████████| 42/42 [00:08<00:00,  7.62it/s]


('loss for ', 560, tensor(3.3856, device='cuda:0'))
('loss for ', 169, tensor(3.3871, device='cuda:0'))
('loss for ', 499, tensor(3.3962, device='cuda:0'))
('loss for ', 402, tensor(3.4166, device='cuda:0'))
('loss for ', 537, tensor(3.3890, device='cuda:0'))
('loss for ', 364, tensor(3.4537, device='cuda:0'))
('loss for ', 612, tensor(3.4090, device='cuda:0'))
('loss for ', 370, tensor(3.4655, device='cuda:0'))
('loss for ', 390, tensor(3.4357, device='cuda:0'))
('loss for ', 389, tensor(3.5490, device='cuda:0'))
('loss for ', 628, tensor(3.4595, device='cuda:0'))
('loss for ', 602, tensor(3.4861, device='cuda:0'))
('loss for ', 583, tensor(3.4888, device='cuda:0'))
('loss for ', 50, tensor(3.4964, device='cuda:0'))
('loss for ', 626, tensor(3.4939, device='cuda:0'))
('loss for ', 341, tensor(3.5481, device='cuda:0'))
('loss for ', 125, tensor(3.5830, device='cuda:0'))
('loss for ', 8, tensor(3.4934, device='cuda:0'))
('loss for ', 194, tensor(3.5734, device='cuda:0'))
('loss for ', 5

In [10]:
if args['mode'] == 'test': # test with train set
    print('loading checkpoint from file %s'%(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar')))
    checkpoint_dict = torch.load(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar'),map_location=device)
    model.load_state_dict(checkpoint_dict['autoencoder_state_dict'])
        
    predictions, norm_l1_loss, l2_loss = test_variational_autoencoder_dataloader(device, model, dataloader_train, 
                                                                     shapedata, worst_path=worst_mesh_path, 
                                                                     worst_face_num=args['worst_face_num'],
                                                                     mm_constant = 100)    
    np.save(os.path.join(prediction_path,'predictions'), predictions)   
        
    print('autoencoder: normalized loss', norm_l1_loss)
    
    print('autoencoder: euclidean distance in mm=', l2_loss)

loading checkpoint from file /home/jingwang/Data/data/FaceWarehouse/results/spirals_variational_autoencoder/latent_128/checkpoints/real_variational/checkpoint.pth.tar


100%|██████████| 580/580 [01:30<00:00,  6.25it/s]


('loss for ', 7306, tensor(0.6339, device='cuda:0'))
('loss for ', 1443, tensor(0.6429, device='cuda:0'))
('loss for ', 5314, tensor(0.6481, device='cuda:0'))
('loss for ', 5167, tensor(0.6456, device='cuda:0'))
('loss for ', 1052, tensor(0.6705, device='cuda:0'))
('loss for ', 1533, tensor(0.6584, device='cuda:0'))
('loss for ', 4689, tensor(0.6725, device='cuda:0'))
('loss for ', 5009, tensor(0.6617, device='cuda:0'))
('loss for ', 7357, tensor(0.6693, device='cuda:0'))
('loss for ', 7875, tensor(0.8278, device='cuda:0'))
('loss for ', 3114, tensor(0.7798, device='cuda:0'))
('loss for ', 3660, tensor(0.8747, device='cuda:0'))
('loss for ', 9272, tensor(0.6907, device='cuda:0'))
('loss for ', 6394, tensor(0.7031, device='cuda:0'))
('loss for ', 8612, tensor(0.6756, device='cuda:0'))
('loss for ', 2876, tensor(0.7130, device='cuda:0'))
('loss for ', 8032, tensor(0.6918, device='cuda:0'))
('loss for ', 737, tensor(0.6941, device='cuda:0'))
('loss for ', 4046, tensor(0.7354, device='cuda

In [12]:
dataset_train.paths[7357]

'5705'

In [37]:
from pprint import pprint
pprint(list(model.modules()))

[DataParallel(
  (module): SpiralAutoencoder(
    (conv_0): SpiralConv(
      (conv): Linear(in_features=42, out_features=16, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (conv_1): SpiralConv(
      (conv): Linear(in_features=208, out_features=32, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (conv_2): SpiralConv(
      (conv): Linear(in_features=288, out_features=64, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (conv_3): SpiralConv(
      (conv): Linear(in_features=576, out_features=128, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (fc_latent_enc): Linear(in_features=5888, out_features=146, bias=True)
    (fc_latent_dec): Linear(in_features=146, out_features=5888, bias=True)
    (dconv): ModuleList(
      (0): SpiralConv(
        (conv): Linear(in_features=1152, out_features=64, bias=True)
        (activation): ELU(alpha=1.0)
      )
      (1): SpiralConv(
        (conv): Linear(in_features=576, out_features=32, bias=True)
        (activation

In [10]:
# from optimize_to_get_result import optimize_to_get_result

import torch
import torch.optim as optim
import torch.nn as nn


def optimize_to_get_result(model, loss_fun, z_dim, device, targets, n_iter=10, output_loss=True):
    model.eval()
    
#     return torch.empty(0), model(targets)
    
    model.only_encode(True)
    z = model(targets)
    model.only_encode(False)
    
    model.only_decode(True)
    z_param = nn.Parameter(z)
    optimizer = optim.LBFGS(params=[z_param])
    for it in range(n_iter):
        def closure():
            optimizer.zero_grad()
            outputs = model(z_param)
            outputs = outputs.reshape(targets.shape) # decode, don't know shape
            loss = loss_fun(outputs, targets)
            loss.backward()
            if output_loss:
                print('loss ', loss.item())
            return loss
        optimizer.step(closure)
    outputs = model(z_param)
    outputs = outputs.reshape(targets.shape)
    
    model.only_decode(False)
    return z_param.data, outputs

def work_optimize():

    test_dataset = dataset_test
    test_ids =[560,169,499,402,537,364,612,370,390,389,628,602,583,50, 626,341,125,8,194,559]
    inputs = torch.cat([test_dataset[idx]['points'].to(device).unsqueeze(0) for idx in test_ids])
    inputs = inputs.to(device)

    print(inputs.shape) # fixme

    shapedata_mean = torch.Tensor(shapedata.mean).to(device)
    shapedata_std = torch.Tensor(shapedata.std).to(device)

    z, outputs = optimize_to_get_result(model, loss_fn, args['nz'], device,
                          inputs)
    from pprint import pprint
    pprint(z)

    outputs = outputs[:,:-1]
    inputs = inputs[:,:-1]

    old_outputs = outputs
    old_inputs = inputs

    outputs = outputs * shapedata_std + shapedata_mean
    inputs = inputs * shapedata_std + shapedata_mean

    per_l2_loss = torch.sqrt(torch.sum((outputs - inputs)**2, dim=2))

    for i in range(len(test_ids)):
        print(test_ids[i], torch.mean(per_l2_loss[i]))

    shapedata.save_meshes(os.path.join(worst_mesh_path, 'opt_input'),old_inputs.detach().cpu().numpy(),test_ids)
    shapedata.save_meshes(os.path.join(worst_mesh_path, 'opt_output'),old_outputs.detach().cpu().numpy(),test_ids)

work_optimize()

torch.Size([20, 11511, 3])
('loss ', 0.4986802637577057)
('loss ', 0.49858540296554565)
('loss ', 0.5245462656021118)
('loss ', 0.42691150307655334)
('loss ', 0.43026289343833923)
('loss ', 0.4063626825809479)
('loss ', 0.40158355236053467)
('loss ', 0.39085909724235535)
('loss ', 0.3910890221595764)
('loss ', 0.38281163573265076)
('loss ', 0.38037484884262085)
('loss ', 0.3779156506061554)
('loss ', 0.375139057636261)
('loss ', 0.37313753366470337)
('loss ', 0.36904338002204895)
('loss ', 0.3667895197868347)
('loss ', 0.36510035395622253)
('loss ', 0.3637033998966217)
('loss ', 0.3623831570148468)
('loss ', 0.36083289980888367)
('loss ', 0.35960763692855835)
('loss ', 0.35859280824661255)
('loss ', 0.357584148645401)
('loss ', 0.35685405135154724)
('loss ', 0.35606375336647034)
('loss ', 0.3552369475364685)
('loss ', 0.35456404089927673)
('loss ', 0.35376542806625366)
('loss ', 0.3530873954296112)
('loss ', 0.35253989696502686)
('loss ', 0.35190171003341675)
('loss ', 0.35117784142494

In [10]:
# from optimize_to_get_result import optimize_to_get_result
# lambda_var = 1e-5

import torch
import torch.optim as optim
import torch.nn as nn


def optimize_to_get_result(model, loss_fun, z_dim, device, targets, n_iter=10, output_loss=True):
    model.eval()
    
#     return torch.empty(0), model(targets)
    
    model.only_encode(True)
    z = model(targets)
    model.only_encode(False)
    
    model.only_decode(True)
    z_param = nn.Parameter(z)
    optimizer = optim.LBFGS(params=[z_param])
    for it in range(n_iter):
        def closure():
            optimizer.zero_grad()
            outputs = model(z_param)
            outputs = outputs.reshape(targets.shape) # decode, don't know shape
            loss = loss_fun(outputs, targets)
            loss.backward()
            if output_loss:
                print('loss ', loss.item())
            return loss
        optimizer.step(closure)
    outputs = model(z_param)
    outputs = outputs.reshape(targets.shape)
    
    model.only_decode(False)
    return z_param.data, outputs

def work_optimize():

    test_dataset = dataset_test
    test_ids =[560,169,499,402,537,364,612,370,390,389,628,602,583,50, 626,341,125,8,194,559]
    inputs = torch.cat([test_dataset[idx]['points'].to(device).unsqueeze(0) for idx in test_ids])
    inputs = inputs.to(device)

    print(inputs.shape) # fixme

    shapedata_mean = torch.Tensor(shapedata.mean).to(device)
    shapedata_std = torch.Tensor(shapedata.std).to(device)

    z, outputs = optimize_to_get_result(model, loss_fn, args['nz'], device,
                          inputs)
    from pprint import pprint
    pprint(z)

    outputs = outputs[:,:-1]
    inputs = inputs[:,:-1]

    old_outputs = outputs
    old_inputs = inputs

    outputs = outputs * shapedata_std + shapedata_mean
    inputs = inputs * shapedata_std + shapedata_mean

    per_l2_loss = torch.sqrt(torch.sum((outputs - inputs)**2, dim=2))

    for i in range(len(test_ids)):
        print(test_ids[i], torch.mean(per_l2_loss[i]))

    shapedata.save_meshes(os.path.join(worst_mesh_path, 'opt_input'),old_inputs.detach().cpu().numpy(),test_ids)
    shapedata.save_meshes(os.path.join(worst_mesh_path, 'opt_output'),old_outputs.detach().cpu().numpy(),test_ids)

work_optimize()

torch.Size([20, 11511, 3])
('loss ', 0.5567759275436401)
('loss ', 0.5563845634460449)
('loss ', 0.5783909559249878)
('loss ', 0.496934175491333)
('loss ', 0.48533686995506287)
('loss ', 0.47421765327453613)
('loss ', 0.47161945700645447)
('loss ', 0.46805039048194885)
('loss ', 0.4635757803916931)
('loss ', 0.4615626633167267)
('loss ', 0.4599860608577728)
('loss ', 0.4563556909561157)
('loss ', 0.45438069105148315)
('loss ', 0.4522996246814728)
('loss ', 0.4508543014526367)
('loss ', 0.44892674684524536)
('loss ', 0.4474303722381592)
('loss ', 0.4460495412349701)
('loss ', 0.44529640674591064)
('loss ', 0.44481122493743896)
('loss ', 0.44385385513305664)
('loss ', 0.4430656433105469)
('loss ', 0.44228968024253845)
('loss ', 0.4414037764072418)
('loss ', 0.44081467390060425)
('loss ', 0.4402202367782593)
('loss ', 0.4397233724594116)
('loss ', 0.4393344819545746)
('loss ', 0.43894192576408386)
('loss ', 0.4384746253490448)
('loss ', 0.4380544126033783)
('loss ', 0.43765899538993835)
(