In [11]:
import sys
sys.path.insert(1, '../utils')
sys.path.insert(1, '../scripts')
import torch
import numpy as np
import open3d as o3d
import torch.nn as nn
from dataset import RigNetDataset
from models import JointNet
from train_attention import train as train_attention 
from train_jointnet import train as train_jointnet
from visualization_utils import visualize_mesh_graph, visualize_attention_heatmap
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score

In [12]:
path = '../data/ModelResource_RigNetv1_preproccessed/mesh_graphs/val.pkl'
dataset = RigNetDataset(path, num_samples=1, seed=42)

In [13]:
G = dataset[0]

In [14]:
G['attn_mask'].sum() / len(G['attn_mask'])

tensor(0.3028)

In [43]:
G['one_ring'][:, :5]

tensor([[   0,  576,    0,  578,    0],
        [ 358, 1493, 1491, 1493, 1499]])

In [44]:
G['one_ring'][:, :5].T

tensor([[   0,  358],
        [ 576, 1493],
        [   0, 1491],
        [ 578, 1493],
        [   0, 1499]])

In [46]:
G['one_ring'].view(-1, 2)

tensor([[   0,  576],
        [   0,  578],
        [   0,  574],
        ...,
        [2104, 2104],
        [2105, 2105],
        [2106, 2106]])

In [48]:
visualize_mesh_graph(
    vertices=G['vertices'].numpy(),
    edge_list=G['one_ring'].T.numpy(),
    joints_gt=G['joints'].numpy()
)

In [16]:
model = JointNet(train_iters=10, infer_iters=50, edge_dropout=15)
attn_head = model.attn_head

In [17]:
epochs = 1
lr=1e-3 # 1e-6
wd=1e-5
optimizer = torch.optim.AdamW(attn_head.parameters(), lr=lr, weight_decay=wd)

pos_frac = G['attn_mask'].sum() / len(G['attn_mask'])
pos_weight = (1 - pos_frac) /pos_frac 
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

pos_frac, pos_weight

(tensor(0.3028), tensor(2.3025))

In [18]:
train_attention(attn_head, optimizer, dataset, loss_fn=loss_fn,
                val_ds=None, epochs=epochs, logdir='../runs/attention')

Epoch 1/1: 100%|██████████| 1/1 [00:01<00:00,  1.03s/it]

Epoch 1: train loss = 9.6696e-01





In [19]:
attn_pred_probs = nn.functional.sigmoid(attn_head(G['vertices'], G['one_ring'], G['geodesic']))
attn_preds = (attn_pred_probs >= 0.5).long()
attn_preds[:5], attn_pred_probs[:5], attn_pred_probs.mean()

(tensor([[1],
         [1],
         [1],
         [1],
         [1]]),
 tensor([[0.5015],
         [0.5015],
         [0.5015],
         [0.5015],
         [0.5015]], grad_fn=<SliceBackward0>),
 tensor(0.5005, grad_fn=<MeanBackward0>))

In [20]:
accuracy_score(G['attn_mask'], attn_preds)

0.4712861888941623

In [21]:
confusion_matrix(G['attn_mask'], attn_preds, labels=[0, 1])

array([[ 469, 1000],
       [ 114,  524]])

In [22]:
visualize_attention_heatmap(G['vertices'].detach().numpy(), 
                            G['one_ring'].view(-1, 2).detach().numpy(), 
                            attn_pred_probs.detach().numpy(), 
                            joints_gt=G['joints'].detach().numpy())

#### Attn Module TL;DR:
- Use WEIGHTED BCEWithLogitsLoss
- 1e-4 TOO SLOW, and 1e-3 jittery but still continues to converge to local minimum. Should use LRScheduler

In [48]:
epochs = 50
lr=1e-4
wd=1e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)

In [63]:
model.is_training = True
train_jointnet(model, optimizer, dataset, epochs=epochs)

Epoch 1/50: 100%|██████████| 1/1 [00:03<00:00,  3.52s/it]


[Epoch 1] train=1.9497e-01 (disp 9.2131e-02, joint 1.0284e-01)


Epoch 2/50: 100%|██████████| 1/1 [00:03<00:00,  3.44s/it]


[Epoch 2] train=1.9445e-01 (disp 9.2178e-02, joint 1.0227e-01)


Epoch 3/50: 100%|██████████| 1/1 [00:03<00:00,  3.72s/it]


[Epoch 3] train=1.9383e-01 (disp 9.2175e-02, joint 1.0165e-01)


Epoch 4/50: 100%|██████████| 1/1 [00:03<00:00,  3.27s/it]


[Epoch 4] train=1.9320e-01 (disp 9.2120e-02, joint 1.0108e-01)


Epoch 5/50: 100%|██████████| 1/1 [00:03<00:00,  3.18s/it]


[Epoch 5] train=1.9253e-01 (disp 9.2060e-02, joint 1.0047e-01)


Epoch 6/50: 100%|██████████| 1/1 [00:03<00:00,  3.12s/it]


