In [12]:
import numpy as np
import json
import os
import copy
import pickle

import mesh_sampling
import trimesh
from shape_data import ShapeData

from autoencoder_dataset import autoencoder_dataset
from torch.utils.data import DataLoader

from spiral_utils import get_adj_trigs, generate_spirals
from models import SpiralAutoencoder
from train_funcs import train_autoencoder_dataloader
from test_funcs import test_autoencoder_dataloader


import torch
from tensorboardX import SummaryWriter

from sklearn.metrics.pairwise import euclidean_distances
meshpackage = 'mpi-mesh' # 'mpi-mesh', 'trimesh'
root_dir = '/home/jingwang/Data/data/'

dataset = 'FaceWarehouse'
name = 'align_pose'

torch.backends.cudnn.benchmark = True


GPU = True
device_idx = 0 # 0, 1, 2, 3
device_ids = [0, 1, 2, 3]
for idx in device_ids:
    print(torch.cuda.get_device_name(idx))

GeForce GTX TITAN X
GeForce GTX TITAN X
GeForce GTX TITAN X
GeForce GTX TITAN X


In [13]:
args = {}

generative_model = 'simple_autoencoder'
downsample_method = 'COMA_downsample' # choose'COMA_downsample' or 'meshlab_downsample'


# below are the arguments for the DFAUST run
reference_mesh_file = os.path.join(root_dir, dataset, 'template', 'template.obj')
downsample_directory = os.path.join(root_dir, dataset,'template', downsample_method)
ds_factors = [4]
step_sizes = [1, 1]
filter_sizes_enc = [[3, 32],[[],[]]]
filter_sizes_dec = [[32, 16],[[],3]]
dilation_flag = True
if dilation_flag:
    dilation=[1, 1] 
else:
    dilation = None
reference_points = [[5930]]# [[3567,4051,4597]] # [[414]]  # used for COMA with 3 disconnected components

args = {'generative_model': generative_model,
        'name': name, 'data': os.path.join(root_dir, dataset, 'preprocessed',name),
        'results_folder':  os.path.join(root_dir, dataset,'results/spirals_'+ generative_model),
        'reference_mesh_file':reference_mesh_file, 'downsample_directory': downsample_directory,
        'checkpoint_file': 'checkpoint',
        'seed':2, 'loss':'l1',
        'batch_size': 96, 'num_epochs':300, 'eval_frequency':200, 'num_workers': 40,
        'filter_sizes_enc': filter_sizes_enc, 'filter_sizes_dec': filter_sizes_dec,
        'nz': 128, # 100 identity + 46 expression 
        'ds_factors': ds_factors, 'step_sizes' : step_sizes, 'dilation': dilation,
        
        'lr':1e-3, 
        'regularization': 5e-5,         
        'scheduler': True, 'decay_rate': 0.99,'decay_steps':1,  
        'resume': False,
        
        'mode':'train', 'shuffle': True, 'nVal': 100, 'normalization': True,
        'write_mesh': True,
        'lambda_var': 0.5,
        'worst_face_num': 20}

args['results_folder'] = os.path.join(args['results_folder'],'latent_'+str(args['nz']))
    
if not os.path.exists(os.path.join(args['results_folder'])):
    os.makedirs(os.path.join(args['results_folder']))

summary_path = os.path.join(args['results_folder'],'summaries',args['name'])
if not os.path.exists(summary_path):
    os.makedirs(summary_path)  
    
checkpoint_path = os.path.join(args['results_folder'],'checkpoints', args['name'])
if not os.path.exists(checkpoint_path):
    os.makedirs(checkpoint_path)
    
samples_path = os.path.join(args['results_folder'],'samples', args['name'])
if not os.path.exists(samples_path):
    os.makedirs(samples_path)
    
prediction_path = os.path.join(args['results_folder'],'predictions', args['name'])
if not os.path.exists(prediction_path):
    os.makedirs(prediction_path)

if not os.path.exists(downsample_directory):
    os.makedirs(downsample_directory)

downsample_mesh_path = os.path.join(args['results_folder'],'downsample_mesh', args['name'])
if not os.path.exists(downsample_mesh_path):
    os.makedirs(downsample_mesh_path)


worst_mesh_path = os.path.join(args['results_folder'],'worst_test', args['name'])
if not os.path.exists(downsample_mesh_path):
    os.makedirs(downsample_mesh_path)

In [14]:
np.random.seed(args['seed'])
print("Loading data .. ")
if not os.path.exists(args['data']+'/mean.npy') or not os.path.exists(args['data']+'/std.npy'):
    shapedata =  ShapeData(nVal=args['nVal'], 
                          train_file=args['data']+'/train.npy', 
                          test_file=args['data']+'/test.npy', 
                          reference_mesh_file=args['reference_mesh_file'],
                          normalization = args['normalization'],
                          meshpackage = meshpackage, load_flag = True)
    np.save(args['data']+'/mean.npy', shapedata.mean)
    np.save(args['data']+'/std.npy', shapedata.std)
else:
    shapedata = ShapeData(nVal=args['nVal'], 
                         train_file=args['data']+'/train.npy',
                         test_file=args['data']+'/test.npy', 
                         reference_mesh_file=args['reference_mesh_file'],
                         normalization = args['normalization'],
                         meshpackage = meshpackage, load_flag = False)
    shapedata.mean = np.load(args['data']+'/mean.npy')
    shapedata.std = np.load(args['data']+'/std.npy')
    shapedata.n_vertex = shapedata.mean.shape[0]
    shapedata.n_features = shapedata.mean.shape[1]

