In [None]:
!pip install --upgrade matplotlib

In [None]:
import torch
import numpy as np
import os
from numpy.random import shuffle
#%%
from tqdm import tqdm

rand_seed = 1234
torch.manual_seed(rand_seed)
#%%
params = {'N_units': 30,   # numero di dimensioni del vettore latente
          'L_rate_AE': 4e-5,
          'B_size': 8,
          'Epochs': 1000,
          'W_N_f_p': 1e-1,
          'W_N_p_f': 1e-1,
          'N_evals': 10,
          'enc_layers': [512, 256],
          'dec_layers': [256],
          'AE_activation': 'tanh',
          'datastep': 10,
          'bat_n_all': True,
          'actual_epoch': 0,
          'N_layers': [64, 128, 256, 128, 64], # Pi Layers
          }

In [None]:
partial_shapes = np.load('./partial_shapes_matrix.npy')
full_shapes = np.load('./total_shapes_matrix.npy')

pix_f = torch.tensor(full_shapes)
pix_p = torch.tensor(partial_shapes)

In [None]:
save_every = 100
test_every = 5
#%%

def distance_matrix(array1, array2):
    """
    arguments:
        array1: the array, size: (batch_size, num_point, num_feature)
        array2: the samples, size: (batch_size, num_point, num_feature)
    returns:
        distances: each entry is the distance from a sample to array1
            , it's size: (batch_size, num_point, num_point)
    """
    batch_size, num_point, num_features = array1.shape
    expanded_array1 = torch.tile(array1, dims=(1, num_point, 1))
    expanded_array2 = torch.reshape(
        torch.tile(torch.unsqueeze(array2, 2),
                   (1, 1, num_point, 1)),
        (batch_size, -1, num_features))

    distances = torch.linalg.norm(expanded_array1-expanded_array2, dim=-1)
    distances = torch.reshape(distances, (batch_size, num_point, num_point))
    return distances


def av_dist(array1, array2):
    """
    arguments:
        array1, array2: both size: (batch_size, num_points, num_feature)
    returns:
        distances: size: (1,)
    """
    distances = distance_matrix(array1, array2)
    distances, _ = torch.min(distances, dim=-1)
    distances = torch.mean(distances, dim=-1)
    return distances

def av_dist_sum(array1, array2):
    """
    arguments:
        arrays: array1, array2
    returns:
        sum of av_dist(array1, array2) and av_dist(array2, array1)
    """
    av_dist1 = av_dist(array1, array2)
    av_dist2 = av_dist(array2, array1)
    return av_dist1+av_dist2

def chamfer_distance(array1, array2):
    return torch.mean(av_dist_sum(array1, array2))

def save_decoded_shape(shape, epoch):
    with open('shape{}.npy'.format(epoch), 'wb') as f:
      np.save(f, shape)

In [None]:
#%%
model_path = 'models' 
if not os.path.exists(model_path):
  os.mkdir(model_path)
#%%
# load datasets

class Encoder(torch.nn.Module):

    def __init__(self, n_points):

        super(Encoder, self).__init__()
        self.n_points = n_points
        enc_layers = params['enc_layers'].copy()
        enc_layers.insert(0, self.n_points * 3)
        self.layers = torch.nn.ModuleList([])
        self.flat = torch.nn.Flatten()
        for n in range(1, len(enc_layers)):
            self.layers.append(torch.nn.Linear(enc_layers[n - 1], enc_layers[n]))
            self.layers.append(torch.nn.Tanh())
        self.out = torch.nn.Linear(enc_layers[-1], params['N_units'])

    def forward(self, x):
        x =  self.flat(x)
        for layer in self.layers:
            x = layer(x)
        x = self.out(x) # (N, n_points, 3)
        return x

class Decoder(torch.nn.Module):

    def __init__(self, n_points, latent_dim):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim
        self.n_points = n_points
        dec_layers = params['dec_layers'].copy()
        dec_layers.insert(0, self.latent_dim)
        self.layers = torch.nn.ModuleList([])
        for n in range(1, len(dec_layers)):
            self.layers.append(torch.nn.Linear(dec_layers[n - 1], dec_layers[n]))
            self.layers.append(torch.nn.Tanh())
        self.out = torch.nn.Linear(dec_layers[-1], self.n_points * 3)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.out(x)
        x = x.reshape(-1, self.n_points, 3) # (N, n_points, 3)
        return x


