In [1]:
import sys
sys.path.append('../')
from lib import mesh_sampling
import numpy as np
import json
import os
import copy
from facemesh import FaceData
import time
import pickle
import trimesh

try:
    import psbody.mesh
    found = True
except ImportError:
    found = False
if found:
    from psbody.mesh import Mesh, MeshViewer, MeshViewers

from autoencoder_dataset import autoencoder_dataset
from torch.utils.data import DataLoader

from spiral_utils import get_adj_trigs, generate_spirals
from models import SpiralAutoencoder, SpiralAutoencoder_extra_conv

from train_funcs import train_autoencoder, train_autoencoder_dataloader


import torch
from tensorboardX import SummaryWriter

from sklearn.metrics.pairwise import euclidean_distances
meshpackage = 'trimesh'
root_dir = '/data/gb318/datasets/'

name = 'sliced'
dataset = 'COMA'    

GPU = True
device_idx = 8
torch.cuda.get_device_name(device_idx)

'GeForce RTX 2080 Ti'

In [2]:
args = {}

generative_model = 'autoencoder'
dilation_flag = False
hardcode_down_ref = False
downsample_method = 'COMA_downsample' # choose'COMA_downsample' or 'meshlab_downsample'
downsample_config = ''

if dataset == 'COMA':
    reference_mesh_file = os.path.join(root_dir, dataset,'preprocessed/templates/template.obj')
    downsample_directory = os.path.join(root_dir, dataset,'preprocessed/templates',downsample_method,downsample_config)
    ds_factors = [4, 4, 4, 4]
    step_sizes = [1, 1, 1, 1, 1]
    filter_sizes_enc = [[3, 16, 16, 16, 32],[[],[],[],[],[]]]
    filter_sizes_dec = [[32, 32, 16, 16, 3],[[],[],[],[],[]]]
    if dilation_flag:
        dilation=[2, 2, 2, 1, 1] 
    else:
        dilation = None

args = {'generative_model': generative_model,
        'name': name, 'data': os.path.join(root_dir, dataset, 'preprocessed',name),
        'results_folder':  os.path.join(root_dir, dataset,'results/higher_order_'+ generative_model,\
                                        downsample_method, downsample_config,'3nd_order_full'),
        'reference_mesh_file':reference_mesh_file, 'downsample_directory': downsample_directory,
        'checkpoint_file': 'checkpoint',
        'seed':2, 'loss':'l1',
        'batch_size':16, 'num_epochs':300, 'eval_frequency':200, 'num_workers': 4,
        'filter_sizes_enc': filter_sizes_enc, 'filter_sizes_dec': filter_sizes_dec,
        'nz':16, 
        'ds_factors': ds_factors, 'step_sizes' : step_sizes, 'dilation': dilation,
        'injection': True, 'residual': True, 
        
        'lr':1e-3, 
        'regularization': 5e-5,         
        'scheduler': True, 'decay_rate': 0.99,'decay_steps':1,  
        'resume': True,
        
        'mode':'train', 'shuffle': True, 'nVal': 100, 'normalization': True}

if generative_model == 'autoencoder':
    args['results_folder'] = os.path.join(args['results_folder'],\
                                          'latent_'+str(args['nz']))
    
if not os.path.exists(os.path.join(args['results_folder'])):
    os.makedirs(os.path.join(args['results_folder']))

summary_path = os.path.join(args['results_folder'],'summaries',args['name'])
if not os.path.exists(summary_path):
    os.makedirs(summary_path)  
    
checkpoint_path = os.path.join(args['results_folder'],'checkpoints', args['name'])
if not os.path.exists(checkpoint_path):
    os.makedirs(checkpoint_path)
    
samples_path = os.path.join(args['results_folder'],'samples', args['name'])
if not os.path.exists(samples_path):
    os.makedirs(samples_path)
    
prediction_path = os.path.join(args['results_folder'],'predictions', args['name'])
if not os.path.exists(prediction_path):
    os.makedirs(prediction_path)

if not os.path.exists(downsample_directory):
    os.makedirs(downsample_directory)