[Epoch 6] train=1.9188e-01 (disp 9.2007e-02, joint 9.9872e-02)


Epoch 7/50: 100%|██████████| 1/1 [00:03<00:00,  3.14s/it]


[Epoch 7] train=1.9120e-01 (disp 9.1878e-02, joint 9.9326e-02)


Epoch 8/50: 100%|██████████| 1/1 [00:03<00:00,  3.07s/it]


[Epoch 8] train=1.9047e-01 (disp 9.1650e-02, joint 9.8821e-02)


Epoch 9/50: 100%|██████████| 1/1 [00:03<00:00,  3.07s/it]


[Epoch 9] train=1.8973e-01 (disp 9.1392e-02, joint 9.8342e-02)


Epoch 10/50: 100%|██████████| 1/1 [00:03<00:00,  3.04s/it]


[Epoch 10] train=1.8903e-01 (disp 9.1168e-02, joint 9.7859e-02)


Epoch 11/50: 100%|██████████| 1/1 [00:03<00:00,  3.10s/it]


[Epoch 11] train=1.8835e-01 (disp 9.1016e-02, joint 9.7335e-02)


Epoch 12/50: 100%|██████████| 1/1 [00:03<00:00,  3.11s/it]


[Epoch 12] train=1.8485e-01 (disp 9.1020e-02, joint 9.3830e-02)


Epoch 13/50: 100%|██████████| 1/1 [00:03<00:00,  3.08s/it]


[Epoch 13] train=1.8725e-01 (disp 9.1082e-02, joint 9.6172e-02)


Epoch 14/50: 100%|██████████| 1/1 [00:03<00:00,  3.13s/it]


[Epoch 14] train=1.8731e-01 (disp 9.1114e-02, joint 9.6193e-02)


Epoch 15/50: 100%|██████████| 1/1 [00:03<00:00,  3.08s/it]


[Epoch 15] train=1.8641e-01 (disp 9.1059e-02, joint 9.5355e-02)


Epoch 16/50: 100%|██████████| 1/1 [00:03<00:00,  3.01s/it]


[Epoch 16] train=1.8581e-01 (disp 9.0914e-02, joint 9.4899e-02)


Epoch 17/50: 100%|██████████| 1/1 [00:03<00:00,  3.01s/it]


[Epoch 17] train=1.8530e-01 (disp 9.0737e-02, joint 9.4560e-02)


Epoch 18/50: 100%|██████████| 1/1 [00:03<00:00,  3.02s/it]


[Epoch 18] train=1.8494e-01 (disp 9.0598e-02, joint 9.4339e-02)


Epoch 19/50: 100%|██████████| 1/1 [00:03<00:00,  3.09s/it]


[Epoch 19] train=1.8381e-01 (disp 9.0289e-02, joint 9.3525e-02)


Epoch 20/50: 100%|██████████| 1/1 [00:03<00:00,  3.11s/it]


[Epoch 20] train=1.8400e-01 (disp 8.9507e-02, joint 9.4495e-02)


Epoch 21/50: 100%|██████████| 1/1 [00:03<00:00,  3.22s/it]


[Epoch 21] train=1.8139e-01 (disp 8.8714e-02, joint 9.2672e-02)


Epoch 22/50: 100%|██████████| 1/1 [00:03<00:00,  3.31s/it]


[Epoch 22] train=1.7927e-01 (disp 8.8123e-02, joint 9.1149e-02)


Epoch 23/50: 100%|██████████| 1/1 [00:03<00:00,  3.20s/it]


[Epoch 23] train=1.7507e-01 (disp 8.7740e-02, joint 8.7331e-02)


Epoch 24/50: 100%|██████████| 1/1 [00:03<00:00,  3.26s/it]


[Epoch 24] train=1.7555e-01 (disp 8.7409e-02, joint 8.8138e-02)


Epoch 25/50: 100%|██████████| 1/1 [00:03<00:00,  3.33s/it]


[Epoch 25] train=1.7842e-01 (disp 8.7144e-02, joint 9.1274e-02)


Epoch 26/50: 100%|██████████| 1/1 [00:03<00:00,  3.32s/it]


[Epoch 26] train=1.7828e-01 (disp 8.7001e-02, joint 9.1283e-02)


Epoch 27/50: 100%|██████████| 1/1 [00:03<00:00,  3.33s/it]


[Epoch 27] train=1.7782e-01 (disp 8.6832e-02, joint 9.0985e-02)


Epoch 28/50: 100%|██████████| 1/1 [00:03<00:00,  3.35s/it]


[Epoch 28] train=1.7713e-01 (disp 8.6848e-02, joint 9.0283e-02)


Epoch 29/50: 100%|██████████| 1/1 [00:03<00:00,  3.33s/it]


[Epoch 29] train=1.7804e-01 (disp 8.7168e-02, joint 9.0868e-02)


Epoch 30/50: 100%|██████████| 1/1 [00:03<00:00,  3.31s/it]


[Epoch 30] train=1.7842e-01 (disp 8.7412e-02, joint 9.1006e-02)


