In [9]:
import torch
from core.models.hierarchical_transformer import HierarchicalTransformer
from core.keypoint_extractor import KeypointExtractorV2
from core.utils import process_sample
import numpy as np

In [10]:
htformer_weights = "models/final/hierarchical_transformer_f201_d64_h2_s1_t1_do0.1_20250701_2251.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
htformer = HierarchicalTransformer(
    num_joints=33,
    num_frames=201,
    d_model=64,
    nhead=2,
    num_spatial_layers=1,
    num_temporal_layers=1,
    num_classes=3,
    dim_feedforward=2048,
    dropout=0.1
).to(device)
htformer.load_state_dict(torch.load(htformer_weights, map_location=device))

<All keys matched successfully>

In [12]:
# Define other models: KeypointExtractorV2
extractor = KeypointExtractorV2(model_path="models/mediapipe/pose_landmarker_full.task")

In [38]:
# Inference helpfer function
def infer_from_video(video_path, keypoint_extractor, model):
    # Have a clean state
    model.eval()

    labels = {0: "squats", 1: "deadlifts", 2:"shoulder_press"}
    
    video = keypoint_extractor.extract(video_path)
    padded_sample, attention_mask, _ = process_sample(video, max_frames=201)
    
    padded_sample = np.array(padded_sample)[:, :, :3]
    attention_mask = np.array(attention_mask)
    
    # Conver to torch tensors
    X_tensor = torch.tensor(padded_sample, dtype=torch.float32).to(device)
    mask_tensor = torch.tensor(attention_mask, dtype=torch.float32).to(device)
    
    # Add batch dimension
    X_tensor = X_tensor.unsqueeze(0)
    mask_tensor = mask_tensor.unsqueeze(0)
    
    with torch.no_grad():
        output = model(X_tensor, mask_tensor)
        pred = output.argmax(1)
        print("Predicted class:", labels[pred.item()])

In [42]:
infer_from_video(
    video_path="data/unseen/squat_neil_3.mp4",
    keypoint_extractor=extractor,
    model=htformer
)

Processing data/unseen/squat_neil_3.mp4: 534x720, 31 frames
Extracted and normalized 30 frames from data/unseen/squat_neil_3.mp4
Predicted class: squats
