In [1]:
import sys
# caution: path[0] is reserved for script path (or '' in REPL)
sys.path.insert(1, '../utils')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from dataset import RigNetDataset, collate_fn, FILE_PATHS, POS_ATTN_AVG
from models import JointNet, JointDisplacementModule, VertexAttentionModule, GMEdgeConv, GMEdgeNet
from visualization_utils import visualize_mesh_graph, visualize_attention_heatmap
from training_utils import chamfer_loss, save_model, dict_to_device

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, confusion_matrix, precision_recall_fscore_support, precision_recall_curve

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

### Test Dataset

In [3]:
dataset = RigNetDataset(
    FILE_PATHS['val'],
    num_samples=16,
    seed=42
)

dl = DataLoader(
    dataset=dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=collate_fn
)

In [4]:
def index_iterator(iterator, index):
    for i, item in enumerate(iterator):
        if i == index:
            break
    return item

### Vertex Attention Module

In [5]:
model_path = 'models/attn_module_20250702-030900.pt'
attn_module = VertexAttentionModule()
attn_module.load_state_dict(torch.load(model_path, map_location=device))

<All keys matched successfully>

In [6]:
mesh_idx = 1

batch = index_iterator(dl, mesh_idx)

attn_module.to(device)
batch = dict_to_device(batch, device)

# predict
attn_pred_probs = F.sigmoid(attn_module(
    batch['vertices'],
    batch['one_ring'],
    batch['geodesic']
).squeeze())

visualize_attention_heatmap(
    verts=batch['vertices'].detach().numpy(),
    edges=batch['one_ring'].T.detach().numpy(),
    attn_pred=attn_pred_probs.detach().numpy(),
    joints_gt=batch['joints_list'][0],
)


### Vertex Displacement

In [5]:
model_path = 'models/disp_module_20250702-031409.pt'
disp_module = JointDisplacementModule()
disp_module.load_state_dict(torch.load(model_path, map_location=device))

<All keys matched successfully>

In [7]:
mesh_idx = 7

batch = index_iterator(dl, mesh_idx)

disp_module.to(device)
batch = dict_to_device(batch, device)

# predict
disp = disp_module(
    batch['vertices'],
    batch['one_ring'],
    batch['geodesic']
).squeeze()
q = batch['vertices'] + disp

visualize_mesh_graph(
    vertices=batch['vertices'].detach().numpy(),
    edge_list=batch['one_ring'].T.detach().numpy(),
    joints_gt=batch['joints_list'][0],
    displaced_verts=q.detach().numpy()
)
