In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

In [3]:
from src.infra import config

path = '../experiment/test-250605'
opt = config.load_config(path)
opt.path = path
print(opt)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
device_select: auto
model:
  feature_extractor:
    args:
      cache_dir: ../data/FAUST_r/diffusion
      in_channels: 128
      input_type: wks
      out_channels: 256
    name: DiffusionNet
  fm_solver:
    args:
      
    name: RegularizedFMNet
  name: URSSM
  permutation:
    args:
      tau: 0.07
    name: Similarity
path: ../experiment/test-250605
test:
  dataset:
    aug:
      args:
        footprint_wrap_folder: data/processed/footprint-wrap/
        img_size: 5
        l_mask_path: data/processed/left_foot_mask.png
        pedar_dynamic_path: data/processed/pedar_dynamic.pkl
        sense_range: 600
        stack_range: 50
      dataloader:
        args:
          batch_size: 128
          shuffle: True
        name: DataLoader
      name: Footprint2Pressure_SensorStack_Blend
    wo_aug:
      args:
        

In [4]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

## Dataset

In [5]:
import src
from src.infra.registry import DATASET_REGISTRY
DATASET_REGISTRY

{'Dataset': torch.utils.data.dataset.Dataset,
 'SingleFaustDataset': src.dataset.shape_cor.SingleFaustDataset,
 'SingleScapeDataset': src.dataset.shape_cor.SingleScapeDataset,
 'SingleShrec19Dataset': src.dataset.shape_cor.SingleShrec19Dataset,
 'SingleSmalDataset': src.dataset.shape_cor.SingleSmalDataset,
 'SingleDT4DDataset': src.dataset.shape_cor.SingleDT4DDataset,
 'SingleShrec20Dataset': src.dataset.shape_cor.SingleShrec20Dataset,
 'SingleTopKidsDataset': src.dataset.shape_cor.SingleTopKidsDataset,
 'PairDataset': src.dataset.shape_cor.PairDataset,
 'PairFaustDataset': src.dataset.shape_cor.PairFaustDataset,
 'PairScapeDataset': src.dataset.shape_cor.PairScapeDataset,
 'PairShrec19Dataset': src.dataset.shape_cor.PairShrec19Dataset,
 'PairSmalDataset': src.dataset.shape_cor.PairSmalDataset,
 'PairDT4DDataset': src.dataset.shape_cor.PairDT4DDataset,
 'PairShrec20Dataset': src.dataset.shape_cor.PairShrec20Dataset,
 'PairShrec16Dataset': src.dataset.shape_cor.PairShrec16Dataset,
 'Pai

In [6]:
dataset = DATASET_REGISTRY['PairFaustDataset'](
    phase='train',
    data_root='../data/FAUST_r/',
    return_evecs='true',
    return_faces='true',
    num_evecs=200,
    return_corr='false',
    return_dist='false',
)

In [7]:
from src.infra.registry import DATALOADER_REGISTRY
DATALOADER_REGISTRY

{'DataLoader': torch.utils.data.dataloader.DataLoader}

In [8]:
dataloader = DATALOADER_REGISTRY['DataLoader'](
    dataset=dataset,
    batch_size=1,
    shuffle=True,
)
dataloader

<torch.utils.data.dataloader.DataLoader at 0x7faa24a83100>

In [9]:
from src.utils.tensor import to_device

for batch, data in enumerate(dataloader):
    data = to_device(data, device)
    data_x, data_y = data['first'], data['second']
    print(data_x['name'], data_y['name'])
    break

['tr_reg_026'] ['tr_reg_016']


## Model

In [10]:
from src.infra.registry import MODEL_REGISTRY
MODEL_REGISTRY

{'Similarity': src.model.permutation.Similarity,
 'RegularizedFMNet': src.model.fmap.RegularizedFMNet,
 'URSSM': src.model.urssm.URSSM,
 'DiffusionNet': src.model.diffusionnet.DiffusionNet}

In [11]:
urssm = MODEL_REGISTRY[opt.model.name](opt).to(device)
urssm

URSSM(
  (feature_extractor): DiffusionNet(
    (first_linear): Linear(in_features=128, out_features=128, bias=True)
    (last_linear): Linear(in_features=128, out_features=256, bias=True)
    (blocks): ModuleList(
      (0): DiffusionNetBlock(
        (diffusion): LearnedTimeDiffusion()
        (gradient_features): SpatialGradientFeatures(
          (A_re): Linear(in_features=128, out_features=128, bias=False)
          (A_im): Linear(in_features=128, out_features=128, bias=False)
        )
        (mlp): MiniMLP(
          (miniMLP_linear_000): Linear(in_features=384, out_features=128, bias=True)
          (miniMLP_activation_000): ReLU()
          (miniMLP_dropout_001): Dropout(p=0.5, inplace=False)
          (miniMLP_linear_001): Linear(in_features=128, out_features=128, bias=True)
          (miniMLP_activation_001): ReLU()
          (miniMLP_dropout_002): Dropout(p=0.5, inplace=False)
          (miniMLP_linear_002): Linear(in_features=128, out_features=128, bias=True)
        )
  

In [12]:
network_path = '../checkpoints/faust.pth'
urssm.feature_extractor.load_state_dict(
    torch.load(network_path)['networks']['feature_extractor']
)
print(f'Loaded pretrain weights from {network_path}')

Loaded pretrain weights from ../checkpoints/faust.pth


In [13]:
infer = urssm(data_x, data_y)
infer

{'Cxy': tensor([[[ 9.8195e-01, -2.7259e-02, -2.3912e-02,  ...,  1.0097e-03,
           -7.2692e-04, -4.9136e-04],
          [-2.7238e-02, -9.9226e-01, -1.9070e-02,  ..., -6.5598e-05,
           -1.0207e-04, -9.1797e-04],
          [ 1.3511e-02,  1.8678e-02, -9.7463e-01,  ...,  3.5356e-04,
           -2.2927e-04,  3.7693e-04],
          ...,
          [ 1.9589e-03,  1.5950e-03,  1.0679e-04,  ...,  1.4746e-01,
            4.8278e-02,  1.9952e-01],
          [-5.8314e-06, -1.9600e-04, -4.7687e-04,  ...,  3.2928e-02,
           -6.9109e-02,  6.7890e-02],
          [-7.9379e-04, -9.3708e-04, -4.4316e-04,  ...,  1.7181e-01,
            1.4124e-01,  3.1555e-01]]], device='cuda:0', grad_fn=<CatBackward0>),
 'Cyx': tensor([[[ 9.4994e-01, -3.1547e-02, -3.1429e-02,  ...,  8.6229e-05,
            5.8255e-04, -7.6243e-05],
          [-1.3198e-02, -9.3600e-01,  2.1863e-02,  ..., -4.6648e-04,
            3.1874e-04, -3.3835e-04],
          [ 1.5444e-02, -1.8294e-02, -9.9469e-01,  ...,  5.3709e-04,
  

## Loss

In [14]:
from src.infra.registry import LOSS_REGISTRY
LOSS_REGISTRY

{'MSELoss': torch.nn.modules.loss.MSELoss,
 'SquaredFrobeniusLoss': src.loss.fmap.SquaredFrobeniusLoss,
 'SURFMNetLoss': src.loss.fmap.SURFMNetLoss,
 'SURFMNetLoss_wrap': src.loss.fmap.SURFMNetLoss_wrap,
 'SpatialSpectralAlignmentLoss': src.loss.fmap.SpatialSpectralAlignmentLoss,
 'SpatialSpectralAlignmentLoss_wrap': src.loss.fmap.SpatialSpectralAlignmentLoss_wrap,
 'PartialFmapsLoss': src.loss.fmap.PartialFmapsLoss,
 'DirichletLoss': src.loss.dirichlet.DirichletLoss,
 'CompositeLoss': src.loss.composite.CompositeLoss}

In [15]:
loss_dict = {}

for name, loss in opt.train.loss.items():
    loss_dict[name] = {
        'fn': LOSS_REGISTRY[loss['name']](**loss['args']).to(device),
        'weight': loss['weight'],
    }

loss_dict

{'surfm_loss': {'fn': SURFMNetLoss_wrap(
    (squared_frobenius): SquaredFrobeniusLoss()
  ),
  'weight': 1.0},
 'align_loss': {'fn': SpatialSpectralAlignmentLoss_wrap(
    (squared_frobenius): SquaredFrobeniusLoss()
  ),
  'weight': 1.0}}

In [16]:
for name, loss in loss_dict.items():
    print(name, loss['fn'](infer, data))

surfm_loss tensor(99.2401, device='cuda:0', grad_fn=<AddBackward0>)
align_loss tensor(50.5304, device='cuda:0', grad_fn=<AddBackward0>)


## Metrics

In [17]:
from src.infra.registry import METRIC_REGISTRY
METRIC_REGISTRY

{'L1Loss': torch.nn.modules.loss.L1Loss,
 'MSELoss': torch.nn.modules.loss.MSELoss,
 'MeanDiffRatio': src.metric.stats.MeanDiffRatio,
 'StdDiffRatio': src.metric.stats.StdDiffRatio,
 'calculate_geodesic_error': <function src.metric.geodist.calculate_geodesic_error(dist_x, corr_x, corr_y, p2p, return_mean=True)>,
 'GeodesicDist': src.metric.geodist.GeodesicDist,
 'plot_pck': <function src.metric.geodist.plot_pck(geo_err, threshold=0.1, steps=40)>}

In [18]:
metric_dict = {}

for name, metric in opt.train.metric.items():
    metric_dict[name] = METRIC_REGISTRY[metric['name']](**metric['args']).to(device)

metric_dict

{'geodist': GeodesicDist()}

In [21]:
for name, metric in metric_dict.items():
    print(name, metric(infer, data))

geodist 0.020181167870759964


## Training loop

In [22]:
from src.infra.registry import OPTIMIZER_REGISTRY
OPTIMIZER_REGISTRY

{'SGD': torch.optim.sgd.SGD,
 'Adam': torch.optim.adam.Adam,
 'AdamW': torch.optim.adamw.AdamW,
 'RMSprop': torch.optim.rmsprop.RMSprop,
 'Adagrad': torch.optim.adagrad.Adagrad,
 'Adadelta': torch.optim.adadelta.Adadelta}

In [23]:
optimizer = OPTIMIZER_REGISTRY[opt.train.optimizer.name](
    params=urssm.parameters(),
    **opt.train.optimizer.args,
)
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: False
    lr: 0.001
    maximize: False
    weight_decay: 0
)

In [24]:
for batch, data in enumerate(dataloader):
    data = to_device(data, device)
    infer = urssm(data['first'], data['second'])
    loss_val = 0
    
    for name, loss in loss_dict.items():
        loss_value = loss['fn'](infer, data) * loss['weight']
        loss_val += loss_value
    
    print(f'batch {batch} loss: {loss_val.item()}')

    optimizer.zero_grad()
    loss_val.backward()
    optimizer.step()

    if batch > 10:
        break

batch 0 loss: 170.54812622070312
batch 1 loss: 116.75547790527344
batch 2 loss: 175.21177673339844
batch 3 loss: 141.9226837158203
batch 4 loss: 271.80230712890625
batch 5 loss: 182.36788940429688
batch 6 loss: 157.69952392578125
batch 7 loss: 261.27020263671875
batch 8 loss: 311.74053955078125
batch 9 loss: 313.6375427246094
batch 10 loss: 353.85198974609375
batch 11 loss: 254.96670532226562
