In [7]:
import numpy as np
import torch
import torch.nn as nn
import math


In [2]:
class PositionalEncoding(nn.Module):
    """
    Sinusoidal positional encoding as in "Attention Is All You Need".
    Adds positional information to token embeddings.
    """
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        # Create matrix of shape (max_len, d_model)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # shape (max_len, 1, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (seq_len, batch_size, d_model)
        Returns:
            x + positional encoding: same shape as input
        """
        seq_len = x.size(0)
        return x + self.pe[:seq_len]


In [3]:
class SpatialTransformerEncoder(nn.Module):
    """
    Captures spatial correlations between joints within each frame.
    Processes all frames in parallel for efficiency.
    """
    def __init__(
        self,
        d_model: int,
        nhead: int,
        num_layers: int,
        num_joints: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1
    ):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.pos_encoder = PositionalEncoding(d_model, max_len=num_joints)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch_size, num_frames, num_joints, d_model)
        Returns:
            y: Tensor of same shape as input
        """
        b, f, j, d = x.shape
        # Combine batch and frame dims: treat each frame as independent sequence
        x = x.view(b * f, j, d).transpose(0, 1)  # (seq_len=j, batch=b*f, d_model=d)
        x = self.pos_encoder(x)
        y = self.transformer(x)  # (j, b*f, d)
        y = y.transpose(0, 1).view(b, f, j, d)
        return y


In [4]:
class TemporalTransformerEncoder(nn.Module):
    """
    Models temporal dependencies across frames.
    """
    def __init__(
        self,
        d_model: int,
        nhead: int,
        num_layers: int,
        num_frames: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1
    ):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        self.pos_encoder = PositionalEncoding(d_model, max_len=num_frames)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch_size, num_frames, d_model)
        Returns:
            y: Tensor of same shape as input
        """
        # Transformer expects (seq_len, batch_size, d_model)
        x = x.transpose(0, 1)  # (num_frames, batch, d_model)
        x = self.pos_encoder(x)
        y = self.transformer(x)
        y = y.transpose(0, 1)  # (batch, num_frames, d_model)
        return y

In [5]:
class HierarchicalTransformer(nn.Module):
    """
    Hierarchical Transformer combining spatial and temporal encoders for exercise recognition.
    Input: X of shape (batch_size, num_frames, num_joints, 3)
    Output: logits over exercise classes
    """
    def __init__(
        self,
        num_joints: int,
        num_frames: int,
        d_model: int,
        nhead: int,
        num_spatial_layers: int,
        num_temporal_layers: int,
        num_classes: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1
    ):
        super().__init__()
        # Embed 3D coordinates into d_model dims
        self.embedding = nn.Linear(3, d_model)

        # Spatial transformer to capture joint correlations per frame
        self.spatial_encoder = SpatialTransformerEncoder(
            d_model=d_model,
            nhead=nhead,
            num_layers=num_spatial_layers,
            num_joints=num_joints,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )

        # Temporal transformer to capture motion across frames
        self.temporal_encoder = TemporalTransformerEncoder(
            d_model=d_model,
            nhead=nhead,
            num_layers=num_temporal_layers,
            num_frames=num_frames,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )

        # Classification head
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch_size, num_frames, num_joints, 3)
        Returns:
            logits: Tensor of shape (batch_size, num_classes)
        """
        # Embed coordinates
        x = self.embedding(x)  # (batch, F, J, d_model)

        # Spatial encoding
        x = self.spatial_encoder(x)  # (batch, F, J, d_model)

        # Aggregate joints per frame (mean pooling)
        x = x.mean(dim=2)  # (batch, F, d_model)

        # Temporal encoding
        x = self.temporal_encoder(x)  # (batch, F, d_model)

        # Global pooling across frames
        x = x.mean(dim=1)  # (batch, d_model)

        # Classification
        logits = self.classifier(x)  # (batch, num_classes)
        return logits

In [11]:
batch_size = 8
num_frames = 100
num_joints = 33
num_classes = 2
model = HierarchicalTransformer(
    num_joints=num_joints,
    num_frames=num_frames,
    d_model=128,
    nhead=8,
    num_spatial_layers=1,
    num_temporal_layers=1,
    num_classes=num_classes
)
# Dummy input: batch of sequences of 3D joint coords
x = torch.rand(batch_size, num_frames, num_joints, 3)


