In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

In [3]:
from src.infra import config

path = '../experiment/test-250605'
opt = config.load_config(path)
opt.path = path
print(opt)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
benchmark:
  dataset:
    aug:
      args:
        footprint_wrap_folder: data/processed/footprint-wrap/
        img_size: 5
        l_mask_path: data/processed/left_foot_mask.png
        pedar_dynamic_path: data/processed/pedar_dynamic.pkl
        sense_range: 600
        stack_range: 50
      dataloader:
        args:
          batch_size: 128
          shuffle: True
        name: DataLoader
      name: Footprint2Pressure_SensorStack_Blend
    wo_aug:
      args:
        footprint_wrap_folder: data/processed/footprint-wrap/
        img_size: 5
        l_mask_path: data/processed/left_foot_mask.png
        pedar_dynamic_path: data/processed/pedar_dynamic.pkl
        sense_range: 600
        stack_range: 50
      dataloader:
        args:
          batch_size: 128
          shuffle: True
        name: DataLoader
      n

In [4]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

## Dataset

In [5]:
import src
from src.infra.registry import DATASET_REGISTRY
DATASET_REGISTRY

{'Dataset': torch.utils.data.dataset.Dataset,
 'SingleFaustDataset': src.dataset.shape_cor.SingleFaustDataset,
 'SingleScapeDataset': src.dataset.shape_cor.SingleScapeDataset,
 'SingleShrec19Dataset': src.dataset.shape_cor.SingleShrec19Dataset,
 'SingleSmalDataset': src.dataset.shape_cor.SingleSmalDataset,
 'SingleDT4DDataset': src.dataset.shape_cor.SingleDT4DDataset,
 'SingleShrec20Dataset': src.dataset.shape_cor.SingleShrec20Dataset,
 'SingleTopKidsDataset': src.dataset.shape_cor.SingleTopKidsDataset,
 'PairDataset': src.dataset.shape_cor.PairDataset,
 'PairFaustDataset': src.dataset.shape_cor.PairFaustDataset,
 'PairScapeDataset': src.dataset.shape_cor.PairScapeDataset,
 'PairShrec19Dataset': src.dataset.shape_cor.PairShrec19Dataset,
 'PairSmalDataset': src.dataset.shape_cor.PairSmalDataset,
 'PairDT4DDataset': src.dataset.shape_cor.PairDT4DDataset,
 'PairShrec20Dataset': src.dataset.shape_cor.PairShrec20Dataset,
 'PairShrec16Dataset': src.dataset.shape_cor.PairShrec16Dataset,
 'Pai

In [6]:
dataset = DATASET_REGISTRY['PairFaustDataset'](
    phase='train',
    data_root='../data/FAUST_r/',
    return_evecs='true',
    return_faces='true',
    num_evecs=200,
    return_corr='false',
    return_dist='false',
)

In [19]:
from src.utils.tensor import to_device

data = to_device(dataset[1], device)
data_x, data_y = to_device(data['first'], device), to_device(data['second'], device)
data_x['name'], data_y['name']

('tr_reg_000', 'tr_reg_001')

In [8]:
from src.infra.registry import DATALOADER_REGISTRY
DATALOADER_REGISTRY

{'DataLoader': torch.utils.data.dataloader.DataLoader}

In [9]:
dataloader = DATALOADER_REGISTRY['DataLoader'](
    dataset=dataset,
    batch_size=1,
    shuffle=True,
)
dataloader

<torch.utils.data.dataloader.DataLoader at 0x7c9861de8ee0>

In [10]:
for batch, data in enumerate(dataloader):
    data_x, data_y = to_device(data['first'], device), to_device(data['second'], device)
    print(data_x['name'], data_y['name'])
    break

['tr_reg_032'] ['tr_reg_019']


## Model

In [11]:
from src.infra.registry import MODEL_REGISTRY
MODEL_REGISTRY

{'Similarity': src.model.permutation.Similarity,
 'RegularizedFMNet': src.model.fmap.RegularizedFMNet,
 'URSSM': src.model.urssm.URSSM,
 'DiffusionNet': src.model.diffusionnet.DiffusionNet}

In [12]:
urssm = MODEL_REGISTRY[opt.model.name](opt).to(device)
urssm

URSSM(
  (feature_extractor): DiffusionNet(
    (first_linear): Linear(in_features=128, out_features=128, bias=True)
    (last_linear): Linear(in_features=128, out_features=256, bias=True)
    (blocks): ModuleList(
      (0): DiffusionNetBlock(
        (diffusion): LearnedTimeDiffusion()
        (gradient_features): SpatialGradientFeatures(
          (A_re): Linear(in_features=128, out_features=128, bias=False)
          (A_im): Linear(in_features=128, out_features=128, bias=False)
        )
        (mlp): MiniMLP(
          (miniMLP_linear_000): Linear(in_features=384, out_features=128, bias=True)
          (miniMLP_activation_000): ReLU()
          (miniMLP_dropout_001): Dropout(p=0.5, inplace=False)
          (miniMLP_linear_001): Linear(in_features=128, out_features=128, bias=True)
          (miniMLP_activation_001): ReLU()
          (miniMLP_dropout_002): Dropout(p=0.5, inplace=False)
          (miniMLP_linear_002): Linear(in_features=128, out_features=128, bias=True)
        )
  

In [13]:
network_path = '../checkpoints/faust.pth'
urssm.feature_extractor.load_state_dict(
    torch.load(network_path)['networks']['feature_extractor']
)
print(f'Loaded pretrain weights from {network_path}')

Loaded pretrain weights from ../checkpoints/faust.pth


In [14]:
infer = urssm(data_x, data_y)
infer

{'Cxy': tensor([[[-9.7482e-01,  1.2067e-02,  3.8330e-04,  ..., -4.3933e-04,
            7.8162e-04,  1.7684e-05],
          [-1.3796e-02, -9.4858e-01, -5.8940e-02,  ...,  3.0590e-04,
            7.0716e-04, -1.6155e-04],
          [ 1.0811e-02, -4.1229e-02,  9.3997e-01,  ...,  3.0392e-04,
           -2.1220e-04, -1.9038e-04],
          ...,
          [ 1.3636e-03,  1.3223e-03, -1.1529e-03,  ...,  7.6819e-02,
            1.6766e-01, -2.0771e-02],
          [-3.3330e-04,  1.2754e-03,  2.9656e-04,  ...,  6.3550e-02,
           -1.1424e-01, -2.4210e-02],
          [-5.1405e-04,  8.1348e-05, -1.1189e-03,  ...,  1.5569e-01,
           -2.5701e-02, -1.1376e-01]]], device='cuda:0', grad_fn=<CatBackward0>),
 'Cyx': tensor([[[-9.7188e-01, -5.2084e-04,  1.8461e-04,  ...,  5.0212e-04,
           -3.1207e-04, -9.2881e-04],
          [ 9.0058e-03, -9.5826e-01, -5.6782e-02,  ...,  2.4080e-04,
           -3.6017e-04,  8.6898e-05],
          [ 1.0530e-02, -4.4639e-02,  1.0322e+00,  ..., -5.3495e-04,
  

## Loss

In [15]:
from src.infra.registry import LOSS_REGISTRY
LOSS_REGISTRY

{'MSELoss': torch.nn.modules.loss.MSELoss,
 'SquaredFrobeniusLoss': src.loss.fmap.SquaredFrobeniusLoss,
 'SURFMNetLoss': src.loss.fmap.SURFMNetLoss,
 'SURFMNetLoss_wrap': src.loss.fmap.SURFMNetLoss_wrap,
 'SpatialSpectralAlignmentLoss': src.loss.fmap.SpatialSpectralAlignmentLoss,
 'SpatialSpectralAlignmentLoss_wrap': src.loss.fmap.SpatialSpectralAlignmentLoss_wrap,
 'PartialFmapsLoss': src.loss.fmap.PartialFmapsLoss,
 'DirichletLoss': src.loss.dirichlet.DirichletLoss,
 'CompositeLoss': src.loss.composite.CompositeLoss}

In [16]:
loss = LOSS_REGISTRY[opt.train.loss.name](opt).to(device)
loss, loss.weights

(CompositeLoss(
   (losses): ModuleList(
     (0): SURFMNetLoss_wrap(
       (squared_frobenius): SquaredFrobeniusLoss()
     )
     (1): SpatialSpectralAlignmentLoss_wrap(
       (squared_frobenius): SquaredFrobeniusLoss()
     )
   )
 ),
 [1.0, 1.0])

In [17]:
loss(infer, data)

tensor(138.9938, device='cuda:0', grad_fn=<AddBackward0>)

In [18]:
to_device(data, device)

{'first': {'name': ['tr_reg_032'],
  'verts': tensor([[[ 0.0261,  0.5284,  0.1571],
           [ 0.0381,  0.5354,  0.1466],
           [ 0.0382,  0.4832,  0.1354],
           ...,
           [-0.1363,  0.3480, -0.0092],
           [-0.1369,  0.3391,  0.0033],
           [-0.1373,  0.3516, -0.0145]]], device='cuda:0'),
  'faces': tensor([[[  70,    1,   14],
           [  70,   14,   67],
           [  70,    0,   69],
           ...,
           [3374, 2941, 4995],
           [2941, 4994, 4995],
           [3373, 4995, 4994]]], device='cuda:0'),
  'evecs': tensor([[[ 1.0000,  1.0089,  0.0344,  ..., -1.1998, -1.2748, -0.3302],
           [ 1.0000,  1.0100,  0.0344,  ..., -1.1494, -1.5042,  0.7864],
           [ 1.0000,  0.9872,  0.0331,  ..., -0.8419, -0.8645,  0.9074],
           ...,
           [ 1.0000,  0.8921,  0.0472,  ...,  0.1738, -0.5052, -0.3958],
           [ 1.0000,  0.8955,  0.0481,  ...,  0.5150, -0.7056, -0.1909],
           [ 1.0000,  0.8968,  0.0470,  ..., -0.0034, -0.36

## Training loop

In [32]:
import torch.optim as optim

optimizer = optim.Adam(
    params=urssm.parameters(),
    lr=1.0e-3,
)
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 0.001
    maximize: False
    weight_decay: 0
)

In [37]:
for batch, data in enumerate(dataloader):
    data = to_device(data, device)
    infer = urssm(data['first'], data['second'])
    loss_value = loss(infer, data)

    print(f'batch {batch} loss: {loss_value.item()}')

    optimizer.zero_grad()
    loss_value.backward()
    optimizer.step()

    if batch > 10:
        break

batch 0 loss: 168.72573852539062
batch 1 loss: 161.0816650390625
batch 2 loss: 158.69309997558594
batch 3 loss: 140.21925354003906
batch 4 loss: 118.90010833740234
batch 5 loss: 168.7462158203125
batch 6 loss: 147.72103881835938
batch 7 loss: 130.97396850585938
batch 8 loss: 135.06170654296875
batch 9 loss: 149.27003479003906
batch 10 loss: 163.2886962890625
batch 11 loss: 167.516845703125


In [None]:
for batch in range(10):
    data = to_device(data, device)
    infer = urssm(data['first'], data['second'])
    loss_value = loss(infer, data)
    
    print(f'batch {batch} loss: {loss_value.item()}')

    optimizer.zero_grad()
    loss_value.backward()
    optimizer.step()

    if batch > 10:
        break

batch 0 loss: 153.46319580078125
batch 1 loss: 152.51292419433594
batch 2 loss: 150.9139404296875
batch 3 loss: 150.45840454101562
batch 4 loss: 149.34332275390625
batch 5 loss: 150.97003173828125
batch 6 loss: 150.73577880859375
batch 7 loss: 148.7731170654297
batch 8 loss: 144.04058837890625
batch 9 loss: 147.0554656982422
