In [1]:
import os
import time as timer
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
import open3d as o3d
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

from load_3d_data import load_data
from visualization_utils import visualize
import train_facetalk_utils as tu
import vae

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


## Define Constants

In [2]:
DATA_DIR = './data'
MODEL_DIR = './models'

EPOCHS = 100
BATCH_SIZE = 16
LEARNING_RATE = 1e-3
dh = 5000
dz = 500
beta = 0.75

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

## Helper Functions

In [3]:
"""
Utility functions
"""
def reshape_vertices(inversed):
    """
    Function to reshape the inversed principal components
    to vertices
    """
    vertices = []
    for i in range(len(inversed)):
        vert = np.reshape(inversed[i], (-1, 3))
        vertices.append(vert)
    
    return vertices

def vertices_to_meshes(vertices, ori_meshes):
    """
    Function to convert vertices to meshes
    using triangles from original meshes
    """
    meshes = []
    for i in range(len(vertices)):
        triangles = np.asarray(ori_meshes[i].triangles)
        mesh = o3d.geometry.TriangleMesh(
            o3d.utility.Vector3dVector(vertices[i]), 
            o3d.utility.Vector3iVector(triangles))
        mesh.compute_vertex_normals()
        mesh.compute_triangle_normals()
        mesh.paint_uniform_color([0.5, 0.5, 0.5])
        mesh.normalize_normals()
        meshes.append(mesh)
    return meshes

## Data Preparation

In [4]:
dirname = './data/FaceTalk'
files, expressions = load_data(dirname)
len(files)

1183

In [5]:
"""
Split to train and test set (90:10)
"""
X_train, X_test = train_test_split(files, test_size=.1, random_state=42)
len(X_train), len(X_test)

(1064, 119)

### Get Vertices from 3D Mesh

In [6]:
def get_vertices(files):
    """
    Function to get each 3d mesh file's vertices
    """
    vertices = []
    for file in files:
        vert = np.asarray(file.vertices)
        vert_reshaped = vert.reshape([1, -1])[0]
        vertices.append(vert_reshaped)
    return np.asarray(vertices)

In [7]:
X_train_v = get_vertices(X_train).reshape(-1, 5023, 3)
X_test_v = get_vertices(X_test).reshape(-1, 5023, 3)
X_train_v.shape, X_test_v.shape

((1064, 5023, 3), (119, 5023, 3))

In [8]:
train_dl = DataLoader(X_train_v, batch_size=BATCH_SIZE)
test_dl = DataLoader(X_test_v, batch_size=BATCH_SIZE)

In [9]:
[n, dx1, dx2] = X_train_v.shape

## Train Variational Autoencoder

In [10]:
print(f"Device: {DEVICE}")

Device: cuda


In [12]:
model = vae.VAE_FT_sigm(d_in=dx1*dx2, d_z=dz, d_h=dh).to(DEVICE)
print(model)

# Persistent file to store the model
model_path = os.path.join(MODEL_DIR, 'vae_sigm_ep100_facetalk.pth')

VAE_FT_sigm(
  (encoder): Sequential(
    (0): Linear(in_features=15069, out_features=5000, bias=True)
    (1): ReLU()
    (2): Linear(in_features=5000, out_features=1000, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=500, out_features=5000, bias=True)
    (1): ReLU()
    (2): Linear(in_features=5000, out_features=15069, bias=True)
    (3): Sigmoid()
  )
)


In [13]:
"""Define loss function"""
loss_fn = tu.vae_loss

"""Define optimizer"""
optimizer = optim.Adam(
    model.parameters(), 
    lr=LEARNING_RATE
)

In [15]:
"""Train the Autoencoder"""

for t in range(EPOCHS):
    print(f"Epoch {t+1} out of {EPOCHS}\n ------------")
    
    start = timer.time()
    tu.train_vae(train_dl, model, loss_fn, optimizer, beta=beta)
    elapsed_time = timer.time() - start # this timing method ONLY works for CPU computation, not for GPU/cuda calls
    print(f" > Training time: {elapsed_time:>.2f} seconds")
    
    test_loss = tu.test_vae(test_dl, model, loss_fn, beta=beta)
    print(f" > Test reconstruction loss: {test_loss:>.2f}")
    
    # Save model
    torch.save(model.state_dict(), model_path)
    print(f"Model {model_path} stored!")
    
print("Done!")

Epoch 1 out of 100
 ------------
Loss: 61825.105469 [    0]/ 1064
 > Training time: 3.62 seconds
 > Test reconstruction loss: 33726.02
Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 2 out of 100
 ------------
Loss: 602.678467 [    0]/ 1064
 > Training time: 3.59 seconds
 > Test reconstruction loss: 45719.12
Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 3 out of 100
 ------------
Loss: 571.043030 [    0]/ 1064
 > Training time: 3.59 seconds
 > Test reconstruction loss: 47696.65
Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 4 out of 100
 ------------
Loss: 510.360840 [    0]/ 1064
 > Training time: 3.59 seconds
 > Test reconstruction loss: 38158.49
Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 5 out of 100
 ------------
Loss: 443.824310 [    0]/ 1064
 > Training time: 3.59 seconds
 > Test reconstruction loss: 30248.98
Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 6 out of 100
 ------------
Loss: 419.475189 [    0]/ 1064
 > Training ti

 > Training time: 3.59 seconds
 > Test reconstruction loss: 38984.43
Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 46 out of 100
 ------------
Loss: 376.855591 [    0]/ 1064
 > Training time: 3.60 seconds
 > Test reconstruction loss: 39437.43
Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 47 out of 100
 ------------
Loss: 388.809418 [    0]/ 1064
 > Training time: 3.60 seconds
 > Test reconstruction loss: 40375.37
Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 48 out of 100
 ------------
Loss: 387.838745 [    0]/ 1064
 > Training time: 3.60 seconds
 > Test reconstruction loss: 40771.25
Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 49 out of 100
 ------------
Loss: 380.869293 [    0]/ 1064
 > Training time: 3.60 seconds
 > Test reconstruction loss: 40425.17
Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 50 out of 100
 ------------
Loss: 404.983002 [    0]/ 1064
 > Training time: 3.59 seconds
 > Test reconstruction loss: 42809.11
Model 

Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 90 out of 100
 ------------
Loss: 401.192657 [    0]/ 1064
 > Training time: 3.59 seconds
 > Test reconstruction loss: 51222.20
Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 91 out of 100
 ------------
Loss: 405.815491 [    0]/ 1064
 > Training time: 3.60 seconds
 > Test reconstruction loss: 51132.15
Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 92 out of 100
 ------------
Loss: 403.631012 [    0]/ 1064
 > Training time: 3.60 seconds
 > Test reconstruction loss: 50954.33
Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 93 out of 100
 ------------
Loss: 399.832489 [    0]/ 1064
 > Training time: 3.60 seconds
 > Test reconstruction loss: 50875.60
Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 94 out of 100
 ------------
Loss: 411.799103 [    0]/ 1064
 > Training time: 3.60 seconds
 > Test reconstruction loss: 50796.18
Model ./models\vae_sigm_ep100_facetalk.pth stored!
Epoch 95 out of 100
 ---