In [2]:
import sys
# caution: path[0] is reserved for script path (or '' in REPL)
sys.path.insert(1, '../utils')

import open3d as o3d
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_scatter import scatter_max

import matplotlib.pyplot as plt

import random
import pickle
from tqdm import tqdm

In [3]:
nums = torch.tensor([10, 20])
graphs = [torch.zeros(n, 3) for n in nums]
for g in graphs:
    print(g.size())

torch.Size([10, 3])
torch.Size([20, 3])


In [4]:
G = torch.concat(graphs)
G.shape

torch.Size([30, 3])

In [5]:
indices = torch.tensor([i for i, num in enumerate(nums) for _ in range(num)])
indices

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1])

In [6]:
# Faster with repeat_interleave
idxs = torch.arange(len(nums))
indices = idxs.repeat_interleave(torch.tensor(nums))
indices

  indices = idxs.repeat_interleave(torch.tensor(nums))


tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1])

In [7]:
torch.randint(low=0, high=10, size=(2, 10))

tensor([[7, 4, 9, 7, 3, 2, 9, 8, 0, 1],
        [9, 4, 2, 3, 1, 4, 7, 4, 5, 6]])

In [8]:
all_edges = []
for i, n in enumerate(nums):
    edges = torch.randint(low=0, high=n, size=(2, n))
    if i > 0:
        edges += nums[i - 1]
    all_edges.append(edges)
all_edges

[tensor([[6, 2, 2, 9, 5, 5, 8, 6, 5, 2],
         [5, 2, 7, 7, 9, 4, 8, 8, 4, 1]]),
 tensor([[11, 12, 22, 23, 13, 14, 10, 17, 26, 25, 24, 20, 24, 19, 28, 16, 11, 21,
          15, 28],
         [11, 11, 27, 14, 13, 20, 21, 13, 21, 16, 17, 29, 23, 13, 16, 27, 20, 21,
          18, 14]])]

In [9]:
# 
offsets = torch.cat([
    torch.tensor([0]), 
    torch.cumsum(nums, dim=0)[:-1]
])
offsets

tensor([ 0, 10])

In [10]:
nums

tensor([10, 20])

In [11]:
all_edges = [
    torch.randint(low=0, high=n, size=(2, n)) + offset
    for n, offset in zip(nums, offsets)
]
E = torch.concat(all_edges, dim=1)

In [12]:
G[indices].shape

torch.Size([30, 3])

In [13]:
b = 0
# indices == b
G[indices == b]

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [99]:
# nums is verts per graph
vpg = nums
G.size(0), vpg

(30, tensor([10, 20]))

In [100]:
# Splitting and indexing on a contiguous range is faster than using boolean masks
torch.split(torch.arange(G.size(0)), vpg.tolist())

(tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
         28, 29]))

In [17]:
all_joints = []
for b in range(len(graphs)):
    mask = (indices == b)
    verts = G[mask]
    all_joints.append(verts)
print(all_joints)

[tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]), tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])]


In [26]:
from dataset import VAL_FILE_PATH, RigNetDataset
from torch.utils.data import Dataset, DataLoader
VAL_FILE_PATH

'../data/ModelResource_RigNetv1_preproccessed/mesh_graphs/val.pkl'

In [47]:
dataset = RigNetDataset(VAL_FILE_PATH, 2, seed=42)
dataset[0]