if hardcode_down_ref:
    if dataset == 'COMA' and downsample_method == 'COMA_downsample':
        reference_points = [[3567,4051,4597],
                            [1010,1081,1170],
                            [256, 276, 295],
                            [11, 69, 74],
                            [17, 17, 17]]
    elif dataset == 'COMA' and downsample_method == 'meshlab_downsample' and downsample_config == 'preserve_topology=True_preserve_boundary=False':
        reference_points = [[3567, 4051, 4597],
                             [1105, 1214, 1241],
                             [289, 310, 318],
                             [70, 80, 85],
                             [2, 19, 24]]
    else:
        raise NotImplementedError
else:
    if dataset == 'COMA':
        reference_points = [[3567,4051,4597]]
        

In [3]:
np.random.seed(args['seed'])
print("Loading data .. ")
if not os.path.exists(args['data']+'/mean.npy') or not os.path.exists(args['data']+'/std.npy'):
    facedata = FaceData(nVal=args['nVal'], train_file=args['data']+'/train.npy',
                             test_file=args['data']+'/test.npy', reference_mesh_file=args['reference_mesh_file'],
                             pca_n_comp=args['nz'], normalization = args['normalization'],\
                             meshpackage = meshpackage, load_flag = True)
    np.save(args['data']+'/mean.npy', facedata.mean)
    np.save(args['data']+'/std.npy', facedata.std)
else:
    facedata = FaceData(nVal=args['nVal'], train_file=args['data']+'/train.npy',\
                        test_file=args['data']+'/test.npy', reference_mesh_file=args['reference_mesh_file'],\
                        pca_n_comp=args['nz'], normalization = args['normalization'],\
                        meshpackage = meshpackage, load_flag = False)
    facedata.mean = np.load(args['data']+'/mean.npy')
    facedata.std = np.load(args['data']+'/std.npy')
    facedata.n_vertex = facedata.mean.shape[0]
    facedata.n_features = facedata.mean.shape[1]

if not os.path.exists(os.path.join(args['downsample_directory'],'downsampling_matrices.pkl')):
    if facedata.meshpackage == 'trimesh':
        raise NotImplementedError
    print("Generating Transform Matrices ..")


    if downsample_method == 'COMA_downsample':
        M,A,D,U,F = mesh_sampling.generate_transform_matrices(facedata.reference_mesh, args['ds_factors'])
    elif downsample_method == 'meshlab_downsample':
        M,A,D,U,F = mesh_sampling.generate_transform_matrices_given_downsamples(facedata.reference_mesh,                                                                                args['downsample_directory'],                                                                                len(args['ds_factors']))
    else:
        raise NotImplementedError(downsample_method)
        
    with open(os.path.join(args['downsample_directory'],'downsampling_matrices.pkl'), 'wb') as fp:
        M_verts_faces = [(M[i].v, M[i].f) for i in range(len(M))]
        pickle.dump({'M_verts_faces':M_verts_faces,'A':A,'D':D,'U':U,'F':F}, fp)
else:
    print("Loading Transform Matrices ..")
    with open(os.path.join(args['downsample_directory'],'downsampling_matrices.pkl'), 'rb') as fp:
        downsampling_matrices = pickle.load(fp,encoding = 'latin1')
            
    M_verts_faces = downsampling_matrices['M_verts_faces']
    if facedata.meshpackage == 'mpi-mesh':
        M = [Mesh(v=M_verts_faces[i][0], f=M_verts_faces[i][1]) for i in range(len(M_verts_faces))]
    elif facedata.meshpackage == 'trimesh':
        M = [trimesh.base.Trimesh(vertices=M_verts_faces[i][0], faces=M_verts_faces[i][1], process = False)             for i in range(len(M_verts_faces))]
    A = downsampling_matrices['A']
    D = downsampling_matrices['D']
    U = downsampling_matrices['U']
    F = downsampling_matrices['F']
        

if not hardcode_down_ref:
    print("Calculating reference points for downsampled versions..")
    for i in range(len(args['ds_factors'])):
        if facedata.meshpackage == 'mpi-mesh':
            dist = euclidean_distances(M[i+1].v, M[0].v[reference_points[0]])
        elif facedata.meshpackage == 'trimesh':
            dist = euclidean_distances(M[i+1].vertices, M[0].vertices[reference_points[0]])
        reference_points.append(np.argmin(dist,axis=0).tolist())



Loading data .. 
Loading Transform Matrices ..
Calculating reference points for downsampled versions..