if not os.path.exists(os.path.join(args['downsample_directory'],'downsampling_matrices.pkl')):
    if shapedata.meshpackage == 'trimesh':
        raise NotImplementedError('Rerun with mpi-mesh as meshpackage')
    print("Generating Transform Matrices ..")
    if downsample_method == 'COMA_downsample':
        M,A,D,U,F = mesh_sampling.generate_transform_matrices(shapedata.reference_mesh, args['ds_factors'])
        if args['write_mesh']:
            import openmesh
            for i in range(len(M)):
                mesh = openmesh.TriMesh(points=M[i].v,face_vertex_indices=M[i].f)
                openmesh.write_mesh(os.path.join(args['results_folder'],'downsample_mesh',args['name'], '%d.obj'%i),mesh)
    with open(os.path.join(args['downsample_directory'],'downsampling_matrices.pkl'), 'wb') as fp:
        M_verts_faces = [(M[i].v, M[i].f) for i in range(len(M))]
        pickle.dump({'M_verts_faces':M_verts_faces,'A':A,'D':D,'U':U,'F':F}, fp)
else:
    print("Loading Transform Matrices ..")
    with open(os.path.join(args['downsample_directory'],'downsampling_matrices.pkl'), 'rb') as fp:
        #downsampling_matrices = pickle.load(fp,encoding = 'latin1')
        downsampling_matrices = pickle.load(fp)
            
    M_verts_faces = downsampling_matrices['M_verts_faces']
    if shapedata.meshpackage == 'mpi-mesh':
        from psbody.mesh import Mesh
        M = [Mesh(v=M_verts_faces[i][0], f=M_verts_faces[i][1]) for i in range(len(M_verts_faces))]
    elif shapedata.meshpackage == 'trimesh':
        M = [trimesh.base.Trimesh(vertices=M_verts_faces[i][0], faces=M_verts_faces[i][1], process = False) for i in range(len(M_verts_faces))]
    A = downsampling_matrices['A']
    D = downsampling_matrices['D']
    U = downsampling_matrices['U']
    F = downsampling_matrices['F']
        
# Needs also an extra check to enforce points to belong to different disconnected component at each hierarchy level
print("Calculating reference points for downsampled versions..")
for i in range(len(args['ds_factors'])):
    if shapedata.meshpackage == 'mpi-mesh':
        dist = euclidean_distances(M[i+1].v, M[0].v[reference_points[0]])
    elif shapedata.meshpackage == 'trimesh':
        dist = euclidean_distances(M[i+1].vertices, M[0].vertices[reference_points[0]])
    reference_points.append(np.argmin(dist,axis=0).tolist())



Loading data .. 
Loading Transform Matrices ..
Calculating reference points for downsampled versions..


In [15]:
if shapedata.meshpackage == 'mpi-mesh':
    sizes = [x.v.shape[0] for x in M]
elif shapedata.meshpackage == 'trimesh':
    sizes = [x.vertices.shape[0] for x in M]
Adj, Trigs = get_adj_trigs(A, F, shapedata.reference_mesh, meshpackage = shapedata.meshpackage)

spirals_np, spiral_sizes,spirals = generate_spirals(args['step_sizes'], 
                                                    M, Adj, Trigs, 
                                                    reference_points = reference_points, 
                                                    dilation = args['dilation'], random = False, 
                                                    meshpackage = shapedata.meshpackage, 
                                                    counter_clockwise = True)

bU = []
bD = []
for i in range(len(D)):
    d = np.zeros((1,D[i].shape[0]+1,D[i].shape[1]+1))
    u = np.zeros((1,U[i].shape[0]+1,U[i].shape[1]+1))
    d[0,:-1,:-1] = D[i].todense()
    u[0,:-1,:-1] = U[i].todense()
    d[0,-1,-1] = 1
    u[0,-1,-1] = 1
    bD.append(d)
    bU.append(u)


spiral generation for hierarchy 0 (11510 vertices) finished
spiral generation for hierarchy 1 (2878 vertices) finished
spiral sizes for hierarchy 0:  10
spiral sizes for hierarchy 1:  9


In [16]:
torch.manual_seed(args['seed'])

if GPU:
#     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = torch.device("cuda:"+str(device_idx) if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")
print(device)

tspirals = [torch.from_numpy(s).long().to(device) for s in spirals_np]
tD = [torch.from_numpy(s).float().to(device) for s in bD]
tU = [torch.from_numpy(s).float().to(device) for s in bU]

cuda:0


In [17]:
tspirals[0].shape

torch.Size([1, 11511, 10])

In [18]:
# Building model, optimizer, and loss function

dataset_train = autoencoder_dataset(root_dir = args['data'], points_dataset = 'train',
                                           shapedata = shapedata,
                                           normalization = args['normalization'])

dataloader_train = DataLoader(dataset_train, batch_size=args['batch_size'],\
                                     shuffle = args['shuffle'], num_workers = args['num_workers'], pin_memory=True)

dataset_val = autoencoder_dataset(root_dir = args['data'], points_dataset = 'val', 
                                         shapedata = shapedata,
                                         normalization = args['normalization'])

dataloader_val = DataLoader(dataset_val, batch_size=args['batch_size'],\
                                     shuffle = False, num_workers = args['num_workers'], pin_memory=True)


dataset_test = autoencoder_dataset(root_dir = args['data'], points_dataset = 'test',
                                          shapedata = shapedata,
                                          normalization = args['normalization'])

dataloader_test = DataLoader(dataset_test, batch_size=args['batch_size'],\
                                     shuffle = False, num_workers = args['num_workers'],  pin_memory=True)



if 'autoencoder' in args['generative_model']:
        model = torch.nn.DataParallel(SpiralAutoencoder(filters_enc = args['filter_sizes_enc'],   
                                  filters_dec = args['filter_sizes_dec'],
                                  latent_size=args['nz'],
                                  sizes=sizes,
                                  spiral_sizes=spiral_sizes,
                                  spirals=tspirals,
                                  D=tD, U=tU).to(device),device_ids=device_ids)
        if torch.cuda.device_count() > 1:
#             model = torch.nn.parallel.DataParallel(model, device_ids=device_ids)
            print('Let\'s use %d GPUs!'%torch.cuda.device_count())
 
    
optim = torch.optim.Adam(model.parameters(),lr=args['lr'],weight_decay=args['regularization'])
if args['scheduler']:
    scheduler=torch.optim.lr_scheduler.StepLR(optim, args['decay_steps'],gamma=args['decay_rate'])
