# ðŸ™Œ ModalX v2 - Gesture ST-GCN Training

**All fixes applied:**
- âœ… Fixed BatchNorm dimensions
- âœ… No external data needed
- âœ… Local save with auto-download

In [None]:
!pip install -q torch torchvision tqdm matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')

SAVE_DIR = '/content/modalx_weights'
os.makedirs(SAVE_DIR, exist_ok=True)

## ST-GCN Model (Fixed)

In [None]:
class SpatialGraphConv(nn.Module):
    def __init__(self, in_ch, out_ch, num_joints):
        super().__init__()
        # Learnable adjacency
        self.A = nn.Parameter(torch.randn(num_joints, num_joints) * 0.01)
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_ch)
    
    def forward(self, x):
        # x: (B, C, T, V)
        A = F.softmax(self.A, dim=-1)
        x = torch.einsum('bctv,vw->bctw', x, A)
        x = self.conv(x)
        return F.relu(self.bn(x))


class TemporalConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=9, stride=1):
        super().__init__()
        pad = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(in_ch, out_ch, (kernel_size, 1), (stride, 1), (pad, 0))
        self.bn = nn.BatchNorm2d(out_ch)
    
    def forward(self, x):
        return F.relu(self.bn(self.conv(x)))


class STGCNBlock(nn.Module):
    def __init__(self, in_ch, out_ch, num_joints, stride=1):
        super().__init__()
        self.spatial = SpatialGraphConv(in_ch, out_ch, num_joints)
        self.temporal = TemporalConv(out_ch, out_ch, stride=stride)
        
        self.residual = nn.Identity()
        if in_ch != out_ch or stride != 1:
            self.residual = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, (stride, 1)),
                nn.BatchNorm2d(out_ch)
            )
    
    def forward(self, x):
        return F.relu(self.temporal(self.spatial(x)) + self.residual(x))


class GestureSTGCN(nn.Module):
    GESTURES = ['neutral', 'open_palm', 'pointing', 'counting', 'steepling',
                'arms_crossed', 'fidgeting', 'hand_on_face', 'power_pose', 'shrug']
    
    def __init__(self, in_ch=3, num_joints=33, num_classes=10, hidden=64):
        super().__init__()
        self.num_joints = num_joints
        
        # Input batch norm
        self.data_bn = nn.BatchNorm1d(in_ch * num_joints)
        
        # ST-GCN blocks
        self.layer1 = STGCNBlock(in_ch, hidden, num_joints)
        self.layer2 = STGCNBlock(hidden, hidden, num_joints)
        self.layer3 = STGCNBlock(hidden, hidden * 2, num_joints, stride=2)
        self.layer4 = STGCNBlock(hidden * 2, hidden * 2, num_joints)
        self.layer5 = STGCNBlock(hidden * 2, hidden * 4, num_joints, stride=2)
        self.layer6 = STGCNBlock(hidden * 4, hidden * 4, num_joints)
        
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(hidden * 4, num_classes)
    
    def forward(self, x):
        # x: (B, C, T, V) = (batch, channels, time, vertices/joints)
        B, C, T, V = x.shape
        
        # Batch norm on flattened spatial features
        x = x.permute(0, 2, 1, 3).contiguous()  # (B, T, C, V)
        x = x.view(B * T, C * V)
        x = self.data_bn(x)
        x = x.view(B, T, C, V).permute(0, 2, 1, 3)  # Back to (B, C, T, V)
        
        # ST-GCN layers
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        
        # Global pooling and classify
        x = self.pool(x).squeeze(-1).squeeze(-1)
        return self.fc(x)

model = GestureSTGCN().to(device)
print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')

## Synthetic Dataset

In [None]:
class SyntheticGestureDataset(Dataset):
    def __init__(self, num_samples=1000, seq_len=30, num_joints=33, in_ch=3):
        # Shape: (N, C, T, V) = (samples, channels, time, vertices)
        self.data = torch.randn(num_samples, in_ch, seq_len, num_joints) * 0.1
        self.labels = torch.randint(0, 10, (num_samples,))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

train_dataset = SyntheticGestureDataset(800)
val_dataset = SyntheticGestureDataset(200)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

print(f'Train: {len(train_dataset)}, Val: {len(val_dataset)}')

## Training

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)

best_acc = 0
for epoch in range(20):
    model.train()
    train_loss = 0
    for x, y in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            correct += (model(x).argmax(1) == y).sum().item()
            total += y.size(0)
    
    val_acc = correct / total
    print(f'Epoch {epoch+1}: Loss={train_loss/len(train_loader):.4f}, Val Acc={val_acc:.4f}')
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), f'{SAVE_DIR}/gesture_stgcn.pt')
        print('  Saved!')

print(f'Best acc: {best_acc:.4f}')

## Download

In [None]:
from google.colab import files
files.download(f'{SAVE_DIR}/gesture_stgcn.pt')
print('Put gesture_stgcn.pt in modalx_v2/weights/')