In [4]:
if facedata.meshpackage == 'mpi-mesh':
    sizes = [x.v.shape[0] for x in M]
elif facedata.meshpackage == 'trimesh':
    sizes = [x.vertices.shape[0] for x in M]
Adj, Trigs = get_adj_trigs(A, F, facedata.reference_mesh, meshpackage = facedata.meshpackage)

spirals_np, spiral_sizes,spirals = generate_spirals(args['step_sizes'], M, Adj, Trigs, \
                                                    reference_points = reference_points, \
                                                    dilation = args['dilation'], random = False, \
                                                    meshpackage = facedata.meshpackage, counter_clockwise = True)

bU = []
bD = []
for i in range(len(D)):
    d = np.zeros((1,D[i].shape[0]+1,D[i].shape[1]+1))
    u = np.zeros((1,U[i].shape[0]+1,U[i].shape[1]+1))
    d[0,:-1,:-1] = D[i].todense()
    u[0,:-1,:-1] = U[i].todense()
    d[0,-1,-1] = 1
    u[0,-1,-1] = 1
    bD.append(d)
    bU.append(u)


spiral generation for hierarchy 0 (5023 vertices) finished
spiral generation for hierarchy 1 (1256 vertices) finished
spiral generation for hierarchy 2 (314 vertices) finished
spiral generation for hierarchy 3 (79 vertices) finished
spiral generation for hierarchy 4 (20 vertices) finished
spiral sizes for hierarchy 0:  9
spiral sizes for hierarchy 1:  9
spiral sizes for hierarchy 2:  9
spiral sizes for hierarchy 3:  9
spiral sizes for hierarchy 4:  8


In [5]:
# pytorch stuff

torch.manual_seed(args['seed'])

if GPU:
    device = torch.device("cuda:"+str(device_idx) if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")
print(device)

tspirals = [torch.from_numpy(s).long().to(device) for s in spirals_np]
tD = [torch.from_numpy(s).float().to(device) for s in bD]
tU = [torch.from_numpy(s).float().to(device) for s in bU]

cuda:8


In [6]:
# Building model, optimizer, and loss function

dataset_train = autoencoder_dataset(root_dir = args['data'], points_dataset = 'train',
                                           facedata = facedata,
                                           normalization = args['normalization'])

dataloader_train = DataLoader(dataset_train, batch_size=args['batch_size'],\
                                     shuffle = args['shuffle'], num_workers = args['num_workers'])

dataset_val = autoencoder_dataset(root_dir = args['data'], points_dataset = 'val', 
                                         facedata = facedata,
                                         normalization = args['normalization'])

dataloader_val = DataLoader(dataset_val, batch_size=args['batch_size'],\
                                     shuffle = False, num_workers = args['num_workers'])


dataset_test = autoencoder_dataset(root_dir = args['data'], points_dataset = 'test',
                                          facedata = facedata,
                                          normalization = args['normalization'])

dataloader_test = DataLoader(dataset_test, batch_size=args['batch_size'],\
                                     shuffle = False, num_workers = args['num_workers'])



if 'autoencoder' in args['generative_model']:
        model = SpiralAutoencoder_extra_conv(filters_enc = args['filter_sizes_enc'],   
                                             filters_dec = args['filter_sizes_dec'],
                                             latent_size=args['nz'],
                                             sizes=sizes,
                                             spiral_sizes=spiral_sizes,
                                             spirals=tspirals,
                                             D=tD, U=tU,device=device,
                                             injection = args['injection'],
                                             residual = args['residual']).to(device)
 
    
optim = torch.optim.Adam(model.parameters(),lr=args['lr'],weight_decay=args['regularization'])
if args['scheduler']:
    scheduler=torch.optim.lr_scheduler.StepLR(optim, args['decay_steps'],gamma=args['decay_rate'])
else:
    scheduler = None

def loss_l1(outputs, targets):
    L = torch.abs(outputs - targets).mean()
    return L 
if args['loss']=='l1':
    loss_fn = loss_l1


In [7]:
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params)) 
print(model)
# print(M[4].v.shape)

