In [None]:
import scipy.io as sio


from IPython.display import clear_output

import torch

from torch_geometric.data import DataLoader
import torch_geometric.transforms as T

# Dataset Function
from datasets import CuboidDataset

from models.HarmonicResNet import SiameseHSN

# Transforms
from transforms import (VectorHeat, MultiscaleRadiusGraph)

In [None]:
device = torch.device('cuda')

In [None]:
# Ratios used for pooling
ratios=[1, 0.5, 0.25, 0.1]

# Radii
radii = [0.1, 0.2, 0.4, 0.8]

# 2. Define transformations to be performed on the dataset:
# Transformation that computes a multi-scale radius graph and precomputes the logarithmic map.
pre_transform = T.Compose((
    MultiscaleRadiusGraph(ratios, radii, 256, loop=True, flow='target_to_source'),
    VectorHeat())
)

In [None]:
dataset = CuboidDataset(root = 'data/DeepCuboidSeg12K/', pre_transform=pre_transform)
loader = DataLoader(dataset,batch_size=1)

In [None]:
# Move the network to the GPU
model = SiameseHSN().to(device)

# Set up the ADAM optimizer with learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
loss_template = "Mini-Batch: {} Loss {:f}, E1: {:f}, E2: {:f}, E3: {:f}, E4: {:f}"  #

In [None]:
model.train()

In [None]:
for epoch in range(50):
    clear_output()
    batch = 0
    for data in loader:
        C12, C21, src_feat, tar_feat, src_verts, tar_verts, E1, E2, E3, E4, P12, P21  = model(data[0], data[1])
        
        final_loss = E1+E2+E3+E4
        final_loss.backward()
        optimizer.step()
        
        print(loss_template.format(batch, final_loss,E1, E2, E3, E4))
        if batch % 10 == 0:
            mat_dict = dict(F=src_feat.cpu().detach().numpy(),
                            G=tar_feat.cpu().detach().numpy(),
                            P_est_AB=P12.cpu().detach().numpy(),
                            P_est_BA=P21.cpu().detach().numpy(),
                            src_names=data[0].name[0],
                            tar_names=data[1].name[0],
                            src_vertices=src_verts.cpu().detach().numpy(),
                            tar_vertices=tar_verts.cpu().detach().numpy(),
                            C_est_AB=C12.cpu().detach().numpy(),
                            C_est_BA=C21.cpu().detach().numpy())

            mat_filename = 'map_' + str(batch) + ".mat"
            sio.savemat(mat_filename, mat_dict)
        
        batch += 1