# Test the trained prototype model

In [1]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader


## Create a simple pipeline
This will include:
- using raw video data.
- Processing the raw video data.
- Passing the processed video data into the model.
- Displaying the results.

In [2]:
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
import cv2 as cv

In [3]:
def resize(frame, target_size=(256, 256)):
    h, w = frame.shape[:2]
    scale = min(target_size[0]/h, target_size[1]/w)
    new_w, new_h = int(w*scale), int(h*scale)
    resized = cv.resize(frame, (new_w, new_h))
    
    # Create a blank canvas and paste the resized frame
    canvas = np.zeros((target_size[1], target_size[0], 3), dtype=np.uint8)
    y_offset = (target_size[1] - new_h) // 2
    x_offset = (target_size[0] - new_w) // 2
    canvas[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized
    return canvas

In [4]:
def load_and_extract(path: str) -> np.ndarray:
    BaseOptions = python.BaseOptions
    PoseLandmarker = vision.PoseLandmarker
    PoseLandmarkerOptions = vision.PoseLandmarkerOptions
    VisionRunningMode = vision.RunningMode
    
    options = PoseLandmarkerOptions(
    base_options=BaseOptions(
        model_asset_path=r"models/mediapipe/pose_landmarker_heavy.task",
        delegate=python.BaseOptions.Delegate.CPU),
        running_mode=VisionRunningMode.VIDEO)

    with PoseLandmarker.create_from_options(options) as landmarker:
        cap = cv.VideoCapture(path)
        keypoints = []
        frame_idx = 0
        while cap.isOpened():
            r, frame = cap.read()
            if not r: break
            
            rf = resize(frame)
            mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rf)
            detection_result = landmarker.detect_for_video(mp_image, frame_idx)
            pose_landmarks_list = detection_result.pose_landmarks
            
            # Convert landmarks to a numpy-friendly format
            frame_keypoints = []
            if pose_landmarks_list:
                for landmark in pose_landmarks_list[0]:
                    frame_keypoints.append([landmark.x, landmark.y, landmark.z, landmark.visibility])
            else:
                # If no landmarks detected, add a placeholder (all zeros)
                frame_keypoints = np.zeros((33, 4))  # MediaPipe pose has 33 landmarks with x,y,z,visibility
                
            keypoints.append(frame_keypoints)
            frame_idx+=1
            
    # Pad or truncate to a fixed length
    
    keypoints = np.array(keypoints)
    
    max_frames = 331
    pad_len = max_frames - len(keypoints)
    if pad_len > 0:
        pad = np.zeros((pad_len, keypoints.shape[1], keypoints.shape[2]))  # Preserve all dimensions
        padded_sample = np.concatenate((keypoints, pad), axis=0)
    else:
        padded_sample = keypoints

    return np.array(padded_sample)
        

In [6]:
from hierarchical_transformer_prototype import HierarchicalTransformer

model = HierarchicalTransformer(
    num_joints=33,
    num_frames=331,
    d_model=128,
    nhead=8,
    num_spatial_layers=2,
    num_temporal_layers=2,
    num_classes=2
)
model.load_state_dict(torch.load("hierarchical_transformer_weights.pth"))

<All keys matched successfully>

In [16]:
video = load_and_extract('data/unseen/squat_neil_3.mp4')
x_sample = video[:, :, :3]
x_sample.shape

(331, 33, 3)

In [17]:
# inference
x_tensor = torch.tensor(x_sample, dtype=torch.float32).unsqueeze(0)
x_tensor.shape

torch.Size([1, 331, 33, 3])

In [18]:
model.eval()
with torch.no_grad():
    logits = model(x_tensor)
    predicted_class = torch.argmax(logits, dim=1).item()

# Squats: 0 , Deadlifts: 1
print("Predicted class:", predicted_class)

Predicted class: 0