else:
    scheduler = None

def loss_l2(outputs, targets):
    L = torch.sqrt(torch.mean((outputs - targets)**2))
    return L
    
if args['loss']=='l1':
    def loss_l1(outputs, targets):
        L = torch.abs(outputs - targets).mean()
        return L 
    loss_fn = loss_l1
elif arg['loss']=='l1_var':
    lambda_var = args['lambda_var']
    def variational_loss(tx,tx_hat):
        x,mu,logvar = tx
        l1_loss = torch.mean(torch.abs(x-tx))
        var_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return l1_loss + lambda_var * var_loss
    loss_fn = variational_loss



Let's use 4 GPUs!


    There is an imbalance between your GPUs. You may want to exclude GPU 0 which
    has less than 75% of the memory or cores of GPU 1. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.


In [19]:
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params)) 
print(model)
# print(M[4].v.shape)

Total number of parameters is: 23683635
DataParallel(
  (module): SpiralAutoencoder(
    (conv): ModuleList(
      (0): SpiralConv(
        (conv): Linear(in_features=30, out_features=32, bias=True)
        (activation): ELU(alpha=1.0)
      )
    )
    (fc_latent_enc): Linear(in_features=92128, out_features=128, bias=True)
    (fc_latent_dec): Linear(in_features=128, out_features=92128, bias=True)
    (dconv): ModuleList(
      (0): SpiralConv(
        (conv): Linear(in_features=320, out_features=16, bias=True)
        (activation): ELU(alpha=1.0)
      )
      (1): SpiralConv(
        (conv): Linear(in_features=160, out_features=3, bias=True)
      )
    )
  )
)


In [10]:
128*5023

642944

In [11]:
2688/128

21

In [14]:
if args['mode'] == 'train':
    writer = SummaryWriter(summary_path)
    with open(os.path.join(args['results_folder'],'checkpoints', args['name'] +'_params.json'),'w') as fp:
        saveparams = copy.deepcopy(args)
        json.dump(saveparams, fp)
        
    if args['resume']:
            print('loading checkpoint from file %s'%(os.path.join(checkpoint_path,args['checkpoint_file'])))
            checkpoint_dict = torch.load(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar'),map_location=device)
            start_epoch = checkpoint_dict['epoch'] + 1
            model.load_state_dict(checkpoint_dict['autoencoder_state_dict'])
            optim.load_state_dict(checkpoint_dict['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint_dict['scheduler_state_dict'])
            print('Resuming from epoch %s'%(str(start_epoch)))     
    else:
        start_epoch = 0
        
    if args['generative_model'] == 'autoencoder':
        train_autoencoder_dataloader(dataloader_train, dataloader_val,
                          device, model, optim, loss_fn,
                          bsize = args['batch_size'],
                          start_epoch = start_epoch,
                          n_epochs = args['num_epochs'],
                          eval_freq = args['eval_frequency'],
                          scheduler = scheduler,
                          writer = writer,
                          save_recons=True,
                          shapedata=shapedata,
                          metadata_dir=checkpoint_path, samples_dir=samples_path,
                          checkpoint_path = args['checkpoint_file'])

100%|██████████| 97/97 [01:23<00:00,  1.48it/s]
100%|██████████| 2/2 [00:05<00:00,  3.92s/it]


epoch 0 | tr 0.513800279745 | val 0.335121604204


100%|██████████| 97/97 [01:16<00:00,  1.48it/s]
100%|██████████| 2/2 [00:05<00:00,  2.94s/it]


epoch 1 | tr 0.276958546967 | val 0.25284075439


100%|██████████| 97/97 [01:16<00:00,  1.48it/s]
100%|██████████| 2/2 [00:05<00:00,  2.99s/it]


epoch 2 | tr 0.220896459499 | val 0.21656611383


100%|██████████| 97/97 [01:16<00:00,  1.48it/s]
100%|██████████| 2/2 [00:06<00:00,  3.01s/it]


epoch 3 | tr 0.192009133563 | val 0.190629543066


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.01s/it]


epoch 4 | tr 0.174960222738 | val 0.182885987163


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.84s/it]


epoch 5 | tr 0.156118331336 | val 0.167732248902


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.99s/it]


epoch 6 | tr 0.152241090248 | val 0.164357124567


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.97s/it]


epoch 7 | tr 0.141347240631 | val 0.156274980307


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.80s/it]


epoch 8 | tr 0.131721924451 | val 0.148227498531


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.96s/it]


epoch 9 | tr 0.131757073079 | val 0.138187095523


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.95s/it]


epoch 10 | tr 0.125128592448 | val 0.140960928798


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.94s/it]


epoch 11 | tr 0.121042308905 | val 0.136604729891


100%|██████████| 97/97 [01:17<00:00,  1.48it/s]
100%|██████████| 2/2 [00:05<00:00,  3.84s/it]


epoch 12 | tr 0.118327864388 | val 0.131004008949


100%|██████████| 97/97 [01:16<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  3.81s/it]


epoch 13 | tr 0.112974578322 | val 0.132992919385


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.03s/it]


epoch 14 | tr 0.110354766841 | val 0.128003394008


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


epoch 15 | tr 0.108574256918 | val 0.124137710929


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.79s/it]


epoch 16 | tr 0.104700444125 | val 0.121392237544


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.04s/it]


epoch 17 | tr 0.103483253049 | val 0.1215057832


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.87s/it]


epoch 18 | tr 0.0999110979253 | val 0.115056543648


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.03s/it]


epoch 19 | tr 0.0988929265275 | val 0.117165701091


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


epoch 20 | tr 0.0972873931044 | val 0.115165958405


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.05s/it]


epoch 21 | tr 0.0969804232747 | val 0.110871402621


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  3.85s/it]


epoch 22 | tr 0.0963694453753 | val 0.114491105974


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.91s/it]


epoch 23 | tr 0.0943089206157 | val 0.110424720645


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.94s/it]