{'vertices': tensor([[-0.1510,  0.2746, -0.0119],
         [-0.1707,  0.2635, -0.0522],
         [-0.1711,  0.2387, -0.0084],
         ...,
         [ 0.0686,  0.3920, -0.0187],
         [ 0.0619,  0.4156, -0.0122],
         [ 0.0721,  0.4075, -0.0188]]),
 'num_faces': 3920,
 'one_ring': tensor([[   0,  576,    0,  ..., 1493,  357, 1493],
         [ 358, 1493, 1491,  ..., 2105, 2106, 2106]]),
 'centroid': array([-4.67432633e-05,  4.85459646e-01, -1.37504095e-02]),
 'geodesic': tensor([[   0,    0,    0,  ..., 1491, 1486, 1491],
         [1490, 1491,  356,  ..., 2102, 2105,  570]]),
 'joints': tensor([[ 4.6743e-05,  7.7347e-02,  9.1861e-03],
         [ 4.6743e-05,  3.4989e-02,  9.0736e-03],
         [ 4.6743e-05,  1.5142e-01,  9.4437e-03],
         [ 5.9097e-02, -4.7390e-02,  6.9835e-03],
         [-5.8958e-02, -4.7382e-02,  6.9431e-03],
         [ 5.2238e-02,  2.5606e-01, -3.6187e-03],
         [ 4.6743e-05,  3.4177e-01, -1.1409e-02],
         [-5.2113e-02,  2.5606e-01, -3.6187e-03],
 

In [113]:
def collate_fn(batch: list[dict]):

    verts_list = [b['vertices'] for b in batch]
    topo_list = [b['one_ring']  for b in batch]
    geodesic_list = [b['geodesic']   for b in batch]

    # No need to concatenate these
    # they are only used after the batched graph is processed and unbatched
    attn_mask_list = [b['attn_mask'] for b in batch]
    joints_list = [b['joints'] for b in batch]

    verts_per_graph = torch.tensor([verts.size(0) for verts in verts_list])
    
    # Tensor of all vertices in batch
    V = torch.concat(verts_list)

    # Vertex-Graph Mapping (maps each vertex to its graph)
    graph_idxs = torch.arange(len(verts_per_graph))
    vertex_graph_indices = graph_idxs.repeat_interleave(verts_per_graph)

    # Edge Index Offsets
    # Each edge is represented by a pair of vertex indices
    # Edge indices of each graph must be offset by the number of vertices that came before
    # The offset of the first graph is zero,
    # second graph is len(G1_vertices), third is len(G1_verts) + len(G2_verts), and so on...
    offsets = torch.cat([
        torch.tensor([0]), 
        torch.cumsum(verts_per_graph, dim=0)[:-1]
    ])

    topo_offset_list = []
    geo_offset_list = []
    for topo_b, geo_b, offset in zip(topo_list, geodesic_list, offsets):
        topo_offset_list.append(topo_b + offset)
        geo_offset_list.append(geo_b + offset)

    E_topo = torch.concat(topo_offset_list, dim=1)
    E_geo = torch.concat(geo_offset_list, dim=1)

    return {
        "vertices": V,
        "one_ring": E_topo,
        "geodesic": E_geo,
        "graph_idxs": vertex_graph_indices, # for testing
        "vertices_per_graph": verts_per_graph,
        "attn_mask_list": attn_mask_list,
        "joints_list": joints_list 
    }


In [114]:
dl = DataLoader(
    dataset=dataset,
    batch_size=2,
    collate_fn=collate_fn,
    shuffle=False
)

In [115]:
G1 = dataset[0]
G2 = dataset[1]

In [116]:
G1['one_ring'].size(1)

14087

In [117]:
edge_mask = torch.zeros(G1['one_ring'].size(1) + G2['one_ring'].size(1), dtype=torch.long)
edge_mask[G1['one_ring'].size(1):] = 1
edge_mask = edge_mask == 1
edge_mask

tensor([False, False, False,  ...,  True,  True,  True])

In [118]:
# Very basic test
for b in dl:
    mask = b['graph_idxs'] == 1
    verts_g2 = b['vertices'][mask]
    # print(b['vertices_per_graph'][0])
    # print(b['one_ring'][:, edge_mask].shape, b['vertices_per_graph'][0].shape)
    edges_g2 = b['one_ring'][:, edge_mask] - b['vertices_per_graph'][0]
    assert torch.all(torch.eq(G2['vertices'], verts_g2))
    assert torch.all(torch.eq(G2['one_ring'], edges_g2))

### Rewritten jointnet

In [127]:
from models import JointFeatureNet, MeanShiftClusterer

class JointNet(nn.Module):
    def __init__(self, 
                 edge_dropout=15,
                 initial_h=0.05,
                 train_iters=10,
                 infer_iters=50):
        super().__init__()
        self.feature_extractor = JointFeatureNet(edge_dropout=edge_dropout)
        self.clustering_head = MeanShiftClusterer(
            initial_h=initial_h,
            train_iters=train_iters,
            infer_iters=infer_iters
        )

    def forward(self, G: dict):
        """
        Run joint prediction on either a single graph or a batch of graphs.

        Args:
            G (dict): Graph data, must contain:
                - "vertices": Tensor of shape [N_total, 3]
                - "one_ring": LongTensor of shape [2, E_topo_total]
                - "geodesic": LongTensor of shape [2, E_geo_total]
            Optionally:
                - "vertices_per_graph": 1D Tensor of ints of length N_total
                  If present, enables batched mode; otherwise, treats input as a single graph.

        Returns:
            List[Tensor]: A list of length M (number of graphs).
                Each entry is a Tensor of shape [K_i, 3], the predicted joint locations
                for graph i. In unbatched mode, returns a single-element list.
        """
        verts = G["vertices"]
        E_topo = G["one_ring"]
        E_geo = G["geodesic"]

        q, attn = self.feature_extractor(verts, E_topo, E_geo)

        # Determine batching: if vertices_per_graph is provided, split accordingly.
        vpg = G.get("vertices_per_graph", None)
        if vpg is None:
            # Unbatched: single graph
            return [self.clustering_head(q, attn)]

        # Batched: split both q and attn by the given vertex counts
        # Splitting and indexing on a contigutous range is faster than using boolean masks
        splits = torch.split(torch.arange(verts.size(0)), vpg.tolist())
        outs = []
        for idxs in splits:
            q_b = q[idxs]
            attn_b = attn[idxs]
            outs.append(self.clustering_head(q_b, attn_b))
        return outs

In [123]:
dataset = RigNetDataset(VAL_FILE_PATH, 2, seed=42)
dl = DataLoader(
    dataset=dataset,
    batch_size=2,
    collate_fn=collate_fn,
    shuffle=False
)

In [126]:
net = JointNet(initial_h=0.2)
for batch in dl: 
    print(net(batch))

[tensor([[ 0.0151,  0.4341,  0.0787],
        [ 0.0152, -0.3700,  0.0790],
        [ 0.0153, -0.0285,  0.0653],
        [-0.3040,  0.3368,  0.0599],
        [ 0.3353,  0.3375,  0.0587]], grad_fn=<StackBackward0>), tensor([[ 0.0133,  0.2772,  0.1303],
        [ 0.0145, -0.2822,  0.0752],
        [-0.1496, -0.0361,  0.0641],
        [ 0.1858, -0.0378,  0.0625]], grad_fn=<StackBackward0>)]