class N(torch.nn.Module):

    def __init__(self):
        super(N, self).__init__()
        self.n_features = params['N_units']
        self.bn1 = torch.nn.BatchNorm1d(self.n_features)
        self.layers = torch.nn.ModuleList()
        n1_layers = params['N_layers'].copy()
        n1_layers.insert(0, self.n_features)
        for n in range(1, len(n1_layers)):
            self.layers.append(torch.nn.Linear(n1_layers[n - 1], n1_layers[n]))
            self.layers.append(torch.nn.SELU())
            if params['bat_n_all']:
                self.layers.append(torch.nn.BatchNorm1d(n1_layers[n]))
        self.out = torch.nn.Linear(n1_layers[-1], self.n_features)

    def forward(self, x):
        # input dims: (N, 30)
        x = self.bn1(x)
        for layer in self.layers:
            x = layer(x)
        x = self.out(x)
        return x

class AE(torch.nn.Module):

  def __init__(self):
    super(AE, self).__init__()
    self.encoder = Encoder(n_points=pix_f.shape[1])
    self.decoder = Decoder(n_points=pix_f.shape[1], latent_dim=params['N_units'])

  def forward(self, x):
    # full shape autoencoder
    lat = self.encoder(x)
    mesh_rec = self.decoder(lat)
    # partial shape autoencoder
    return mesh_rec, lat
  
  def encode(self, x):
    return self.encoder(x)

  def decode(self, lat):
    return self.decoder(lat)


class Model(torch.nn.Module):

  def __init__(self):
    super(Model, self).__init__()
    self.ae_f = AE()
    self.ae_p = AE()
    self.n_p_f = N()
    self.n_f_p = N()

  def forward(self, x_f, x_p):
    # mlp network
    mesh_rec_f, lat_f = self.ae_f(x_f)
    mesh_rec_p, lat_p = self.ae_p(x_p)
    lat_rec_f = self.map_p_to_f(lat_p)
    lat_rec_p = self.map_f_to_p(lat_f)
    return mesh_rec_f, mesh_rec_p, lat_f, lat_p, lat_rec_f, lat_rec_p

  def map_p_to_f(self, lat_p):
    return self.n_p_f(lat_p)

  def map_f_to_p(self, lat_f):
    return self.n_f_p(lat_f)

In [None]:
#%%
# Datasets
N_train = 800

train_f = pix_f[:N_train, :,:]
test_f = pix_f[N_train:, :,:]
train_dataset_f = torch.utils.data.TensorDataset(torch.tensor(train_f).to(torch.device("cuda:0")))
train_dataloader_f = torch.utils.data.DataLoader(train_dataset_f, batch_size=int(params['B_size']), num_workers=0)

test_dataset_f = torch.utils.data.TensorDataset(torch.tensor(test_f).to(torch.device("cuda:0")))
test_dataloader_f = torch.utils.data.DataLoader(test_dataset_f, batch_size=int(params['B_size']), num_workers=0)


train_p = pix_p[:N_train, :,:]
test_p = pix_p[N_train:, :,:]
train_dataset_p = torch.utils.data.TensorDataset(torch.tensor(train_p).to(torch.device("cuda:0")))
train_dataloader_p = torch.utils.data.DataLoader(train_dataset_p, batch_size=int(params['B_size']), num_workers=0)

test_dataset_p = torch.utils.data.TensorDataset(torch.tensor(test_p).to(torch.device("cuda:0")))
test_dataloader_p = torch.utils.data.DataLoader(test_dataset_p, batch_size=int(params['B_size']),  num_workers=0)

# Model
model = Model().cuda()

In [None]:
# # Optimizer
all_params = set(model.parameters())

opt = torch.optim.Adam(params=[{"params": list(all_params)}], lr=params['L_rate_AE'])
# #%%

train_loss_f = []
eval_loss_f = []
train_loss_p = []
eval_loss_p = []
train_loss_n_f_p = []
train_loss_n_p_f = []
eval_loss_n_f_p = []
eval_loss_n_p_f = []
epochs = []

losses_train = []
losses_test = []

