# MLSS 2019: Neural Networks for Surface Meshes

This tutorial is based on ICCV 2019 paper "Neural 3D Morphable Models: Spiral Convolutional Networks for 3D Shape Representation Learning and Generation" by Bouritsas, G., Bokhnyak, S., Bronstein, M., and Zafeiriou, S.

It presents a novel graph convolutional operator, acting directly on the 3D mesh, that explicitly models the inductive bias of the fixed underlying graph.

Using this operator, we will construct an autoencoder neural network for meshes and test its performanse on a set of human poses.

<b>We are going to use Google Collab to run this notebook. In order to install all the necessary files run the following cells:</b>

## Data loading

In [None]:
!pip install -q --upgrade git+https://github.com/mlss-skoltech/tutorials_week2.git#subdirectory=graph_neural_networks2

In [1]:
import os
!pip -q install trimesh==3.2 tensorboardX
!wget -nc -O DFAUST.zip https://box.skoltech.ru/index.php/s/vcTrg71n94HkqjX/download
!unzip -qn DFAUST.zip

You should consider upgrading via the 'pip install --upgrade pip' command.[0m
File `DFAUST.zip' already there; not retrieving.
File `mesh_sampling.py' already there; not retrieving.
File `utils.py' already there; not retrieving.


In [2]:
import numpy as np
import json
import copy
import pickle
import pdb
import time
from tqdm import tqdm

from sklearn.metrics.pairwise import euclidean_distances
import trimesh
from trimesh.exchange.export import export_mesh

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from tensorboardX import SummaryWriter

import noneuclideanlearning.mesh_sampling as mesh_sampling
from noneuclideanlearning.utils import get_adj_trigs, generate_spirals, get_settings, show_mesh

root_dir = './'
mesh_template = trimesh.load('DFAUST/template/template.obj', process = False)
dataset = 'DFAUST'
name = ''

specified material ()  not loaded!


For this tutorial, we are going to need GPU. Let's make sure we have it

In [3]:
GPU = True
device_idx = 0
torch.cuda.get_device_name(device_idx)

'GeForce GTX 1080 Ti'

We use DFAUST for our experiments. The dynamic human body shape dataset consists of 40K+ 3D scans (6890 vertices) of ten unique identities performing actions such as leg and arm raises, jumps, etc. when we trained the network, we randomly split the data into a test set of 5000, 500 validation, and 34,5K+ train. Here, we only use 1000 training, 500 validation and 1000 testing samples.

We are going to use a function for mesh visualization, and see what our data samples look like. We chose to do our analysis on human pose meshes.

In [4]:
show_mesh(mesh_template.vertices)

specified material ()  not loaded!


As you know, neural networks require a lot of parameters to work. We will set these parameters in ``utils.py`` file and import them here. Feel free to check the corresponding file to see the settings.

In [5]:
args, reference_mesh_file, ds_factors, step_sizes,\
filter_sizes_enc, filter_sizes_dec, reference_points,\
summary_path, checkpoint_path, samples_path, prediction_path = get_settings(root_dir, dataset, name)

In applications dealing with 3D data, the key challenge of geometric deep learning is a meaningful definition of intrinsic operations analogous to convolution and pooling on meshes or point clouds. Among numerous advantages of working directly on mesh or point cloud data is the fact that it is possible to build invariance to shape transformations (both rigid and nonrigid) into the architecture, as a result allowing to use significantly simpler models and much less training data. 

`PyTorch`, as a number of other deep learning frameworks, needs to have a scpecial class for the `dataset` and `dataloader` objects. For the network we are going to implement, the dataset will take directory to the data, vertices of the meshes, and some other inputs, normalize the data inside, and output a training sample.

In [6]:
class autoencoder_dataset(Dataset):

    def __init__(self, root_dir, points_dataset, shapedata, normalization = True, dummy_node = True):
        
        self.shapedata = shapedata
        self.normalization = normalization
        self.root_dir = root_dir
        self.points_dataset = points_dataset
        self.dummy_node = dummy_node
        self.paths = np.load(os.path.join(root_dir, 'paths_'+points_dataset+'.npy'))

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        basename = self.paths[idx]
        
        verts_init = np.load(os.path.join(self.root_dir,'points'+'_'+self.points_dataset, basename+'.npy'))
        if self.normalization:
            verts_init = verts_init - self.shapedata.mean
            verts_init = verts_init / self.shapedata.std
        verts_init[np.where(np.isnan(verts_init))]=0.0
        
        verts_init = verts_init.astype('float32')
        if self.dummy_node:
            verts = np.zeros((verts_init.shape[0]+1,verts_init.shape[1]),dtype=np.float32)
            verts[:-1,:] = verts_init
            verts_init = verts
        verts = torch.Tensor(verts_init)
        

        sample = {'points': verts}

        return sample

## Neural networks for meshes

One of the key challenges in developing convolution-like operators on graphs or manifolds is the lack of a global system of coordinates that can be associated with each point. In this tutorial, we will focus on fixed topology meshes. We will define an ordering-based graph convolutional operator, contrary to the permutation invariant operators in the literature of Graph Neural Networks.

This way we obtain anisotropic filters without sacrificing computational complexity, while simultaneously we explicitly encode the fixed graph connectivity. The operator can potentially be generalised to other domains that accept implicit local orderings, such as arbitrary mesh topologies and point clouds, while it is naturally equivalent to traditional grid convolutions. Via this equivalence, common CNN practices, such as dilated convolutions, can be easily formulated for meshes.

![spiral_conv](https://box.skoltech.ru/index.php/s/Fudde1kfbSvglJV/download)

The issues of the absence of a global ordering and insensitivity to graph topology are irrelevant when dealing with fixed topology meshes. In particular, one can locally order the vertices and keep the order fixed. Then, graph convolution can be defined as follows:
$$(f * g)_x = \sum_{l=1}^L g_l f(x_l)$$
where $\{x_1, \ldots, x_L\}$ denote the neighbours of vertex $x$ ordered in a fixed way. 

For example, it is possible to get the neighborhoods of mesh vertices through spiral ordering. Let $x \in V$ be a mesh vertex, and let $R^d(x)$ be the $d$-ring, i.e. an ordered set of vertices whose shortest (graph) path to $x$ is exactly $d$ hops long; $R^d_j(x)$ denotes the $j$th element in the $d$-ring (trivially, $R^0_1(x) = x$). Then, one can define the spiral patch operator as the ordered sequence 
$$S(x) = \{x, R^1_1(x), R^1_2(x), \ldots , R^h_{|R^h|}\},$$
where $h$ denotes the patch radius, similar to the size of the kernel in classical CNNs. Then, spiral convolution is:
$$(f * g)_x = \sum_{l=1}^L g_l f(S_l(x)).$$

In [7]:
class SpiralConv(nn.Module):
    def __init__(self, in_c, spiral_size, out_c, activation='elu', bias=True, device=None):
        super(SpiralConv,self).__init__()
        self.in_c = in_c
        self.out_c = out_c
        self.device = device

        self.conv = nn.Linear(in_c * spiral_size, out_c, bias=bias)

        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'elu':
            self.activation = nn.ELU()
        elif activation == 'leaky_relu':
            self.activation = nn.LeakyReLU(0.02)
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'identity':
            self.activation = lambda x: x
        else:
            raise NotImplementedError()

    def forward(self, x, spiral_adj):
        bsize, num_pts, feats = x.size()
        _, _, spiral_size = spiral_adj.size()
  
        spirals_index = spiral_adj.view(bsize * num_pts * spiral_size)  # [1d array of batch,vertx,vertx-adj]
        batch_index = torch.arange(bsize, device=self.device) \
                           .view(-1, 1) \
                           .repeat([1,num_pts * spiral_size]) \
                           .view(-1).long()  # [0*numpt,1*numpt,etc.]
        spirals = x[batch_index, spirals_index, :] \
                  .view(bsize * num_pts, spiral_size * feats) # [bsize*numpt, spiral*feats]

        out_feat = self.conv(spirals)
        out_feat = self.activation(out_feat)

        out_feat = out_feat.view(bsize,num_pts,self.out_c)
        zero_padding = torch.ones((1,x.size(1),1), device=self.device)
        zero_padding[0,-1,0] = 0.0
        out_feat = out_feat * zero_padding

        return out_feat

The uniqueness of the ordering is given by fixing two degrees of freedom: the direction of the rings and the first vertex $R^1_1(x)$. The rest of the vertices of the spiral are ordered
inductively. The direction is chosen by moving clockwise or counterclockwise, while the choice of the first vertex is based on the underlying geometry of the shape to ensure the robustness of the method. In particular, fixing a reference vertex $x_0$ on a template shape and choosing the initial point for each spiral to be in the direction of the shortest geodesic path to $x_0$, gives us
$$R^1_1(x) = \underset{y\in R^1(x)}{\operatorname{argmin}} d_{\mathcal M}(x_0, y),$$
where $d_{\mathcal M}$ is the geodesic distance between two vertices on the mesh $\mathcal M$. In order to allow for fixed-sized spirals, we choose a fixed length $L$ as a hyper-parameter and then either truncate or zero-pad each spiral depending on its size.

Spiral convolution operator comes by construction with desirable properties (anisotropic, topology-aware, lightweight, easy-to-optimise), and by using it as a building block for traditional deep generative architectures.

## Neural3DMM: mesh autoencoder

In essence, a Neural 3D Morphable Model is a deep convolutional mesh autoencoder, that learns hierarchical representations of a shape. This way one manages to learn semantically meaningful representations. Similar to traditional convolutional autoencoders, this one consists of a series of convolutional layers with small receptive fields followed by pooling and unpooling, for the encoder and the decoder respectively. A decimated or upsampled version of the mesh is obtained each time and the features of the existing vertices are either aggregated or extrapolated. The calculation of the features of the added vertices after upsampling is done through interpolation by weighting the nearby vertices with barycentric coordinates. The network is trained by minimising the $\mathcal L_1$ norm between the input and the predicted output.

![image](https://box.skoltech.ru/index.php/s/k1xaTf5VlLYf2Na/download)

In [8]:
class SpiralAutoencoder(nn.Module):
    def __init__(self, filters_enc, filters_dec, latent_size, sizes, spiral_sizes, spirals, D, U, device, activation = 'elu'):
        super(SpiralAutoencoder,self).__init__()
        self.latent_size = latent_size
        self.sizes = sizes
        self.spirals = spirals
        self.filters_enc = filters_enc
        self.filters_dec = filters_dec
        self.spiral_sizes = spiral_sizes
        self.D = D
        self.U = U
        self.device = device
        self.activation = activation
        
        self.conv = []
        input_size = filters_enc[0][0]
        for i in range(len(spiral_sizes)-1):
            if filters_enc[1][i]:
                self.conv.append(SpiralConv(input_size, spiral_sizes[i], filters_enc[1][i],
                                            activation=self.activation, device=device).to(device))
                input_size = filters_enc[1][i]

            self.conv.append(SpiralConv(input_size, spiral_sizes[i], filters_enc[0][i+1],
                                        activation=self.activation, device=device).to(device))
            input_size = filters_enc[0][i+1]

        self.conv = nn.ModuleList(self.conv)   
        
        self.fc_latent_enc = nn.Linear((sizes[-1]+1)*input_size, latent_size)
        self.fc_latent_dec = nn.Linear(latent_size, (sizes[-1]+1)*filters_dec[0][0])
        
        self.dconv = []
        input_size = filters_dec[0][0]
        for i in range(len(spiral_sizes)-1):
            if i != len(spiral_sizes)-2:
                self.dconv.append(SpiralConv(input_size, spiral_sizes[-2-i], filters_dec[0][i+1],
                                             activation=self.activation, device=device).to(device))
                input_size = filters_dec[0][i+1]  
                
                if filters_dec[1][i+1]:
                    self.dconv.append(SpiralConv(input_size,spiral_sizes[-2-i], filters_dec[1][i+1],
                                                 activation=self.activation, device=device).to(device))
                    input_size = filters_dec[1][i+1]
            else:
                if filters_dec[1][i+1]:
                    self.dconv.append(SpiralConv(input_size, spiral_sizes[-2-i], filters_dec[0][i+1],
                                                 activation=self.activation, device=device).to(device))
                    input_size = filters_dec[0][i+1]                      
                    self.dconv.append(SpiralConv(input_size,spiral_sizes[-2-i], filters_dec[1][i+1],
                                                 activation='identity', device=device).to(device)) 
                    input_size = filters_dec[1][i+1] 
                else:
                    self.dconv.append(SpiralConv(input_size, spiral_sizes[-2-i], filters_dec[0][i+1],
                                                 activation='identity', device=device).to(device))
                    input_size = filters_dec[0][i+1]                      
                    
        self.dconv = nn.ModuleList(self.dconv)

    def encode(self,x):
        bsize = x.size(0)
        S = self.spirals
        D = self.D
        
        j = 0
        for i in range(len(self.spiral_sizes)-1):
            x = self.conv[j](x,S[i].repeat(bsize,1,1))
            j+=1
            if self.filters_enc[1][i]:
                x = self.conv[j](x,S[i].repeat(bsize,1,1))
                j+=1
            x = torch.matmul(D[i],x)
        x = x.view(bsize,-1)
        return self.fc_latent_enc(x)
    
    def decode(self,z):
        bsize = z.size(0)
        S = self.spirals
        U = self.U
        
        x = self.fc_latent_dec(z)
        x = x.view(bsize,self.sizes[-1]+1,-1)
        j=0
        for i in range(len(self.spiral_sizes)-1):
            x = torch.matmul(U[-1-i],x)
            x = self.dconv[j](x,S[-2-i].repeat(bsize,1,1))
            j+=1
            if self.filters_dec[1][i+1]: 
                x = self.dconv[j](x,S[-2-i].repeat(bsize,1,1))
                j+=1
        return x

    def forward(self,x):
        bsize = x.size(0)
        z = self.encode(x)
        x = self.decode(z)
        return x 

Let's make a class for the data we have.

In [9]:
class ShapeData(object):
    def __init__(self, nVal, train_file, test_file, reference_mesh_file, normalization = True, load_flag = True, mean_subtraction_only = False):
        self.nVal = nVal
        self.train_file = train_file
        self.test_file = test_file
        self.vertices_train = None
        self.vertices_val = None
        self.vertices_test = None
        self.n_vertex = None
        self.n_features = None
        self.normalization = normalization
        self.load_flag = load_flag
        self.mean_subtraction_only = mean_subtraction_only
        
        if self.load_flag:
            self.load()
        self.reference_mesh = trimesh.load(reference_mesh_file, process = False)
        
        if self.load_flag:
            self.mean = np.mean(self.vertices_train, axis=0)
            self.std = np.std(self.vertices_train, axis=0)
        else:
            self.mean = None
            self.std = None
        self.normalize()
        
    def load(self):
        vertices_train = np.load(self.train_file)
        self.vertices_train = vertices_train[:-self.nVal]
        self.vertices_val = vertices_train[-self.nVal:]

        self.n_vertex = self.vertices_train.shape[1]
        self.n_features = self.vertices_train.shape[2]

        if os.path.exists(self.test_file):
            self.vertices_test = np.load(self.test_file)
            self.vertices_test = self.vertices_test

    def normalize(self):
        if self.load_flag:
            if self.normalization:
                if self.mean_subtraction_only:
                    self.std = np.ones_like((self.std))
                self.vertices_train = self.vertices_train - self.mean
                self.vertices_train = self.vertices_train/self.std
                self.vertices_train[np.where(np.isnan(self.vertices_train))]=0.0

                self.vertices_val = self.vertices_val - self.mean
                self.vertices_val = self.vertices_val/self.std
                self.vertices_val[np.where(np.isnan(self.vertices_val))]=0.0

                if self.vertices_test is not None:
                    self.vertices_test = self.vertices_test - self.mean
                    self.vertices_test = self.vertices_test/self.std
                    self.vertices_test[np.where(np.isnan(self.vertices_test))]=0.0
                
                self.N = self.vertices_train.shape[0]

                print('Vertices normalized')
            else:
                print('Vertices not normalized')


    def save_meshes(self, filename, meshes, mesh_indices):
        for i in range(meshes.shape[0]):
            if self.normalization:
                vertices = meshes[i].reshape((self.n_vertex, self.n_features))*self.std + self.mean
            else:
                vertices = meshes[i].reshape((self.n_vertex, self.n_features))
            new_mesh = self.reference_mesh
            if self.n_features == 3:
                new_mesh.vertices = vertices
            elif self.n_features == 6:
                new_mesh.vertices = vertices[:,0:3]
                colors = vertices[:,3:]
                colors[np.where(colors<0)]=0
                colors[np.where(colors>1)]=1
                vertices[:,3:] = colors
                new_mesh.visual = trimesh.visual.create_visual(vertex_colors = vertices[:,3:])
            else:
                raise NotImplementedError
            new_mesh.export(filename+'.'+str(mesh_indices[i]).zfill(6)+'.ply','ply')   
        return 0

Now, we define the dataloaders for training and testing subsets.

In [10]:
def train_autoencoder_dataloader(dataloader_train, dataloader_val,
                                 device, model, optim, loss_fn, 
                                 bsize, start_epoch, n_epochs, eval_freq, scheduler = None,
                                 writer=None, save_recons=True, shapedata = None,
                                 metadata_dir=None, samples_dir = None, checkpoint_path = None):
    if not shapedata.normalization:
        shapedata_mean = torch.Tensor(shapedata.mean).to(device)
        shapedata_std = torch.Tensor(shapedata.std).to(device)
    
    total_steps = start_epoch*len(dataloader_train)

    for epoch in range(start_epoch, n_epochs):
        model.train()

        tloss = []
        for b, sample_dict in enumerate(tqdm(dataloader_train)):
            optim.zero_grad()
                
            tx = sample_dict['points'].to(device)
            cur_bsize = tx.shape[0]
            
            tx_hat = model(tx)
            loss = loss_fn(tx, tx_hat)

            loss.backward()
            optim.step()
            
            if shapedata.normalization:
                tloss.append(cur_bsize * loss.item())
            else:
                with torch.no_grad():
                    if shapedata.mean.shape[0]!=tx.shape[1]:
                        tx_norm = tx[:,:-1,:]
                        tx_hat_norm = tx_hat[:,:-1,:]
                    else:
                        tx_norm = tx
                        tx_hat_norm = tx_hat
                    tx_norm = (tx_norm - shapedata_mean)/shapedata_std
                    tx_norm = torch.cat((tx_norm,torch.zeros(tx.shape[0],1,tx.shape[2]).to(device)),1)
                    
                    tx_hat_norm = (tx_hat_norm -shapedata_mean)/shapedata_std
                    tx_hat_norm = torch.cat((tx_hat_norm,torch.zeros(tx.shape[0],1,tx.shape[2]).to(device)),1)
                    
                    loss_norm = loss_fn(tx_norm, tx_hat_norm)
                    tloss.append(cur_bsize * loss_norm.item())
            if writer and total_steps % eval_freq == 0:
                writer.add_scalar('loss/loss/data_loss',loss.item(),total_steps)
                writer.add_scalar('training/learning_rate', optim.param_groups[0]['lr'],total_steps)
            total_steps += 1

        # validate
        model.eval()
        vloss = []
        with torch.no_grad():
            for b, sample_dict in enumerate(tqdm(dataloader_val)):

                tx = sample_dict['points'].to(device)
                cur_bsize = tx.shape[0]

                tx_hat = model(tx)               
                loss = loss_fn(tx, tx_hat)
                
                if shapedata.normalization:
                    vloss.append(cur_bsize * loss.item())
                else:
                    with torch.no_grad():
                        if shapedata.mean.shape[0]!=tx.shape[1]:
                            tx_norm = tx[:,:-1,:]
                            tx_hat_norm = tx_hat[:,:-1,:]
                        else:
                            tx_norm = tx
                            tx_hat_norm = tx_hat
                        tx_norm = (tx_norm - shapedata_mean)/shapedata_std
                        tx_norm = torch.cat((tx_norm,torch.zeros(tx.shape[0],1,tx.shape[2]).to(device)),1)
                    
                        tx_hat_norm = (tx_hat_norm - shapedata_mean)/shapedata_std
                        tx_hat_norm = torch.cat((tx_hat_norm,torch.zeros(tx.shape[0],1,tx.shape[2]).to(device)),1)
                    
                        loss_norm = loss_fn(tx_norm, tx_hat_norm)
                        vloss.append(cur_bsize * loss_norm.item())   

        if scheduler:
            scheduler.step()
            
        epoch_tloss = sum(tloss) / float(len(dataloader_train.dataset))
        writer.add_scalar('avg_epoch_train_loss',epoch_tloss,epoch)
        if len(dataloader_val.dataset) > 0:
            epoch_vloss = sum(vloss) / float(len(dataloader_val.dataset))
            writer.add_scalar('avg_epoch_valid_loss', epoch_vloss,epoch)
            print('epoch {0} | tr {1} | val {2}'.format(epoch,epoch_tloss,epoch_vloss))
        else:
            print('epoch {0} | tr {1} '.format(epoch,epoch_tloss))
        model = model.cpu()
  
        torch.save({'epoch': epoch,
            'autoencoder_state_dict': model.state_dict(),
            'optimizer_state_dict' : optim.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
        },os.path.join(metadata_dir, checkpoint_path+'.pth.tar'))
        
        if epoch % 10 == 0:
            torch.save({'epoch': epoch,
            'autoencoder_state_dict': model.state_dict(),
            'optimizer_state_dict' : optim.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            },os.path.join(metadata_dir, checkpoint_path+'%s.pth.tar'%(epoch)))

        model = model.to(device)

        if save_recons:
            with torch.no_grad():
                if epoch == 0:
                    mesh_ind = [0]
                    msh = tx[mesh_ind[0]:1,0:-1,:].detach().cpu().numpy()
                    shapedata.save_meshes(os.path.join(samples_dir,'input_epoch_{0}'.format(epoch)),
                                                     msh, mesh_ind)
                mesh_ind = [0]
                msh = tx_hat[mesh_ind[0]:1,0:-1,:].detach().cpu().numpy()
                shapedata.save_meshes(os.path.join(samples_dir,'epoch_{0}'.format(epoch)),
                                                 msh, mesh_ind)

    print('~FIN~')

In [11]:
def test_autoencoder_dataloader(device, model, dataloader_test, shapedata, mm_constant = 1000):
    model.eval()
    l1_loss = 0
    l2_loss = 0
    shapedata_mean = torch.Tensor(shapedata.mean).to(device)
    shapedata_std = torch.Tensor(shapedata.std).to(device)
    with torch.no_grad():
        for i, sample_dict in enumerate(tqdm(dataloader_test)):
            tx = sample_dict['points'].to(device)
            prediction = model(tx)  
            if i==0:
                predictions = copy.deepcopy(prediction)
            else:
                predictions = torch.cat([predictions,prediction],0) 
                
            if dataloader_test.dataset.dummy_node:
                x_recon = prediction[:,:-1]
                x = tx[:,:-1]
            else:
                x_recon = prediction
                x = tx
            l1_loss+= torch.mean(torch.abs(x_recon-x))*x.shape[0]/float(len(dataloader_test.dataset))
            
            x_recon = (x_recon * shapedata_std + shapedata_mean) * mm_constant
            x = (x * shapedata_std + shapedata_mean) * mm_constant
            l2_loss+= torch.mean(torch.sqrt(torch.sum((x_recon - x)**2,dim=2)))*x.shape[0]/float(len(dataloader_test.dataset))
            
        predictions = predictions.cpu().numpy()
        l1_loss = l1_loss.item()
        l2_loss = l2_loss.item()
    
    return predictions, l1_loss, l2_loss

## Training the network
### Preliminary computations and loadings

Load the data.

In [12]:
np.random.seed(args['seed'])
print("Loading data .. ")

shapedata = ShapeData(nVal=args['nVal'], 
                     train_file=args['data']+'/train.npy',
                     test_file=args['data']+'/test.npy', 
                     reference_mesh_file=args['reference_mesh_file'],
                     normalization = args['normalization'], load_flag = False)
shapedata.mean = np.load(args['data']+'/mean.npy')
shapedata.std = np.load(args['data']+'/std.npy')
shapedata.n_vertex = shapedata.mean.shape[0]
shapedata.n_features = shapedata.mean.shape[1]

specified material ()  not loaded!


Loading data .. 


Load the precalculated transform matrices for the mesh models we have.

In [13]:
print("Loading Transform Matrices ..")
with open(os.path.join(args['downsample_directory'],'downsampling_matrices.pkl'), 'rb') as fp:
    downsampling_matrices = pickle.load(fp,encoding = 'latin1')
#         downsampling_matrices = pickle.load(fp)

M_verts_faces = downsampling_matrices['M_verts_faces']
M = [trimesh.base.Trimesh(vertices=M_verts_faces[i][0], faces=M_verts_faces[i][1], process = False) for i in range(len(M_verts_faces))]
A = downsampling_matrices['A']
D = downsampling_matrices['D']
U = downsampling_matrices['U']
F = downsampling_matrices['F']

Loading Transform Matrices ..


Precompute the reference points.

In [14]:
print("Calculating reference points for downsampled versions..")
for i in range(len(args['ds_factors'])):
    dist = euclidean_distances(M[i+1].vertices, M[0].vertices[reference_points[0]])
    reference_points.append(np.argmin(dist,axis=0).tolist())

Calculating reference points for downsampled versions..


Generating spiral paths for the convolutions.

In [15]:
sizes = [x.vertices.shape[0] for x in M]
Adj, Trigs = get_adj_trigs(A, F, shapedata.reference_mesh)

spirals_np, spiral_sizes,spirals = generate_spirals(args['step_sizes'], 
                                                    M, Adj, Trigs, 
                                                    reference_points = reference_points, 
                                                    dilation = args['dilation'], random = False, 
                                                    counter_clockwise = True)

bU = []
bD = []
for i in range(len(D)):
    d = np.zeros((1,D[i].shape[0]+1,D[i].shape[1]+1))
    u = np.zeros((1,U[i].shape[0]+1,U[i].shape[1]+1))
    d[0,:-1,:-1] = D[i].todense()
    u[0,:-1,:-1] = U[i].todense()
    d[0,-1,-1] = 1
    u[0,-1,-1] = 1
    bD.append(d)
    bU.append(u)

spiral generation for hierarchy 0 (6890 vertices) finished
spiral generation for hierarchy 1 (1723 vertices) finished
spiral generation for hierarchy 2 (431 vertices) finished
spiral generation for hierarchy 3 (108 vertices) finished
spiral generation for hierarchy 4 (27 vertices) finished
spiral sizes for hierarchy 0:  12
spiral sizes for hierarchy 1:  14
spiral sizes for hierarchy 2:  9
spiral sizes for hierarchy 3:  9
spiral sizes for hierarchy 4:  9


In [16]:
torch.manual_seed(args['seed'])

if GPU:
    device = torch.device("cuda:"+str(device_idx) if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")
print(device)

tspirals = [torch.from_numpy(s).long().to(device) for s in spirals_np]
tD = [torch.from_numpy(s).float().to(device) for s in bD]
tU = [torch.from_numpy(s).float().to(device) for s in bU]

cuda:0


The last steps before we can use the network: initialize the datasets and dataloaders.

In [17]:
# Building model, optimizer, and loss function

dataset_train = autoencoder_dataset(root_dir = args['data'], points_dataset = 'train',
                                           shapedata = shapedata,
                                           normalization = args['normalization'])

dataloader_train = DataLoader(dataset_train, batch_size=args['batch_size'],\
                                     shuffle = args['shuffle'], num_workers = args['num_workers'])

dataset_val = autoencoder_dataset(root_dir = args['data'], points_dataset = 'val', 
                                         shapedata = shapedata,
                                         normalization = args['normalization'])

dataloader_val = DataLoader(dataset_val, batch_size=args['batch_size'],\
                                     shuffle = False, num_workers = args['num_workers'])


dataset_test = autoencoder_dataset(root_dir = args['data'], points_dataset = 'test',
                                          shapedata = shapedata,
                                          normalization = args['normalization'])

dataloader_test = DataLoader(dataset_test, batch_size=args['batch_size'],\
                                     shuffle = False, num_workers = args['num_workers'])

Here is the neural network itself.

In [18]:
model = SpiralAutoencoder(filters_enc = args['filter_sizes_enc'],   
                          filters_dec = args['filter_sizes_dec'],
                          latent_size=args['nz'],
                          sizes=sizes,
                          spiral_sizes=spiral_sizes,
                          spirals=tspirals,
                          D=tD, U=tU,device=device).to(device)
 
    
optim = torch.optim.Adam(model.parameters(),lr=args['lr'],weight_decay=args['regularization'])
scheduler=torch.optim.lr_scheduler.StepLR(optim, args['decay_steps'],gamma=args['decay_rate'])

def loss_l1(outputs, targets):
    L = torch.abs(outputs - targets).mean()
    return L 
loss_fn = loss_l1

In [19]:
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params)) 
print(model)

Total number of parameters is: 331795
SpiralAutoencoder(
  (conv): ModuleList(
    (0): SpiralConv(
      (conv): Linear(in_features=36, out_features=16, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (1): SpiralConv(
      (conv): Linear(in_features=224, out_features=32, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (2): SpiralConv(
      (conv): Linear(in_features=288, out_features=64, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (3): SpiralConv(
      (conv): Linear(in_features=576, out_features=128, bias=True)
      (activation): ELU(alpha=1.0)
    )
  )
  (fc_latent_enc): Linear(in_features=3584, out_features=16, bias=True)
  (fc_latent_dec): Linear(in_features=16, out_features=3584, bias=True)
  (dconv): ModuleList(
    (0): SpiralConv(
      (conv): Linear(in_features=1152, out_features=64, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (1): SpiralConv(
      (conv): Linear(in_features=576, out_features=32, bias=True)
      (activation): EL

### Training the network

Here we provide the code for training the network, but since it takes a lot of data and time to train it properly, we will use the pretraied model to see the main results. For this, we set the parameter `mode` to `test`.

In [20]:
args['mode'] = 'test'
args['shuffle'] = True
args['resume'] = True
if args['mode'] == 'train':
    writer = SummaryWriter(summary_path)
    with open(os.path.join(args['results_folder'],'checkpoints', args['name'] +'_params.json'),'w') as fp:
        saveparams = copy.deepcopy(args)
        json.dump(saveparams, fp)
        
    if args['resume']:
            print('loading checkpoint from file %s'%(os.path.join(checkpoint_path,args['checkpoint_file'])))
            checkpoint_dict = torch.load(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar'),map_location=device)
            start_epoch = checkpoint_dict['epoch'] + 1
            model.load_state_dict(checkpoint_dict['autoencoder_state_dict'])
            optim.load_state_dict(checkpoint_dict['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint_dict['scheduler_state_dict'])
            print('Resuming from epoch %s'%(str(start_epoch)))     
    else:
        start_epoch = 0
        

    train_autoencoder_dataloader(dataloader_train, dataloader_val,
                          device, model, optim, loss_fn,
                          bsize = args['batch_size'],
                          start_epoch = start_epoch,
                          n_epochs = args['num_epochs'],
                          eval_freq = args['eval_frequency'],
                          scheduler = scheduler,
                          writer = writer,
                          save_recons=True,
                          shapedata=shapedata,
                          metadata_dir=checkpoint_path, samples_dir=samples_path,
                          checkpoint_path = args['checkpoint_file'])

### Testing the performance

In [21]:
if args['mode'] == 'test':
    print('loading checkpoint from file %s'%(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar')))
    checkpoint_dict = torch.load(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar'),map_location=device)
    model.load_state_dict(checkpoint_dict['autoencoder_state_dict'])
        
    predictions, norm_l1_loss, l2_loss = test_autoencoder_dataloader(device, model, dataloader_test, 
                                                                     shapedata, mm_constant = 1000)    
    np.save(os.path.join(prediction_path,'predictions'), predictions)   
        
    print('autoencoder: normalized loss', norm_l1_loss)
    
    print('autoencoder: euclidean distance in mm=', l2_loss)

loading checkpoint from file ./DFAUST/results/spirals_autoencoder/latent_16/checkpoints/checkpoint.pth.tar


100%|██████████| 63/63 [00:02<00:00, 22.27it/s]


autoencoder: normalized loss 0.0630233958363533
autoencoder: euclidean distance in mm= 13.751459121704102


## Linear algebra in the latent space

### Sampling from the latent space
_The purpose of this part is to learn how to sample from a latent space of mesh encodings that we have created._

The autoencoder architecture is designed such that it is able to reconstruct shapes from any point in the latent space. This means, that we can sample shapes from the latent space. For example:

**Exersize:** generate a random 16-dimensional vector ysing pytorch and use this vector to decode the underlying shape. 

**Hint:** use `numpy.random.RandomState` and a lucky seed is `34563`. 

**Hint:** use `model.decode` to restore the mesh structure using the learned filters.

**Hint:** transform the decoded mesh into standardized form using `mesh_mean` and `mesh_std` learned by `shapedata` object.

**Solution:**

In [22]:
def sample_from_latent_space(model, mesh_mean, mesh_std, dim=16, rand=None):
    """
    Samples from the latent space of the model. 
    
    model: the shape autoencoder model having ".decode" method 
           (note that this method only accepts 2D tensors)
    
    mesh_mean: 6890-dim vector of mean vertex parameters

    mesh_std: 6890-dim vector of vertex scales
    
    dim: dimensionality of the latent space
    
    rand: an object holding the `.uniform` methods
    """
    if None is rand:
        r = np.random
    else:
        r = rand
    
    # Step 1: generate a uniformly random vector of shape `dim`. 
    # Note: turn this vector into a matrix of size [1, dim]
    # Note: `r` variable has the `uniform` method
    latent_code_np = r.uniform(-5, 5, size=dim)
    latent_code_np = np.expand_dims(latent_code_np, axis=0)
    latent_code = torch.Tensor(latent_code_np).to(device)
    
    # Step 2: decode the latent code representation into a 
    # Note: the model outputs a batch of size [batch_size, n_vertices, 3]
    # Note: remove the added zeroth vertex
    restored_np = model.decode(latent_code).detach().cpu().numpy()
    restored_np = np.squeeze(restored_np)[:-1]  # remove the batch_size == 1
    
    # Step 3: scale the mesh_std and add mesh_mean 
    restored_np = restored_np * mesh_std + mesh_mean
    return restored_np

In [23]:
restored = sample_from_latent_space(model, shapedata.mean, shapedata.std)

show_mesh(restored)

specified material ()  not loaded!


### Exploring individual latent transformations
_The purpose of this part is to explore separate directions in the latent space._

In [24]:
latent_code_np = np.random.RandomState(34563).uniform(-5, 5, size=16)
latent_code_np = np.expand_dims(latent_code_np, axis=0)
latent_code = torch.Tensor(latent_code_np).to(device)

**Exersize:** explore the specific transformations of the latent code to interpret the meaning of its parts. Identify:
 - which transformations correspond to subjects with raised/lowered hands? 
 - .. to subjects leaning back? 
 - .. to subject gender?

**Hint:** vary separate coordinates in the `latent_code` variable.

**Solution:** 

 - 7: hands raised (neg)
 - 11! (neg): 8, -7
 - 13 m/f 10
 - 15 leaning back, elbow bending

In [25]:
### backup the previous latent code 
latent_copy = copy.deepcopy(latent_code)

# manually edit the code 
latent_copy[0, 13] = -10

def decode_mesh(latent_code, model, mesh_mean, mesh_std):
    restored_np = model.decode(latent_code).detach().cpu().numpy()
    restored_np = np.squeeze(restored_np)[:-1]
    restored_np = restored_np * mesh_std + mesh_mean
    return restored_np

# decode the mesh using the helper
restored_copy = decode_mesh(latent_copy, model, shapedata.mean, shapedata.std)

show_mesh(restored_copy)

specified material ()  not loaded!


### Temporal relations between poses

Let us load the test data.

In [26]:
test = np.load('DFAUST/preprocessed/test.npy')

mesh_female_pose = test[31]
mesh_female_neutral = test[781]

Let's see the properties of the latent space. For example, here is the reference female pose:

In [27]:
show_mesh(mesh_female_pose)

specified material ()  not loaded!


Here is the neutral female pose:

In [28]:
show_mesh(mesh_female_neutral)

specified material ()  not loaded!


Let's encode there two shapes, and in the latent space, have some algebraic operations on them:

In [29]:
def encode_mesh(mesh, model, mesh_mean, mesh_std):
    mesh = (mesh - mesh_mean) / mesh_std  # standardize mesh
    
    mesh = np.vstack([mesh, np.array([0, 0, 0])])  # add the necessary zeroth vertex 
    mesh = np.expand_dims(mesh, axis=0)  # batch of size 1
    
    mesh_tensor = torch.Tensor(mesh).to(model.device)  # convert to torch tensor

    latent_code = model.encode(mesh_tensor)  # do a forward pass
    
    return latent_code

In [30]:
latent_female_pose = encode_mesh(mesh_female_pose, model, shapedata.mean, shapedata.std)
latent_female_neutral = encode_mesh(mesh_female_neutral, model, shapedata.mean, shapedata.std)

The latent space allows us to slightly alter shapes. For example, we can interpolate between two poses:

**Exersize:** program a pose interpolation function to build intermediate shapes.

**Hint:** use a simple convex combination of the two.

**Solution:**

In [31]:
def interpolate_between(latent_1, latent_2, w1=0., w2=1.):
    """
    Weighs latent_1 and latent_2 by w1 and w2, respectively.
    
    Compute w1 * latent_1 + w2 * latent_2
    
    """
    return w1 * latent_1 + w2 * latent_2

In [32]:
interpolated_code = interpolate_between(latent_female_pose, latent_female_neutral, w1=0.9, w2=0.1)

interpolated_mesh = decode_mesh(interpolated_code, model, shapedata.mean, shapedata.std)

show_mesh(interpolated_mesh)

specified material ()  not loaded!


And, having some neutral pose and reference pose, we can extrapolate the movement:

**Exersize:** program a pose **extra**polation function to build intermediate shapes.

**Hint:** use a simple convex combination of the two but forget about positiveness of the weigths.

**Solution:**

In [33]:
extrapolated_code = interpolate_between(latent_female_pose, latent_female_neutral, w1=1.5, w2=-0.5)

extrapolated_mesh = decode_mesh(extrapolated_code, model, shapedata.mean, shapedata.std)

show_mesh(extrapolated_mesh)

specified material ()  not loaded!



### Pose transfer

Lastly, let's see the pose transfer example. Taking the reference and the neutral female poses let's transfer the reference pose to a different body shape using this neutral male pose:

In [34]:
mesh_male_neutral = test[888]

In [35]:
show_mesh(mesh_male_neutral)

specified material ()  not loaded!


Now, using a simple algebaric relations, we substract neutral female pose from the initial pose and add the neutral male pose. Let's see the result:

**Exercise:** perform pose transfer between female and male meshes:
 - encode female neutral and pose meshes into latent space
 - encode neutral male mesh into latent space
 - perform latent space arithmetic to compute an approximation 
 

 **Hint:** first compute the difference between latent codes representing pose and neutral female meshes, then add this difference to the neutral male latent code

**Solution:**

In [36]:
latent_female_pose = encode_mesh(test[31], model, shapedata.mean, shapedata.std)
latent_female_neutral = encode_mesh(test[781], model, shapedata.mean, shapedata.std)
latent_male_neutral = encode_mesh(test[888], model, shapedata.mean, shapedata.std)

In [37]:
difference_in_pose = latent_female_pose - latent_female_neutral
latent_male_pose = latent_male_neutral + difference_in_pose

In [38]:
mesh_male_pose = decode_mesh(latent_male_pose, model, shapedata.mean, shapedata.std)

In [39]:
show_mesh(mesh_male_pose)

specified material ()  not loaded!