epoch 24 | tr 0.0948657818909 | val 0.107327494621


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.94s/it]


epoch 25 | tr 0.0885547685726 | val 0.105570139289


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.95s/it]


epoch 26 | tr 0.0871771808585 | val 0.108905961215


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  3.80s/it]


epoch 27 | tr 0.0885929788238 | val 0.103905022144


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.96s/it]


epoch 28 | tr 0.0857047013168 | val 0.111413721144


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.87s/it]


epoch 29 | tr 0.0871363077441 | val 0.112896489799


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.95s/it]


epoch 30 | tr 0.0847681838891 | val 0.103460202217


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.90s/it]


epoch 31 | tr 0.0857295431967 | val 0.105345871747


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.07s/it]


epoch 32 | tr 0.0829253701557 | val 0.102092578709


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.81s/it]


epoch 33 | tr 0.0844706362691 | val 0.103589070141


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.01s/it]


epoch 34 | tr 0.0805756643414 | val 0.102762301862


100%|██████████| 97/97 [01:16<00:00,  1.48it/s]
100%|██████████| 2/2 [00:06<00:00,  3.06s/it]


epoch 35 | tr 0.0803270888483 | val 0.0991189852357


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.96s/it]


epoch 36 | tr 0.0801903401983 | val 0.108584632277


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


epoch 37 | tr 0.0786487989128 | val 0.0990088689327


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.03s/it]


epoch 38 | tr 0.0763705218147 | val 0.100557996333


100%|██████████| 97/97 [01:17<00:00,  1.48it/s]
100%|██████████| 2/2 [00:05<00:00,  2.82s/it]


epoch 39 | tr 0.0791962770809 | val 0.102594838142


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.93s/it]


epoch 40 | tr 0.0790191060767 | val 0.101214667261


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.95s/it]


epoch 41 | tr 0.0767901282115 | val 0.0973853135109


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.96s/it]


epoch 42 | tr 0.0733429383615 | val 0.0959478378296


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.97s/it]


epoch 43 | tr 0.0775003012398 | val 0.100433691442


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.94s/it]


epoch 44 | tr 0.0738170081942 | val 0.0972260433435


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.89s/it]


epoch 45 | tr 0.0718915398265 | val 0.0950180581212


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.95s/it]


epoch 46 | tr 0.0734505271089 | val 0.0983893111348


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.01s/it]


epoch 47 | tr 0.0730127130089 | val 0.101156622171


100%|██████████| 97/97 [01:16<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.02s/it]


epoch 48 | tr 0.0724734119043 | val 0.0946831774712


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.02s/it]


epoch 49 | tr 0.0703534423791 | val 0.0950577184558


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.83s/it]


epoch 50 | tr 0.0709805386077 | val 0.0940340796113


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.85s/it]


epoch 51 | tr 0.0723444790933 | val 0.093690263927


100%|██████████| 97/97 [01:17<00:00,  1.48it/s]
100%|██████████| 2/2 [00:05<00:00,  2.99s/it]


epoch 52 | tr 0.0692796472075 | val 0.0966285794973


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.04s/it]


epoch 53 | tr 0.0680940414811 | val 0.0910237535834


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.01s/it]


epoch 54 | tr 0.0683870873831 | val 0.0891500276327


100%|██████████| 97/97 [01:16<00:00,  1.48it/s]
100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


epoch 55 | tr 0.0675116892262 | val 0.0931577268243


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.96s/it]


epoch 56 | tr 0.0691985542918 | val 0.0923396411538


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.95s/it]


epoch 57 | tr 0.0665992095296 | val 0.0921356731653


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.99s/it]


epoch 58 | tr 0.068219100427 | val 0.0876381412148


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.90s/it]


epoch 59 | tr 0.0664227854204 | val 0.0931127288938


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.95s/it]


epoch 60 | tr 0.0673470342211 | val 0.0881303203106


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.97s/it]


epoch 61 | tr 0.0652612595337 | val 0.0888529220223


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.04s/it]


epoch 62 | tr 0.0648579602868 | val 0.0895755526423


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.97s/it]


epoch 63 | tr 0.0651003413545 | val 0.0871379181743


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


epoch 64 | tr 0.0646103653543 | val 0.0890098166466


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.99s/it]


epoch 65 | tr 0.0645993092451 | val 0.0890401789546


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.95s/it]


epoch 66 | tr 0.0636784036098 | val 0.0914680179954


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.04s/it]


epoch 67 | tr 0.0630666130704 | val 0.0873731401563


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.00s/it]


epoch 68 | tr 0.0632755833196 | val 0.0880865639448


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.82s/it]


epoch 69 | tr 0.0620367746286 | val 0.0912196516991


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.95s/it]


epoch 70 | tr 0.0621185887733 | val 0.0879690054059


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.94s/it]


epoch 71 | tr 0.061710526473 | val 0.086039044261


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


epoch 72 | tr 0.0612018877062 | val 0.0883620604873


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  4.43s/it]


epoch 73 | tr 0.0624775114758 | val 0.0863000246882


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


epoch 74 | tr 0.0606706571862 | val 0.0854777187109


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


epoch 75 | tr 0.061238202427 | val 0.0856770351529


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  3.95s/it]


epoch 76 | tr 0.0600160404141 | val 0.0893356016278


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.09s/it]


epoch 77 | tr 0.0606567973721 | val 0.0848695865273


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.04s/it]


epoch 78 | tr 0.0602011043835 | val 0.0833241665363


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.10s/it]


epoch 79 | tr 0.0596473264669 | val 0.0870857480168


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.88s/it]


epoch 80 | tr 0.0605109778972 | val 0.0872258359194


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.82s/it]


epoch 81 | tr 0.0584978110061 | val 0.0853892213106


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.75s/it]


epoch 82 | tr 0.0587184519218 | val 0.0847591269016


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.95s/it]


epoch 83 | tr 0.0592104845134 | val 0.0829868745804


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.05s/it]


