# Test the trained prototype model

In [1]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader


## Create a simple pipeline
This will include:
- using raw video data.
- Processing the raw video data.
- Passing the processed video data into the model.
- Displaying the results.

In [2]:

from keypoint_extractor import KeypointExtractor, KeypointExtractorV2

In [17]:
def load_and_extract(path: str) -> np.ndarray:
    extractor = KeypointExtractorV2(r"models/mediapipe/pose_landmarker_full.task")
    keypoints = extractor.extract(path)
    
    max_frames = 331
    pad_len = max_frames - len(keypoints)
    if pad_len > 0:
        pad = np.zeros((pad_len, keypoints.shape[1], keypoints.shape[2]))  # Preserve all dimensions
        padded_sample = np.concatenate((keypoints, pad), axis=0)
    else:
        padded_sample = keypoints

    return np.array(padded_sample)
        

In [4]:
from hierarchical_transformer_prototype import HierarchicalTransformer

model = HierarchicalTransformer(
    num_joints=33,
    num_frames=331,
    d_model=128,
    nhead=8,
    num_spatial_layers=1,
    num_temporal_layers=1,
    num_classes=2
)
model.load_state_dict(torch.load("hierarchical_transformer_weights_2025-05-30.pth"))



<All keys matched successfully>

In [18]:
video = load_and_extract('data/raw/deadlifts/deadlift_42_rep_1.mp4')
x_sample = video[:, :, :3]
x_sample.shape

Processing data/raw/deadlifts/deadlift_42_rep_1.mp4: 360x430, 183 frames
Extracted and normalized 182 frames from data/raw/deadlifts/deadlift_42_rep_1.mp4


(331, 33, 3)

In [20]:
# inference
x_tensor = torch.tensor(x_sample, dtype=torch.float32).unsqueeze(0)
x_tensor.shape

torch.Size([1, 331, 33, 3])

In [21]:
model.eval()
with torch.no_grad():
    logits = model(x_tensor)
    predicted_class = torch.argmax(logits, dim=1).item()

# Squats: 0 , Deadlifts: 1
print("Predicted class:", predicted_class)

Predicted class: 1


In [23]:
oldmodel = HierarchicalTransformer(
    num_joints=33,
    num_frames=331,
    d_model=128,
    nhead=8,
    num_spatial_layers=2,
    num_temporal_layers=2,
    num_classes=2
)
model.load_state_dict(torch.load("hierarchical_transformer_weights.pth"))

RuntimeError: Error(s) in loading state_dict for HierarchicalTransformer:
	Unexpected key(s) in state_dict: "spatial_encoder.transformer.layers.1.self_attn.in_proj_weight", "spatial_encoder.transformer.layers.1.self_attn.in_proj_bias", "spatial_encoder.transformer.layers.1.self_attn.out_proj.weight", "spatial_encoder.transformer.layers.1.self_attn.out_proj.bias", "spatial_encoder.transformer.layers.1.linear1.weight", "spatial_encoder.transformer.layers.1.linear1.bias", "spatial_encoder.transformer.layers.1.linear2.weight", "spatial_encoder.transformer.layers.1.linear2.bias", "spatial_encoder.transformer.layers.1.norm1.weight", "spatial_encoder.transformer.layers.1.norm1.bias", "spatial_encoder.transformer.layers.1.norm2.weight", "spatial_encoder.transformer.layers.1.norm2.bias", "temporal_encoder.transformer.layers.1.self_attn.in_proj_weight", "temporal_encoder.transformer.layers.1.self_attn.in_proj_bias", "temporal_encoder.transformer.layers.1.self_attn.out_proj.weight", "temporal_encoder.transformer.layers.1.self_attn.out_proj.bias", "temporal_encoder.transformer.layers.1.linear1.weight", "temporal_encoder.transformer.layers.1.linear1.bias", "temporal_encoder.transformer.layers.1.linear2.weight", "temporal_encoder.transformer.layers.1.linear2.bias", "temporal_encoder.transformer.layers.1.norm1.weight", "temporal_encoder.transformer.layers.1.norm1.bias", "temporal_encoder.transformer.layers.1.norm2.weight", "temporal_encoder.transformer.layers.1.norm2.bias". 

In [33]:
oldmodel.eval()
with torch.no_grad():
    logits = oldmodel(x_tensor)
    predicted_class = torch.argmax(logits, dim=1).item()

# Squats: 0 , Deadlifts: 1
print("Predicted class:", predicted_class)

Predicted class: 1