Epoch 31/50: 100%|██████████| 1/1 [00:03<00:00,  3.32s/it]


[Epoch 31] train=1.7827e-01 (disp 8.7490e-02, joint 9.0781e-02)


Epoch 32/50: 100%|██████████| 1/1 [00:03<00:00,  3.30s/it]


[Epoch 32] train=1.7787e-01 (disp 8.7330e-02, joint 9.0541e-02)


Epoch 33/50: 100%|██████████| 1/1 [00:03<00:00,  3.29s/it]


[Epoch 33] train=1.7731e-01 (disp 8.7114e-02, joint 9.0195e-02)


Epoch 34/50: 100%|██████████| 1/1 [00:03<00:00,  3.32s/it]


[Epoch 34] train=1.7407e-01 (disp 8.6868e-02, joint 8.7198e-02)


Epoch 35/50: 100%|██████████| 1/1 [00:03<00:00,  3.29s/it]


[Epoch 35] train=1.7551e-01 (disp 8.6688e-02, joint 8.8827e-02)


Epoch 36/50: 100%|██████████| 1/1 [00:03<00:00,  3.31s/it]


[Epoch 36] train=1.7584e-01 (disp 8.6757e-02, joint 8.9087e-02)


Epoch 37/50: 100%|██████████| 1/1 [00:03<00:00,  3.28s/it]


[Epoch 37] train=1.7536e-01 (disp 8.6848e-02, joint 8.8516e-02)


Epoch 38/50: 100%|██████████| 1/1 [00:03<00:00,  3.29s/it]


[Epoch 38] train=1.7490e-01 (disp 8.6953e-02, joint 8.7952e-02)


Epoch 39/50: 100%|██████████| 1/1 [00:03<00:00,  3.21s/it]


[Epoch 39] train=1.7442e-01 (disp 8.7244e-02, joint 8.7172e-02)


Epoch 40/50: 100%|██████████| 1/1 [00:03<00:00,  3.31s/it]


[Epoch 40] train=1.7434e-01 (disp 8.7497e-02, joint 8.6843e-02)


Epoch 41/50: 100%|██████████| 1/1 [00:03<00:00,  3.29s/it]


[Epoch 41] train=1.7423e-01 (disp 8.7435e-02, joint 8.6793e-02)


Epoch 42/50: 100%|██████████| 1/1 [00:03<00:00,  3.23s/it]


[Epoch 42] train=1.7402e-01 (disp 8.7146e-02, joint 8.6877e-02)


Epoch 43/50: 100%|██████████| 1/1 [00:03<00:00,  3.28s/it]


[Epoch 43] train=1.7390e-01 (disp 8.6736e-02, joint 8.7162e-02)


Epoch 44/50: 100%|██████████| 1/1 [00:03<00:00,  3.46s/it]


[Epoch 44] train=1.7321e-01 (disp 8.6389e-02, joint 8.6821e-02)


Epoch 45/50: 100%|██████████| 1/1 [00:03<00:00,  3.33s/it]


[Epoch 45] train=1.7262e-01 (disp 8.6211e-02, joint 8.6411e-02)


Epoch 46/50: 100%|██████████| 1/1 [00:03<00:00,  3.28s/it]


[Epoch 46] train=1.7165e-01 (disp 8.6138e-02, joint 8.5512e-02)


Epoch 47/50: 100%|██████████| 1/1 [00:03<00:00,  3.28s/it]


[Epoch 47] train=1.7234e-01 (disp 8.6091e-02, joint 8.6249e-02)


Epoch 48/50: 100%|██████████| 1/1 [00:03<00:00,  3.19s/it]


[Epoch 48] train=1.7341e-01 (disp 8.6273e-02, joint 8.7134e-02)


Epoch 49/50: 100%|██████████| 1/1 [00:03<00:00,  3.39s/it]


[Epoch 49] train=1.7417e-01 (disp 8.6618e-02, joint 8.7549e-02)


Epoch 50/50: 100%|██████████| 1/1 [00:03<00:00,  3.38s/it]

[Epoch 50] train=1.7504e-01 (disp 8.7064e-02, joint 8.7977e-02)





In [64]:
joints_pred = model(G['vertices'], G['one_ring'], G['geodesic'])

# h is not updating!
len(joints_pred), len(G['joints']), model.h

(22,
 18,
 Parameter containing:
 tensor(0.0670, requires_grad=True))

In [67]:
# Displacements
d = model.disp_head(G['vertices'], G['one_ring'], G['geodesic'])
d.mean(axis=0), d.std(axis=0)
q = d + G['vertices']
joints = model(G['vertices'], G['one_ring'], G['geodesic'])

In [68]:
visualize_mesh_graph(
    vertices=G['vertices'].detach().numpy(),
    edge_list=G['one_ring'].view(-1, 2).detach().numpy(),
    joints_gt=G['joints'].detach().numpy(),
    joints_pred=joints.detach().numpy()
)

- Debug only the displacement module 
- Overfit to higher numbers of meshes