epoch 84 | tr 0.0580294880374 | val 0.0848208707571


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.05s/it]


epoch 85 | tr 0.0579493360787 | val 0.0859024205804


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.95s/it]


epoch 86 | tr 0.0573820202644 | val 0.0826702946424


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.92s/it]


epoch 87 | tr 0.0579679606932 | val 0.0823867097497


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.95s/it]


epoch 88 | tr 0.0573913198221 | val 0.0849234491587


100%|██████████| 97/97 [01:17<00:00,  1.48it/s]
100%|██████████| 2/2 [00:05<00:00,  2.96s/it]


epoch 89 | tr 0.0576112286027 | val 0.0840978959203


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.06s/it]


epoch 90 | tr 0.0571090320318 | val 0.0834539216757


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.00s/it]


epoch 91 | tr 0.0561830147202 | val 0.0835397914052


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.05s/it]


epoch 92 | tr 0.0567982992222 | val 0.085640167594


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.06s/it]


epoch 93 | tr 0.0557449756254 | val 0.081491432786


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.24s/it]


epoch 94 | tr 0.0564371332012 | val 0.0839791631699


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.88s/it]


epoch 95 | tr 0.0557472694794 | val 0.0837182462215


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.96s/it]


epoch 96 | tr 0.0550871528942 | val 0.082709518671


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.02s/it]


epoch 97 | tr 0.0545989758881 | val 0.0836737778783


100%|██████████| 97/97 [01:17<00:00,  1.44it/s]
100%|██████████| 2/2 [00:06<00:00,  3.05s/it]


epoch 98 | tr 0.0551566685197 | val 0.0833191302419


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.75s/it]


epoch 99 | tr 0.0546748550673 | val 0.0838296976686


100%|██████████| 97/97 [01:16<00:00,  1.48it/s]
100%|██████████| 2/2 [00:06<00:00,  3.02s/it]


epoch 100 | tr 0.0544670195158 | val 0.0818810129166


100%|██████████| 97/97 [01:16<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.02s/it]


epoch 101 | tr 0.055506487666 | val 0.081730542779


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


epoch 102 | tr 0.0539771685955 | val 0.0822180664539


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  3.85s/it]


epoch 103 | tr 0.0537900714011 | val 0.0819152125716


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  4.06s/it]


epoch 104 | tr 0.0541855352064 | val 0.0843082004786


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


epoch 105 | tr 0.0544590070715 | val 0.0847873979807


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.99s/it]


epoch 106 | tr 0.0543028850879 | val 0.0837322482467


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.96s/it]


epoch 107 | tr 0.0537593138244 | val 0.0820875179768


100%|██████████| 97/97 [01:17<00:00,  1.45it/s]
100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


epoch 108 | tr 0.0528712451972 | val 0.0814978203177


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.92s/it]


epoch 109 | tr 0.0524014643169 | val 0.0818753686547


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.84s/it]


epoch 110 | tr 0.0531212457295 | val 0.0817824172974


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.03s/it]


epoch 111 | tr 0.05253004925 | val 0.0819513863325


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.90s/it]


epoch 112 | tr 0.0528285017173 | val 0.0809443834424


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.00s/it]


epoch 113 | tr 0.0525773321503 | val 0.0815666034818


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  3.88s/it]


epoch 114 | tr 0.0533221821461 | val 0.0827595576644


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.01s/it]


epoch 115 | tr 0.0523477437424 | val 0.0815759733319


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


epoch 116 | tr 0.0516530590325 | val 0.0821881455183


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.04s/it]


epoch 117 | tr 0.0518682603435 | val 0.0819729378819


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.96s/it]


epoch 118 | tr 0.0518247302879 | val 0.0822987809777


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


epoch 119 | tr 0.0510227641283 | val 0.0823409950733


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.96s/it]


epoch 120 | tr 0.0507352735076 | val 0.0819452899694


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.93s/it]


epoch 121 | tr 0.0504376232753 | val 0.0824178293347


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.94s/it]


epoch 122 | tr 0.0508318605608 | val 0.0799914494157


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.97s/it]


epoch 123 | tr 0.0515717535055 | val 0.0809717571735


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.89s/it]


epoch 124 | tr 0.0509327776859 | val 0.0807193443179


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.96s/it]


epoch 125 | tr 0.0504691418003 | val 0.0821026480198


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.88s/it]


epoch 126 | tr 0.0500171009837 | val 0.0811383041739


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.02s/it]


epoch 127 | tr 0.0503276660781 | val 0.0848111996055


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  3.86s/it]


epoch 128 | tr 0.0500316419221 | val 0.0812847587466


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.92s/it]


epoch 129 | tr 0.0495718402082 | val 0.0816916844249


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.02s/it]


epoch 130 | tr 0.0495407554353 | val 0.0822428068519


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.05s/it]


epoch 131 | tr 0.0496408722267 | val 0.0824397140741


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


epoch 132 | tr 0.0496421236416 | val 0.0788401144743


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.12s/it]


epoch 133 | tr 0.0492863344221 | val 0.0813623270392


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.92s/it]


epoch 134 | tr 0.0491234937747 | val 0.0812745544314


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.06s/it]


epoch 135 | tr 0.0491591027071 | val 0.0812903633714


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.02s/it]


epoch 136 | tr 0.04927857418 | val 0.0821302828193


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.00s/it]


epoch 137 | tr 0.049081345275 | val 0.0821248164773


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.97s/it]


epoch 138 | tr 0.0487549237136 | val 0.0801691144705


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.08s/it]


epoch 139 | tr 0.0489161602254 | val 0.0810024783015


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.93s/it]


epoch 140 | tr 0.0486242692018 | val 0.0812080782652


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.07s/it]


epoch 141 | tr 0.0486474095105 | val 0.0830516597629


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.76s/it]


epoch 142 | tr 0.0481588004221 | val 0.0809144067764


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.96s/it]


epoch 143 | tr 0.0478417363403 | val 0.0811122047901


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.15s/it]


