## Inference Pipeline

In [8]:
import torch
import numpy as np
from omegaconf import OmegaConf
from scripts.joint_regressor import JointRegressor
from scripts.checkpoint_utils import load_full_checkpoint
from scripts.dlrhand2_joint_datamodule import _sample_mesh_as_pc


def load_model(
    joint_ckpt: str,
    config_path: str,
    backbone_ckpt: str | None = None,
    device: str = "cuda:0"
) -> JointRegressor:
    """Instantiate JointRegressor, optionally load a JEPA backbone, then load trained joint-ckpt weights."""
    # Load training config
    cfg = OmegaConf.load(config_path)

    # Build model with training hyperparams
    model = JointRegressor(
        num_points=cfg.data.num_points,
        tokenizer_groups=cfg.model.tokenizer_groups,
        tokenizer_group_size=cfg.model.tokenizer_group_size,
        tokenizer_radius=cfg.model.tokenizer_radius,
        encoder_dim=cfg.model.encoder_dim,
        encoder_depth=cfg.model.encoder_depth,
        encoder_heads=cfg.model.encoder_heads,
        encoder_dropout=cfg.model.encoder_dropout,
        encoder_attn_dropout=cfg.model.encoder_attn_dropout,
        encoder_drop_path_rate=cfg.model.encoder_drop_path_rate,
        encoder_mlp_ratio=cfg.model.encoder_mlp_ratio,
        pooling_type=cfg.model.pooling_type,
        pooling_heads=cfg.model.pooling_heads,
        pooling_dropout=cfg.model.pooling_dropout,
        head_hidden_dims=cfg.model.head_hidden_dims,
        lr_backbone=cfg.model.lr_backbone,
        lr_head=cfg.model.lr_head,
    )

    # Load backbone pretrain if provided
    if backbone_ckpt:
        load_full_checkpoint(model, backbone_ckpt)

    # Load full joint-regression checkpoint (overwrites head + backbone)
    checkpoint = torch.load(joint_ckpt, map_location=device)
    state = checkpoint.get('state_dict', checkpoint)
    model.load_state_dict(state, strict=False)

    model.eval()
    model.freeze()
    model.to(device)
    return model


@torch.inference_mode()
def predict_joint_angles(
    model: JointRegressor,
    mesh_path: str,
    pose7d: list[float] | np.ndarray,
    num_points: int | None = None,
    device: str = "cuda:0",
) -> np.ndarray:
    """
    Given a mesh file and a 7D hand pose, returns the 12-D joint-angle prediction.
    """
    n_pts = num_points or model.hparams.num_points
    pc = _sample_mesh_as_pc(mesh_path, n=n_pts)

    pts = torch.from_numpy(pc).unsqueeze(0).to(device)
    pose = torch.as_tensor(pose7d, dtype=torch.float32).unsqueeze(0).to(device)

    pred = model(pts, pose)  # (1,12)
    return pred.squeeze(0).cpu().numpy()

# Example usage:
# model = load_model(
#     joint_ckpt="checkpoints/joint-final.ckpt",
#     config_path="configs/train_joint.yaml",
#     backbone_ckpt="checkpoints/pretrain-pointjepa.ckpt",  # optional
#     device="cuda:0"
# )
# angles = predict_joint_angles(model, mesh_path, hand_pose)
# print(angles)


In [15]:
model = load_model(
    joint_ckpt="../configs/checkpoints/jepa_no_FT.ckpt",            # your trained regressor
    config_path="../configs/train_joint.yaml",
    #backbone_ckpt="checkpoints/pretrain_pointjepa.ckpt",  # if you want to override backbone first
    device="cuda:0",
)

angles = predict_joint_angles(
    model,
    mesh_path="../data/grasp_sample/02818832/4bc7ad3dbb8fc8747d8864caa856253b/0/mesh.obj",
    pose7d=[0.035, -0.01, 0.07, 0.0, 0.0, 0.0, 1.0],
)
print("Predicted joint angles:", angles)


Predicted joint angles: [-0.149     0.1322   -0.2136    0.09985  -0.1733   -0.04956   0.03072
 -0.09595   0.008736  0.028    -0.10754  -0.02014 ]


## Helper: Checkpoint Inspector

In [11]:
def inspect_checkpoint(
    ckpt_path: str,
    prefix_filter: str | None = None
) -> None:
    """
    Print all parameter keys saved in a checkpoint. Optionally filter by prefix.

    Args:
        ckpt_path: path to the .ckpt or .pth file
        prefix_filter: only show keys containing this substring
    """
    # Load checkpoint on CPU
    ckpt = torch.load(ckpt_path, map_location='cpu')
    state = ckpt.get('state_dict', ckpt)
    keys = list(state.keys())

    if prefix_filter:
        keys = [k for k in keys if prefix_filter in k]

    print(f"Found {len(keys)} parameters{' with filter ' + prefix_filter if prefix_filter else ''}:")
    for k in sorted(keys):
        print(k)

In [12]:
inspect_checkpoint("../configs/checkpoints/jepa_no_FT.ckpt")


Found 546 parameters:
mask_token
positional_encoding.0.bias
positional_encoding.0.weight
positional_encoding.2.bias
positional_encoding.2.weight
predictor.mask_token
predictor.positional_encoding.0.bias
predictor.positional_encoding.0.weight
predictor.positional_encoding.2.bias
predictor.positional_encoding.2.weight
predictor.predictor.blocks.0.attn.proj.bias
predictor.predictor.blocks.0.attn.proj.weight
predictor.predictor.blocks.0.attn.qkv.bias
predictor.predictor.blocks.0.attn.qkv.weight
predictor.predictor.blocks.0.mlp.fc1.bias
predictor.predictor.blocks.0.mlp.fc1.weight
predictor.predictor.blocks.0.mlp.fc2.bias
predictor.predictor.blocks.0.mlp.fc2.weight
predictor.predictor.blocks.0.norm1.bias
predictor.predictor.blocks.0.norm1.weight
predictor.predictor.blocks.0.norm2.bias
predictor.predictor.blocks.0.norm2.weight
predictor.predictor.blocks.1.attn.proj.bias
predictor.predictor.blocks.1.attn.proj.weight
predictor.predictor.blocks.1.attn.qkv.bias
predictor.predictor.blocks.1.attn.q