In [1]:
import sys
import os

# Go up to project root (from inside training/)
project_root = os.path.abspath(os.path.join(os.getcwd(), '../..'))
if project_root not in sys.path:
    sys.path.append(project_root)

In [2]:
import numpy as np
import torch
import torch.nn as nn
import gc
import random
from core.models.hierarchical_transformer import HierarchicalTransformer
from core.utils import create_transformer_dataset, TransformerLRScheduler
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import precision_recall_fscore_support


In [3]:
def initial_cleanup():
    # Memory cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    gc.collect()
    
    # Set random seeds
    random.seed(69)
    np.random.seed(69)
    torch.manual_seed(69)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(69)
        torch.cuda.manual_seed_all(69)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    print("Clean slate initialized!")

# Call this at the top of your notebook
initial_cleanup()

Clean slate initialized!


In [4]:
X_np, y_np, attention_masks_np, sequence_lengths_np = create_transformer_dataset(data_dir="../../data/keypoints",verbose=False)

Using max_frames = 200 (95.0th percentile)
Sequence length stats - Min: 0, Max: 404
Filtering out sample with length 236
Filtering out sample with length 236
Filtering out sample with length 267
Filtering out sample with length 267
Filtering out sample with length 237
Filtering out sample with length 237
Filtering out sample with length 248
Filtering out sample with length 248
Filtering out sample with length 207
Filtering out sample with length 207
Filtering out sample with length 219
Filtering out sample with length 219
Filtering out sample with length 280
Filtering out sample with length 280
Filtering out sample with length 249
Filtering out sample with length 249
Filtering out sample with length 247
Filtering out sample with length 247
Filtering out sample with length 210
Filtering out sample with length 210
Filtering out sample with length 215
Filtering out sample with length 215
Filtering out sample with length 283
Filtering out sample with length 283
Filtering out sample with le

In [5]:
X_np = X_np[:, :, :, :3]  # shape: (N, F, J, 3)

In [6]:
from sklearn.model_selection import StratifiedKFold

In [7]:
parameters = dict(
    num_joints=33,
    num_frames=201,
    d_model=64,
    nhead=2,
    num_spatial_layers=1,
    num_temporal_layers=1,
    num_classes=3,
    dim_feedforward=2048,
    dropout=0.1
)

batch_size = 16
epochs = 64
k = 5  # K-Fold

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prepare K-Fold Cross-Validation
skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=69)

fold_results = []
precision_scores = []
recall_scores = []
f1_scores = []

In [None]:
for fold, (train_idx, val_idx) in enumerate(skf.split(X_np, y_np)):
    print(f"\n========== Fold {fold+1}/{k} ==========")

    # Split the data
    X_train_fold, X_val_fold = X_np[train_idx], X_np[val_idx]
    y_train_fold, y_val_fold = y_np[train_idx], y_np[val_idx]
    mask_train_fold, mask_val_fold = attention_masks_np[train_idx], attention_masks_np[val_idx]

    # Convert to tensors
    X_train_tensor = torch.from_numpy(X_train_fold).float()
    y_train_tensor = torch.from_numpy(y_train_fold).long()
    mask_train_tensor = torch.from_numpy(mask_train_fold).float()

    X_val_tensor = torch.from_numpy(X_val_fold).float()
    y_val_tensor = torch.from_numpy(y_val_fold).long()
    mask_val_tensor = torch.from_numpy(mask_val_fold).float()

    # Create datasets and loaders
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor, mask_train_tensor)
    val_dataset = TensorDataset(X_val_tensor, y_val_tensor, mask_val_tensor)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Initialize model and training components
    model = HierarchicalTransformer(**parameters).to(device)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
    optimizer = optim.AdamW(model.parameters(), lr=5e-3, betas=(0.9, 0.98), eps=1e-9, weight_decay=1e-4)
    num_steps_per_epoch = len(train_loader)
    warmup_steps = int(0.1 * num_steps_per_epoch * epochs)
    scheduler = TransformerLRScheduler(optimizer, d_model=parameters['d_model'], warmup_steps=warmup_steps)

    # Tracking metrics
    best_val_loss = float('inf')
    patience = 5
    epochs_no_improve = 0

    val_loss_history = []
    val_acc_history = []

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        correct = 0

        for X_batch, y_batch, mask_batch in train_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)
            mask_batch = mask_batch.to(device)

            optimizer.zero_grad()
            outputs = model(x=X_batch, temporal_mask=mask_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss += loss.item() * X_batch.size(0)
            correct += (outputs.argmax(1) == y_batch).sum().item()

        train_loss /= len(train_loader.dataset)
        train_acc = correct / len(train_loader.dataset)

        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for X_batch, y_batch, mask_batch in val_loader:
                X_batch = X_batch.to(device)
                y_batch = y_batch.to(device)
                mask_batch = mask_batch.to(device)

                outputs = model(x=X_batch, temporal_mask=mask_batch)
                loss = criterion(outputs, y_batch)

                val_loss += loss.item() * X_batch.size(0)
                preds = outputs.argmax(1)
                val_correct += (preds == y_batch).sum().item()

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(y_batch.cpu().numpy())

        val_loss /= len(val_loader.dataset)
        val_acc = val_correct / len(val_loader.dataset)

        val_loss_history.append(val_loss)
        val_acc_history.append(val_acc)

        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}.")
            break

    # Compute precision, recall, and f1 score for this fold
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average='macro', zero_division=0
    )

    precision_scores.append(precision)
    recall_scores.append(recall)
    f1_scores.append(f1)

    fold_results.append({
        'val_acc': val_acc_history[-1],
        'val_loss': val_loss_history[-1],
        'precision': precision,
        'recall': recall,
        'f1': f1
    })

# --- Final Summary ---
val_accs = [res['val_acc'] for res in fold_results]
val_losses = [res['val_loss'] for res in fold_results]

print("\n========== K-Fold Cross-Validation Summary ==========")
print(f"Average Validation Accuracy: {np.mean(val_accs):.4f} ± {np.std(val_accs):.4f}")
print(f"Average Validation Loss:     {np.mean(val_losses):.4f} ± {np.std(val_losses):.4f}")
print(f"Average Precision:           {np.mean(precision_scores):.4f} ± {np.std(precision_scores):.4f}")
print(f"Average Recall:              {np.mean(recall_scores):.4f} ± {np.std(recall_scores):.4f}")
print(f"Average F1 Score:            {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")
print("=====================================================")