epoch 144 | tr 0.0484810113393 | val 0.080960008502


100%|██████████| 97/97 [01:17<00:00,  1.44it/s]
100%|██████████| 2/2 [00:06<00:00,  3.01s/it]


epoch 145 | tr 0.0479782150352 | val 0.0804876026511


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  3.83s/it]


epoch 146 | tr 0.0480051718909 | val 0.0809251672029


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.07s/it]


epoch 147 | tr 0.0475192888287 | val 0.0809379518032


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.99s/it]


epoch 148 | tr 0.0473585977637 | val 0.0797298666835


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.13s/it]


epoch 149 | tr 0.0477638957701 | val 0.0797364276648


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.94s/it]


epoch 150 | tr 0.0472013197582 | val 0.0816277834773


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  3.91s/it]


epoch 151 | tr 0.0475721420528 | val 0.0797000494599


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.97s/it]


epoch 152 | tr 0.0472324659084 | val 0.0799727171659


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.02s/it]


epoch 153 | tr 0.0472273498003 | val 0.0812162402272


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.03s/it]


epoch 154 | tr 0.0470241204161 | val 0.0791539546847


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.74s/it]


epoch 155 | tr 0.0467513727859 | val 0.0798414325714


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.87s/it]


epoch 156 | tr 0.0470137107218 | val 0.0810721036792


100%|██████████| 97/97 [01:17<00:00,  1.32it/s]
100%|██████████| 2/2 [00:06<00:00,  3.02s/it]


epoch 157 | tr 0.0465080775449 | val 0.080331992209


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.99s/it]


epoch 158 | tr 0.0469195251578 | val 0.0819525131583


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.01s/it]


epoch 159 | tr 0.0475641274118 | val 0.0798272448778


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.01s/it]


epoch 160 | tr 0.0466283605017 | val 0.0799839937687


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.00s/it]


epoch 161 | tr 0.0461517550437 | val 0.0807928541303


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.03s/it]


epoch 162 | tr 0.0461591037055 | val 0.0812677133083


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.96s/it]


epoch 163 | tr 0.0466052077711 | val 0.0804954856634


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.02s/it]


epoch 164 | tr 0.0460815508047 | val 0.0802111878991


100%|██████████| 97/97 [01:17<00:00,  1.35it/s]
100%|██████████| 2/2 [00:05<00:00,  3.86s/it]


epoch 165 | tr 0.0460915263228 | val 0.0797876673937


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.05s/it]


epoch 166 | tr 0.0458889733615 | val 0.0801259598136


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.05s/it]


epoch 167 | tr 0.0459220889057 | val 0.0787330156565


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  3.87s/it]


epoch 168 | tr 0.0463312172684 | val 0.0806934544444


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.05s/it]


epoch 169 | tr 0.0458405612101 | val 0.0817112511396


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.92s/it]


epoch 170 | tr 0.0457010189007 | val 0.0804140585661


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.98s/it]


epoch 171 | tr 0.0453123283412 | val 0.0793489128351


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.03s/it]


epoch 172 | tr 0.0455892373776 | val 0.0789650312066


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  3.90s/it]


epoch 173 | tr 0.0455432417064 | val 0.0799448296428


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.00s/it]


epoch 174 | tr 0.0454526579714 | val 0.0805144742131


100%|██████████| 97/97 [01:16<00:00,  1.48it/s]
100%|██████████| 2/2 [00:05<00:00,  3.85s/it]


epoch 175 | tr 0.0456335348559 | val 0.082800039947


100%|██████████| 97/97 [01:16<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  2.97s/it]


epoch 176 | tr 0.0453452645191 | val 0.079732452631


100%|██████████| 97/97 [01:16<00:00,  1.48it/s]
100%|██████████| 2/2 [00:05<00:00,  2.74s/it]


epoch 177 | tr 0.0448019974576 | val 0.080589722693


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.97s/it]


epoch 178 | tr 0.0449023708701 | val 0.0794488078356


100%|██████████| 97/97 [01:16<00:00,  1.48it/s]
100%|██████████| 2/2 [00:05<00:00,  2.96s/it]


epoch 179 | tr 0.0452224799785 | val 0.0811701521277


100%|██████████| 97/97 [01:16<00:00,  1.49it/s]
100%|██████████| 2/2 [00:06<00:00,  3.11s/it]


epoch 180 | tr 0.0443791696095 | val 0.0810977292061


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


epoch 181 | tr 0.0444695837153 | val 0.0785774725676


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.99s/it]


epoch 182 | tr 0.0446719126969 | val 0.0809445130825


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  3.83s/it]


epoch 183 | tr 0.0445976705901 | val 0.0813769769669


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.01s/it]


epoch 184 | tr 0.0446146301026 | val 0.0805625742674


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.09s/it]


epoch 185 | tr 0.0445354639456 | val 0.0796847113967


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.06s/it]


epoch 186 | tr 0.0443267217612 | val 0.0802139154077


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.01s/it]


epoch 187 | tr 0.0445886394587 | val 0.0795047777891


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.06s/it]


epoch 188 | tr 0.0441853779786 | val 0.0803125622869


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  4.10s/it]


epoch 189 | tr 0.0442246679345 | val 0.0806492355466


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:06<00:00,  3.05s/it]


epoch 190 | tr 0.0444306440138 | val 0.0803663027287


100%|██████████| 97/97 [01:17<00:00,  1.44it/s]
100%|██████████| 2/2 [00:06<00:00,  3.00s/it]


epoch 191 | tr 0.0440830403104 | val 0.0803855222464


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:05<00:00,  3.00s/it]


epoch 192 | tr 0.0442004784556 | val 0.0797406247258


100%|██████████| 97/97 [01:17<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.97s/it]


epoch 193 | tr 0.0440965398761 | val 0.0796398848295


100%|██████████| 97/97 [01:17<00:00,  1.46it/s]
100%|██████████| 2/2 [00:06<00:00,  3.06s/it]


epoch 194 | tr 0.0439418689701 | val 0.0802734678984