In [13]:
x.shape

torch.Size([8, 100, 33, 3])

In [14]:
logits = model(x)
print("Logits shape:", logits.shape)  # expected (8, 3)

Logits shape: torch.Size([8, 2])


In [24]:
# Test load
data = np.load('data/keypoints/deadlifts_squats.npz')
X = data['X'][:, :, :, :3]  # remove visibility => shape: (1146, 331, 33, 3)
y = data['y']
print(X.shape, y.shape)

(1146, 331, 33, 3) (1146,)


In [34]:
# Assume `model` is already loaded and on the correct device
x_sample = torch.tensor(X[1000]).unsqueeze(0).float()  # shape: (1, 331, 33, 3)
y_sample = y[0]

# Ensure model matches the input shape (e.g., num_frames=331)
model = HierarchicalTransformer(
    num_joints=33,
    num_frames=331,
    d_model=128,
    nhead=8,
    num_spatial_layers=1,
    num_temporal_layers=1,
    num_classes=len(np.unique(y))
)

model.eval()
with torch.no_grad():
    logits = model(x_sample)
    pred = torch.argmax(logits, dim=1).item()

print(f"True label: {y_sample}, Predicted: {pred}")

True label: 0, Predicted: 1


  x_sample = torch.tensor(X[1000]).unsqueeze(0).float()  # shape: (1, 331, 33, 3)


In [36]:
# Training
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

In [37]:
data = np.load('data/keypoints/deadlifts_squats.npz')
X = data['X'][:, :, :, :3]  # drop visibility
y = data['y']

# Convert to PyTorch tensors
X_tensor = torch.tensor(X).float()
y_tensor = torch.tensor(y).long()

# Split train/test (simple split)
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(X_tensor, y_tensor, test_size=0.2, random_state=42)

# Create dataloaders
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HierarchicalTransformer(
    num_joints=33,
    num_frames=331,
    d_model=128,
    nhead=8,
    num_spatial_layers=1,
    num_temporal_layers=1,
    num_classes=len(torch.unique(y_tensor))
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)



In [39]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    correct = 0

    for X_batch, y_batch in train_loader:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)

        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * X_batch.size(0)
        correct += (outputs.argmax(1) == y_batch).sum().item()

    train_loss /= len(train_loader.dataset)
    train_acc = correct / len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0
    val_correct = 0
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)

            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)

            val_loss += loss.item() * X_batch.size(0)
            val_correct += (outputs.argmax(1) == y_batch).sum().item()

    val_loss /= len(val_loader.dataset)
    val_acc = val_correct / len(val_loader.dataset)

    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")

Epoch 1/10 | Train Loss: 0.6812, Acc: 0.5841 | Val Loss: 0.5110, Acc: 0.9043
Epoch 2/10 | Train Loss: 0.2301, Acc: 0.9345 | Val Loss: 0.2265, Acc: 0.9174
Epoch 3/10 | Train Loss: 0.1384, Acc: 0.9487 | Val Loss: 0.2068, Acc: 0.9391
Epoch 4/10 | Train Loss: 0.1176, Acc: 0.9531 | Val Loss: 0.2206, Acc: 0.9174
Epoch 5/10 | Train Loss: 0.1198, Acc: 0.9498 | Val Loss: 0.1913, Acc: 0.9391
Epoch 6/10 | Train Loss: 0.0976, Acc: 0.9651 | Val Loss: 0.2661, Acc: 0.9043
Epoch 7/10 | Train Loss: 0.0865, Acc: 0.9683 | Val Loss: 0.1993, Acc: 0.9391
Epoch 8/10 | Train Loss: 0.0860, Acc: 0.9683 | Val Loss: 0.2831, Acc: 0.9043
Epoch 9/10 | Train Loss: 0.0750, Acc: 0.9771 | Val Loss: 0.2265, Acc: 0.9304
Epoch 10/10 | Train Loss: 0.0713, Acc: 0.9727 | Val Loss: 0.1894, Acc: 0.9391


In [40]:
# save
torch.save(model.state_dict(), 'hierarchical_transformer_weights.pth')