# Test the trained prototype model

In [1]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader


## Create a simple pipeline
This will include:
- using raw video data.
- Processing the raw video data.
- Passing the processed video data into the model.
- Displaying the results.

In [2]:

from keypoint_extractor import KeypointExtractor, KeypointExtractorV2

In [3]:
def load_and_extract(path: str) -> np.ndarray:
    extractor = KeypointExtractorV2(r"models/mediapipe/pose_landmarker_heavy.task")
    keypoints = extractor.extract(path)
    
    max_frames = 331
    pad_len = max_frames - len(keypoints)
    if pad_len > 0:
        pad = np.zeros((pad_len, keypoints.shape[1], keypoints.shape[2]))  # Preserve all dimensions
        padded_sample = np.concatenate((keypoints, pad), axis=0)
    else:
        padded_sample = keypoints

    return np.array(padded_sample)
        

In [4]:
from hierarchical_transformer_prototype import HierarchicalTransformer

model = HierarchicalTransformer(
    num_joints=33,
    num_frames=331,
    d_model=128,
    nhead=8,
    num_spatial_layers=1,
    num_temporal_layers=1,
    num_classes=2
)
model.load_state_dict(torch.load("hierarchical_transformer_weights_2025-05-30.pth"))



<All keys matched successfully>

In [29]:
video = load_and_extract('data/unseen/squat_neil_2.mp4')
x_sample = video[:, :, :3]
x_sample.shape

Processing data/unseen/squat_neil_2.mp4: 534x720, 39 frames
Extracted and normalized 39 frames from data/unseen/squat_neil_2.mp4


(331, 33, 3)

In [30]:
# inference
x_tensor = torch.tensor(x_sample, dtype=torch.float32).unsqueeze(0)
x_tensor.shape

torch.Size([1, 331, 33, 3])

In [31]:
model.eval()
with torch.no_grad():
    logits = model(x_tensor)
    predicted_class = torch.argmax(logits, dim=1).item()

# Squats: 0 , Deadlifts: 1
print("Predicted class:", predicted_class)

Predicted class: 0


In [32]:
oldmodel = HierarchicalTransformer(
    num_joints=33,
    num_frames=331,
    d_model=128,
    nhead=8,
    num_spatial_layers=2,
    num_temporal_layers=2,
    num_classes=2
)
model.load_state_dict(torch.load("hierarchical_transformer_weights.pth"))

<All keys matched successfully>

In [33]:
oldmodel.eval()
with torch.no_grad():
    logits = oldmodel(x_tensor)
    predicted_class = torch.argmax(logits, dim=1).item()

# Squats: 0 , Deadlifts: 1
print("Predicted class:", predicted_class)

Predicted class: 1