100%|██████████| 97/97 [01:16<00:00,  1.47it/s]
100%|██████████| 2/2 [00:05<00:00,  2.77s/it]


epoch 195 | tr 0.0440829874883 | val 0.0803302931786


100%|██████████| 97/97 [01:16<00:00,  1.48it/s]
100%|██████████| 2/2 [00:05<00:00,  3.56s/it]


epoch 196 | tr 0.0436791906454 | val 0.0794854050875


  0%|          | 0/97 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [13]:
model.module.only_decode

False

In [None]:
model.module.fc_latent_dec.weight.device

In [32]:
if args['mode'] == 'test':
    print('loading checkpoint from file %s'%(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar')))
    checkpoint_dict = torch.load(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar'),map_location=device)
    model.load_state_dict(checkpoint_dict['autoencoder_state_dict'])
        
    predictions, norm_l1_loss, l2_loss = test_autoencoder_dataloader(device, model, dataloader_test, 
                                                                     shapedata, worst_path=worst_mesh_path, 
                                                                     worst_face_num=args['worst_face_num'],
                                                                     mm_constant = 100)    
    np.save(os.path.join(prediction_path,'predictions'), predictions)   
        
    print('autoencoder: normalized loss', norm_l1_loss)
    
    print('autoencoder: euclidean distance in mm=', l2_loss)

loading checkpoint from file /home/jingwang/Data/data/FaceWarehouse/results/spirals_autoencoder/latent_128/checkpoints/align_pose/checkpoint.pth.tar


100%|██████████| 7/7 [00:07<00:00,  1.02it/s]


('loss for ', 402, tensor(2.7376, device='cuda:0'))
('loss for ', 560, tensor(2.7437, device='cuda:0'))
('loss for ', 50, tensor(2.7965, device='cuda:0'))
('loss for ', 537, tensor(2.7487, device='cuda:0'))
('loss for ', 390, tensor(2.8354, device='cuda:0'))
('loss for ', 364, tensor(2.8189, device='cuda:0'))
('loss for ', 169, tensor(2.8539, device='cuda:0'))
('loss for ', 616, tensor(2.7999, device='cuda:0'))
('loss for ', 341, tensor(2.7889, device='cuda:0'))
('loss for ', 590, tensor(2.8523, device='cuda:0'))
('loss for ', 499, tensor(2.9018, device='cuda:0'))
('loss for ', 389, tensor(2.8896, device='cuda:0'))
('loss for ', 125, tensor(2.8902, device='cuda:0'))
('loss for ', 370, tensor(2.9148, device='cuda:0'))
('loss for ', 614, tensor(2.8942, device='cuda:0'))
('loss for ', 600, tensor(2.9359, device='cuda:0'))
('loss for ', 194, tensor(2.8669, device='cuda:0'))
('loss for ', 8, tensor(2.8613, device='cuda:0'))
('loss for ', 653, tensor(2.8374, device='cuda:0'))
('loss for ', 5

In [18]:
if args['mode'] == 'test': # test with train set
    print('loading checkpoint from file %s'%(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar')))
    checkpoint_dict = torch.load(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar'),map_location=device)
    model.load_state_dict(checkpoint_dict['autoencoder_state_dict'])
        
    predictions, norm_l1_loss, l2_loss = test_autoencoder_dataloader(device, model, dataloader_train, 
                                                                     shapedata, worst_path=worst_mesh_path, 
                                                                     worst_face_num=args['worst_face_num'],
                                                                     mm_constant = 100)    
    np.save(os.path.join(prediction_path,'predictions'), predictions)   
        
    print('autoencoder: normalized loss', norm_l1_loss)
    
    print('autoencoder: euclidean distance in mm=', l2_loss)

loading checkpoint from file /home/jingwang/Data/data/FaceWarehouse/results/spirals_autoencoder/latent_128/checkpoints/align_pose/checkpoint.pth.tar



  0%|          | 0/97 [00:00<?, ?it/s][A
  1%|          | 1/97 [00:05<09:34,  5.99s/it][A
  2%|▏         | 2/97 [00:06<06:48,  4.30s/it][A
  3%|▎         | 3/97 [00:06<04:52,  3.11s/it][A
  4%|▍         | 4/97 [00:07<03:31,  2.27s/it][A
  5%|▌         | 5/97 [00:07<02:34,  1.68s/it][A
  6%|▌         | 6/97 [00:07<01:55,  1.27s/it][A
  7%|▋         | 7/97 [00:07<01:28,  1.02it/s][A
  8%|▊         | 8/97 [00:08<01:09,  1.28it/s][A
  9%|▉         | 9/97 [00:08<00:56,  1.56it/s][A
 10%|█         | 10/97 [00:08<00:47,  1.84it/s][A
 11%|█▏        | 11/97 [00:09<00:40,  2.11it/s][A
 12%|█▏        | 12/97 [00:09<00:40,  2.11it/s][A
 13%|█▎        | 13/97 [00:09<00:35,  2.37it/s][A
 14%|█▍        | 14/97 [00:10<00:31,  2.59it/s][A
 15%|█▌        | 15/97 [00:10<00:29,  2.78it/s][A
 16%|█▋        | 16/97 [00:10<00:27,  2.93it/s][A
 18%|█▊        | 17/97 [00:11<00:26,  3.05it/s][A
 19%|█▊        | 18/97 [00:11<00:25,  3.14it/s][A
 20%|█▉        | 19/97 [00:11<00:24,  3.18it/s]

('loss for ', 586, tensor(0.6500, device='cuda:0'))
('loss for ', 2114, tensor(0.6520, device='cuda:0'))
('loss for ', 7367, tensor(0.6512, device='cuda:0'))
('loss for ', 2851, tensor(0.6557, device='cuda:0'))
('loss for ', 6893, tensor(0.6568, device='cuda:0'))
('loss for ', 8053, tensor(0.6522, device='cuda:0'))
('loss for ', 6181, tensor(0.6723, device='cuda:0'))
('loss for ', 2481, tensor(0.6690, device='cuda:0'))
('loss for ', 6458, tensor(0.6588, device='cuda:0'))
('loss for ', 6215, tensor(0.6871, device='cuda:0'))
('loss for ', 4388, tensor(0.7073, device='cuda:0'))
('loss for ', 3775, tensor(0.8443, device='cuda:0'))
('loss for ', 3728, tensor(0.6557, device='cuda:0'))
('loss for ', 645, tensor(0.7630, device='cuda:0'))
('loss for ', 4929, tensor(0.7139, device='cuda:0'))
('loss for ', 1636, tensor(0.6774, device='cuda:0'))
('loss for ', 2921, tensor(0.7664, device='cuda:0'))
('loss for ', 8807, tensor(0.6903, device='cuda:0'))
('loss for ', 2249, tensor(0.6622, device='cuda:

In [39]:
# from optimize_to_get_result import optimize_to_get_result

import torch
import torch.optim as optim
import torch.nn as nn


def optimize_to_get_result(model, loss_fun, z_dim, device, targets, n_iter=1, output_loss=True):
    model.eval()
    
#     return torch.empty(0), model(targets)
    
    model.module.only_encode = True
    z = model(targets)
    model.module.only_encode = False
    
    model.module.only_decode = True
    z_param = nn.Parameter(z)
    optimizer = optim.LBFGS(params=[z_param])
    for it in range(n_iter):
        def closure():
            optimizer.zero_grad()
            outputs = model(z_param)
            loss = loss_fun(outputs, targets)
            loss.backward()
            if output_loss:
                print('loss ', loss.item())
            return loss
        optimizer.step(closure)
    outputs = model(z_param)
    
    model.module.only_decode = False
    return z_param.data, outputs

def work_optimize():

    test_dataset = dataset_test
    # test_ids = [402,560, 50,  537, 390,
    #             364, 169, 616, 341, 590,
    #             499, 389,125, 370, 614,
    #             600, 194, 8, 653, 559]
    test_ids = [8]
    inputs = torch.cat([test_dataset[idx]['points'].to(device).unsqueeze(0) for idx in test_ids])
    inputs = inputs.to(device)

    print(inputs.shape) # fixme

    shapedata_mean = torch.Tensor(shapedata.mean).to(device)
    shapedata_std = torch.Tensor(shapedata.std).to(device)

    z, outputs = optimize_to_get_result(model, loss_fn, args['nz'], device,
                          inputs)
    from pprint import pprint
    pprint(z)

    outputs = outputs[:,:-1]
    inputs = inputs[:,:-1]

    old_outputs = outputs
    old_inputs = inputs

    outputs = outputs * shapedata_std + shapedata_mean
    inputs = inputs * shapedata_std + shapedata_mean

    per_l2_loss = torch.sqrt(torch.sum((outputs - inputs)**2, dim=2))

    for i in range(len(test_ids)):
        print(test_ids[i], torch.mean(per_l2_loss[i]))

    shapedata.save_meshes(os.path.join(worst_mesh_path, 'opt_input'),old_inputs.detach().cpu().numpy(),test_ids)
    shapedata.save_meshes(os.path.join(worst_mesh_path, 'opt_output'),old_outputs.detach().cpu().numpy(),test_ids)

work_optimize()

torch.Size([1, 11511, 3])
('loss ', 0.39739155769348145)
('loss ', 0.39277511835098267)
('loss ', 0.35588905215263367)
('loss ', 0.33115866780281067)
('loss ', 0.3164112865924835)
('loss ', 0.3067980408668518)
('loss ', 0.2976946234703064)
('loss ', 0.2918466627597809)
('loss ', 0.2884432375431061)
('loss ', 0.28520214557647705)
('loss ', 0.2826074957847595)
('loss ', 0.27986764907836914)
('loss ', 0.27789127826690674)
('loss ', 0.276755154132843)
('loss ', 0.275785356760025)
('loss ', 0.27493956685066223)
('loss ', 0.2742707133293152)
('loss ', 0.27380314469337463)
('loss ', 0.2734643816947937)
('loss ', 0.2732097804546356)
tensor([[-1.2232,  0.2258,  0.2435, -0.7786,  2.4340, -0.1475, -0.0689,  0.4264,
         -0.0836, -0.2713,  0.8224,  0.8065,  0.2474,  0.4163, -0.6270, -0.9571,
          4.9992, -0.8653,  1.1199, -0.8996, -3.6033, -0.4942,  0.4912, -0.3746,
         -1.2572,  1.9690, -0.7755, -0.2962,  0.1613, -1.5031,  1.9733,  0.6295,
         -1.8684, -1.2721,  1.0964,  0.7270

In [19]:
type(inputs)

torch.Tensor

In [20]:
type(outputs)

torch.Tensor

In [21]:
type(old_inputs)

torch.Tensor

In [29]:
torch.cat([test_dataset[0]['points'].unsqueeze(0),test_dataset[1]['points'].unsqueeze(0)]).shape

torch.Size([2, 11511, 3])

In [21]:
args['nz']

128

In [23]:
model.only_decode

True

In [27]:
z1 = torch.zeros((2,11511,3))
z2 = torch.zeros((2,128))

model(z1)

tensor([[[-0.0581,  0.0713, -0.0968],
         [-0.0582,  0.0715, -0.0978],
         [-0.0570,  0.0703, -0.0970],
         ...,
         [-0.0562,  0.0718, -0.0985],
         [-0.0539,  0.0685, -0.1012],
         [-0.0000,  0.0000, -0.0000]],

        [[-0.0581,  0.0713, -0.0968],
         [-0.0582,  0.0715, -0.0978],
         [-0.0570,  0.0703, -0.0970],
         ...,
         [-0.0562,  0.0718, -0.0985],
         [-0.0539,  0.0685, -0.1012],
         [-0.0000,  0.0000, -0.0000]]], device='cuda:0',
       grad_fn=<GatherBackward>)