for epoch in range(0, params['Epochs']):
  epochs.append(epoch)
  avg_loss_ae_f = torch.tensor(0.).cuda()
  avg_loss_ae_p = torch.tensor(0.).cuda()
  avg_loss_n_f_p = torch.tensor(0.).cuda()
  avg_loss_n_p_f = torch.tensor(0.).cuda()
  avg_loss = torch.tensor(0.).cuda()
  for meshes_f, meshes_p in tqdm(zip(train_dataloader_f, train_dataloader_p)):
    mesh_rec_f, mesh_rec_p, lat_f, lat_p, lat_rec_f, lat_rec_p = model(meshes_f[0], meshes_p[0])
    loss_ae_f = torch.sum((meshes_f[0] - mesh_rec_f)**2, dim=-1).mean()
    loss_ae_p = torch.sum((meshes_p[0] - mesh_rec_p)**2, dim=-1).mean()
    loss_n_p_f = params['W_N_p_f']*torch.sum((lat_f - lat_rec_f)**2, dim=-1).mean()
    loss_n_f_p = params['W_N_f_p']*torch.sum((lat_p - lat_rec_p)**2, dim=-1).mean()
    loss = loss_ae_f + loss_ae_p + loss_n_p_f + loss_n_f_p
    loss.backward()
    opt.step()
    opt.zero_grad()
    avg_loss_ae_f += loss_ae_f
    avg_loss_ae_p += loss_ae_p
    avg_loss_n_p_f += loss_n_p_f
    avg_loss += loss

  avg_loss_ae_f /= len(train_dataloader_f)
  avg_loss_ae_p /= len(train_dataloader_p)
  avg_loss_n_p_f /= len(train_dataloader_p)
  avg_loss /= len(train_dataloader_f)

  losses_train.append(avg_loss)
  print('Epoch {} of {}, Train Loss Full and Partial: {:.4f}'.format(epoch+1, params['Epochs'], loss_ae_f, loss_ae_p))
  print(f'Train Avg: avg_loss_ae_f = {avg_loss_ae_f.item()}; avg_loss_ae_p = {avg_loss_ae_p.item()}; avg_loss_n_p_f = {avg_loss_n_p_f.item()}; avg_loss = {avg_loss.item()}')
  #if (epoch+1) % 500 == 0:

  if (epoch) % save_every == 0:
    torch.save(model.state_dict(), model_path + '/ae_' + str(epoch) + '.pt')

  if (epoch+1) % test_every == 0:
    avg_loss_ae_f = torch.tensor(0.).cuda()
    avg_loss_ae_p = torch.tensor(0.).cuda()
    avg_loss_n_f_p = torch.tensor(0.).cuda()
    avg_loss_n_p_f = torch.tensor(0.).cuda()
    avg_loss = torch.tensor(0.).cuda()
    with torch.no_grad():
      for meshes in tqdm(zip(train_dataloader_f, train_dataloader_p)):
        mesh_rec_f, mesh_rec_p, lat_f, lat_p, lat_rec_f , lat_rec_p = model(meshes_f[0], meshes_p[0])
        # loss_ae = chamfer_distance(meshes[0].cuda(), g_mesh)
        loss_ae_f = torch.sum((meshes_f[0] - mesh_rec_f)**2, dim = -1).mean()
        loss_ae_p = torch.sum((meshes_p[0] - mesh_rec_p)**2, dim = -1).mean()
        loss_n_p_f = params['W_N_p_f']*torch.nn.functional.mse_loss(lat_f, lat_rec_f)
        loss_n_f_p = params['W_N_f_p']*torch.nn.functional.mse_loss(lat_p, lat_rec_p)
        loss = loss_ae_f + loss_ae_p + loss_n_p_f + loss_n_f_p
        avg_loss_n_p_f += loss_n_p_f
        avg_loss_ae_f += loss_ae_f
        avg_loss_ae_p += loss_ae_p
        avg_loss += loss

      avg_loss_ae_f /= len(test_dataloader_f)
      avg_loss_ae_p /= len(test_dataloader_p)
      avg_loss /= len(test_dataloader_p)
      avg_loss_n_p_f /= len(test_dataloader_p)

    losses_test.append(avg_loss)
    print(f'Eval Avg: avg_loss_ae_f = {avg_loss_ae_f.item()}; avg_loss_ae_p = {avg_loss_ae_p.item()}; avg_loss_n_p_f = {avg_loss_n_p_f.item()}; avg_loss = {avg_loss.item()}')

losses_train = [x.detach().cpu() for x in losses_train]
losses_test = [x.detach().cpu() for x in losses_test]

In [None]:
torch.save(model.state_dict(), '/content/ae_model.pt')