Total number of parameters is: 734516
SpiralAutoencoder_extra_conv(
  (conv): ModuleList(
    (0): SpiralConv(
      (conva2): Linear(in_features=27, out_features=16, bias=False)
      (convs2): Linear(in_features=27, out_features=16, bias=False)
      (conva3): Linear(in_features=27, out_features=16, bias=False)
      (convs3): Linear(in_features=144, out_features=16, bias=False)
      (normalizer3): BatchNorm1d(80384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (normalizer2): BatchNorm1d(80384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ELU(alpha=1.0)
    )
    (1): SpiralConv(
      (conva2): Linear(in_features=144, out_features=16, bias=False)
      (convs2): Linear(in_features=144, out_features=16, bias=False)
      (conva3): Linear(in_features=144, out_features=16, bias=False)
      (convs3): Linear(in_features=144, out_features=16, bias=False)
      (normalizer3): BatchNorm1d(20112, eps=1e-05, momentum=0.1, affin

In [None]:
if args['mode'] == 'train':
    writer = SummaryWriter(summary_path)
    with open(os.path.join(args['results_folder'],'checkpoints', args['name'] +'_params.json'),'w') as fp:
        saveparams = copy.deepcopy(args)
        json.dump(saveparams, fp)
        
    if args['resume']:
            print('loading checkpoint from file %s'%(os.path.join(checkpoint_path,args['checkpoint_file'])))
            checkpoint_dict = torch.load(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar'),map_location=device)
            start_epoch = checkpoint_dict['epoch'] + 1
            model.load_state_dict(checkpoint_dict['autoencoder_state_dict'])
            optim.load_state_dict(checkpoint_dict['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint_dict['scheduler_state_dict'])
            print('Resuming from epoch %s'%(str(start_epoch)))     
    else:
        start_epoch = 0
        
    if args['generative_model'] == 'autoencoder':
        train_autoencoder_dataloader(dataloader_train, dataloader_val,
                          device, model, optim, loss_fn,
                          bsize = args['batch_size'],
                          start_epoch = start_epoch,
                          n_epochs = args['num_epochs'],
                          eval_freq = args['eval_frequency'],
                          scheduler = scheduler,
                          writer = writer,
                          save_recons=True,
                          facedata=facedata,
                          metadata_dir=checkpoint_path, samples_dir=samples_path,
                          checkpoint_path = args['checkpoint_file'])

loading checkpoint from file /data/gb318/datasets/COMA/results/higher_order_autoencoder/COMA_downsample/3nd_order_full/latent_16/checkpoints/sliced/checkpoint


  0%|          | 0/1145 [00:00<?, ?it/s]

Resuming from epoch 6


100%|██████████| 1145/1145 [01:17<00:00, 14.81it/s]
100%|██████████| 7/7 [00:00<00:00, 16.68it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 6 | tr 0.1362568846327862 | val 0.1461655193567276


100%|██████████| 1145/1145 [01:17<00:00, 14.80it/s]
100%|██████████| 7/7 [00:00<00:00, 16.65it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 7 | tr 0.13333970615602145 | val 0.13564552009105682


100%|██████████| 1145/1145 [01:19<00:00, 14.43it/s]
100%|██████████| 7/7 [00:00<00:00, 14.70it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 8 | tr 0.13015815598731023 | val 0.141049542427063


100%|██████████| 1145/1145 [01:20<00:00, 14.25it/s]
100%|██████████| 7/7 [00:00<00:00, 14.60it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 9 | tr 0.1275568053130403 | val 0.12928445160388946


100%|██████████| 1145/1145 [01:20<00:00, 14.23it/s]
100%|██████████| 7/7 [00:00<00:00, 17.04it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 10 | tr 0.12565184177928956 | val 0.1307503604888916


100%|██████████| 1145/1145 [01:21<00:00, 14.10it/s]
100%|██████████| 7/7 [00:00<00:00, 16.60it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 11 | tr 0.12429249801581718 | val 0.12440707564353942


100%|██████████| 1145/1145 [01:20<00:00, 14.30it/s]
100%|██████████| 7/7 [00:00<00:00, 15.15it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 12 | tr 0.12246768927186882 | val 0.12320172190666198


100%|██████████| 1145/1145 [01:20<00:00, 14.15it/s]
100%|██████████| 7/7 [00:00<00:00, 14.93it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 13 | tr 0.12123445690150082 | val 0.12379265785217285


100%|██████████| 1145/1145 [01:19<00:00, 14.40it/s]
100%|██████████| 7/7 [00:00<00:00, 16.93it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 14 | tr 0.12056587549984667 | val 0.11963440865278244


100%|██████████| 1145/1145 [01:20<00:00, 14.31it/s]
100%|██████████| 7/7 [00:00<00:00, 15.16it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 15 | tr 0.1195166419352497 | val 0.1185815954208374


100%|██████████| 1145/1145 [01:19<00:00, 14.34it/s]
100%|██████████| 7/7 [00:00<00:00, 16.67it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 16 | tr 0.11838678113634131 | val 0.11833416163921356


100%|██████████| 1145/1145 [01:19<00:00, 14.41it/s]
100%|██████████| 7/7 [00:00<00:00, 16.46it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 17 | tr 0.11756018296864526 | val 0.12227528423070907


100%|██████████| 1145/1145 [01:19<00:00, 14.32it/s]
100%|██████████| 7/7 [00:00<00:00, 16.80it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 18 | tr 0.11677420397998205 | val 0.12639546990394593


100%|██████████| 1145/1145 [01:20<00:00, 14.26it/s]
100%|██████████| 7/7 [00:00<00:00, 16.59it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 19 | tr 0.11625908705304953 | val 0.11961199462413788


100%|██████████| 1145/1145 [01:19<00:00, 14.33it/s]
100%|██████████| 7/7 [00:00<00:00, 17.55it/s]


epoch 20 | tr 0.1152293301111943 | val 0.11744223654270172


100%|██████████| 1145/1145 [01:18<00:00, 14.52it/s]
100%|██████████| 7/7 [00:00<00:00, 17.19it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 21 | tr 0.11488153133526091 | val 0.11693404793739319


100%|██████████| 1145/1145 [01:19<00:00, 14.33it/s]
100%|██████████| 7/7 [00:00<00:00, 16.81it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 22 | tr 0.11502412412576977 | val 0.11516235828399658


100%|██████████| 1145/1145 [01:20<00:00, 14.26it/s]
100%|██████████| 7/7 [00:00<00:00, 15.92it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 23 | tr 0.11406268913491059 | val 0.11687652081251144


100%|██████████| 1145/1145 [01:20<00:00, 14.30it/s]
100%|██████████| 7/7 [00:00<00:00, 16.35it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 24 | tr 0.11305438018775917 | val 0.11346208363771439


100%|██████████| 1145/1145 [01:19<00:00, 14.43it/s]
100%|██████████| 7/7 [00:00<00:00, 16.99it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 25 | tr 0.11293778827720544 | val 0.12101413756608963


100%|██████████| 1145/1145 [01:19<00:00, 14.40it/s]
100%|██████████| 7/7 [00:00<00:00, 16.49it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 26 | tr 0.11217843207006666 | val 0.11382825911045075


100%|██████████| 1145/1145 [01:20<00:00, 14.23it/s]
100%|██████████| 7/7 [00:00<00:00, 16.89it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 27 | tr 0.11162844375203029 | val 0.11626580238342285


100%|██████████| 1145/1145 [01:21<00:00, 14.11it/s]
100%|██████████| 7/7 [00:00<00:00, 16.22it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 28 | tr 0.11141290499267115 | val 0.11114266157150268


100%|██████████| 1145/1145 [01:19<00:00, 14.34it/s]
100%|██████████| 7/7 [00:00<00:00, 17.24it/s]


epoch 29 | tr 0.11120222368153788 | val 0.1338459938764572


100%|██████████| 1145/1145 [01:20<00:00, 14.30it/s]
100%|██████████| 7/7 [00:00<00:00, 16.62it/s]


epoch 30 | tr 0.11043777743584791 | val 0.11669428914785385


100%|██████████| 1145/1145 [01:19<00:00, 14.38it/s]
100%|██████████| 7/7 [00:00<00:00, 16.90it/s]
  0%|          | 0/1145 [00:00<?, ?it/s]

epoch 31 | tr 0.1100776700522123 | val 0.12065846621990203


100%|██████████| 1145/1145 [01:19<00:00, 14.32it/s]
100%|██████████| 7/7 [00:00<00:00, 16.19it/s]


epoch 32 | tr 0.11013647114119177 | val 0.1192743244767189


 99%|█████████▉| 1131/1145 [01:18<00:01, 13.87it/s]