This script will be used to train and run a transformer model for classifying basketball activity<br>
Users must manually input the data parquet files where each row is a separate basketball event that tracks player joints over a 21 frame window. Note this script takes in flattened data values.<br>
For example lets say we were analyzing a combination with 10 joints, the pose_data column contains a single flattend row of the 21 frames * 10 joints * 3 (x,y,z) position of each joint. This data needs to be unflattened before processed to easily map each coordinate to the correct joint

*The input dataframe has the following columns of interest*:

 ['game_id', 'stadium_id', 'player_id', 'team_id',
'event_seq_id', 'event_type', 'ball_position', 'player_com', 'pose_data', 'label']

# Libraries and Datasets

In [None]:
# Imports
import torch
import torch.nn as nn
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# import data files
# add file path location of the parquet files storing the joint data per event type.

dribble_file = ''
pass_file = ''
shot_file = ''
rebound_file = ''

# NOTE update this variable for how many joints we are looking at
num_joints = 7
# Load parquet file
dribble_df = pd.read_parquet(dribble_file)
passing_df = pd.read_parquet(pass_file)
shooting_df = pd.read_parquet(shot_file)
rebound_df = pd.read_parquet(rebound_file)

# Combine data from all event types into one dataframe
all_data = pd.concat([dribble_df, passing_df, shooting_df, rebound_df])

# Define the expected pose vector length NOTE this changes if I'm using 5 joints vs 6 vs 7 etc
# 21 frame pose window * number of joints * 3 (x, y, z coordinates)
expected_pose_length = 21 * num_joints * 3

# Filter out rows with inconsistent pose_data lengths (This removes incomplete pose windows)
all_data = all_data[all_data['pose_data'].apply(lambda x: len(x) == expected_pose_length if isinstance(x, (list, np.ndarray)) else False)]

# Map event_type to integer labels
event_type_mapping = {'dribble': 0, 'pass': 1, 'shot': 2, 'rebound': 3}
all_data['label'] = all_data['event_type'].map(event_type_mapping)

shapes = all_data['pose_data'].apply(lambda x: len(x) if isinstance(x, (list, np.ndarray)) else None)

com_shapes = all_data['player_com'].apply(lambda x: len(x) if isinstance(x, (list, np.ndarray)) else None)

ball_shapes = all_data['ball_position'].apply(lambda x: len(x) if isinstance(x, (list, np.ndarray)) else None)

# Combine pose_data, ball_position, and player_com
pose_data = torch.tensor(all_data['pose_data'].tolist(), dtype=torch.float32)  # Shape: [num_samples, 21 x 51]
ball_data = torch.tensor(all_data['ball_position'].tolist(), dtype=torch.float32)  # Shape: [num_samples, 21 x3]
com_data = torch.tensor(all_data['player_com'].tolist(), dtype=torch.float32)  # Shape: [num_samples, 21 x 3]

sequence_length = 21 # we have 21 frame windows
num_features = 3 # x,y,z

# We have to unflatten data here before concatenating to put into correct shape
pose_data = pose_data.view(-1, sequence_length, num_joints * num_features)
com_data = com_data.view(-1, sequence_length, num_features)

#combine pose and com data
combined_data = torch.cat((pose_data, com_data), dim=2)  # Shape: [num_samples, sequence_length, num_joints x features + 3]
labels = torch.tensor(all_data['label'].values, dtype=torch.long)  # Shape: [num_samples]

# Model Architecture


In [None]:
# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Define the transformer model
"""

The PoseTransformer model is a transformer model that takes as input a sequence of pose data and outputs a classification prediction. The model consists of the following components:

    An embedding layer to map the input pose data to the transformer embedding size.
    A learnable positional encoding to provide positional information to the transformer.
    A classification token to aggregate information from the transformer output.
    A transformer encoder that processes the input sequence.
    A classification head that takes the output of the classification token and predicts the class label.
    
"""
class PoseTransformer(nn.Module):
    def __init__(self, input_dim, sequence_length, num_joints, num_features, d_model, nhead, num_layers, num_classes, dropout=0.1):
        super(PoseTransformer, self).__init__()
        self.embedding = nn.Linear(input_dim, d_model)  # Map input to transformer embedding size
        self.positional_encoding = nn.Parameter(torch.zeros(sequence_length, d_model))  # Learnable positional encoding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))  # Learnable classification token.
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=1,
            dropout=dropout,
            batch_first=True  # Input shape: [Batch Size, Sequence Length, Features]
        )
        self.classifier = nn.Linear(d_model, num_classes)  # Classification head

    def forward(self, x):
        # Input shape: [Batch Size, Sequence Length, Input Dim]
        batch_size, seq_len, input_dim = x.shape

        # Embed the input features: [Batch Size, Sequence Length, d_model]
        x = self.embedding(x)
        
        # Add positional encoding
        x = x + self.positional_encoding.unsqueeze(0).expand(batch_size, -1, -1)

        # Prepare the classification token: [Batch Size, 1, d_model]
        cls_token = self.cls_token.expand(batch_size, -1, -1)

        # Transformer expects inputs in shape [Batch Size, Seq Len, d_model]
        # Concatenate cls_token to the beginning of the source sequence
        src = torch.cat([cls_token, x], dim=1)

        # Simple target: Use the same cls_token for simplicity. We are just using 1 layer decoder
        tgt = cls_token

        # Pass through Transformer
        transformer_output = self.transformer(src=src, tgt=tgt)

        # Use the output of the classification token for classification
        cls_output = transformer_output[:, 0, :]  # Output corresponding to the cls_token

        # Classification head
        output = self.classifier(cls_output)
        return output



The below cell balances dataset by augmenting smaller classes

In [6]:
def add_noise(data, noise_std=0.01):
    """
    Add Gaussian noise to the input data.
    
    Args:
    data (torch.Tensor): Input data
    noise_std (float): Standard deviation of the Gaussian noise
    
    Returns:
    torch.Tensor: Noisy data
    """
    noise = noise_std * torch.randn_like(data)
    return data + noise

# Defining Model Augmentation Functions and KFold Process

In [None]:
from torch.utils.data import DataLoader, TensorDataset, Subset
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
import seaborn as sns


### Save and Load Data ###
def save_checkpoint(model, optimizer, epoch, filename="best_model.pth"):

    """Save the model checkpoint."""
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch
    }
    torch.save(checkpoint, filename)
    print(f"Model saved as {filename}")

def load_checkpoint(model, optimizer, filename="best_model.pth", device='cuda'):
    """Load the model checkpoint."""
    checkpoint = torch.load(filename, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    print(f"Model loaded from {filename} (Resuming from epoch {start_epoch})")
    return model, optimizer, start_epoch



def augment_data(train_data, train_labels):
    """ Augments data by oversampling and adding noise. """
    train_dribble_data = train_data[train_labels == 0]
    train_pass_data = train_data[train_labels == 1]
    train_shot_data = train_data[train_labels == 2]
    train_rebound_data = train_data[train_labels == 3]

    # Define target size as the largest class (dribble in this case)
    target_size = len(train_dribble_data)
    num_pass = len(train_pass_data)
    num_shot = len(train_shot_data)
    num_rebound = len(train_rebound_data)
    pass_augment_size = target_size - num_pass
    shot_augment_size = target_size - num_shot
    rebound_augment_size = target_size - num_rebound

    #lets downsample passes. NOTE: after mirroring data across y axis we had more passes than dribbles so need to downsample
    train_pass_data = train_pass_data[torch.randperm(len(train_pass_data))[:target_size]]

    # Oversample shots and rebounds
    shot_indices = torch.randint(0, num_shot, (shot_augment_size,))
    rebound_indices = torch.randint(0, num_rebound, (rebound_augment_size,))

    shot_oversampled = train_shot_data[shot_indices]
    rebound_oversampled = train_rebound_data[rebound_indices]

    # Apply random jitter to augmented data
    shot_jittered_augmented = add_noise(shot_oversampled, noise_std=0.02)
    rebound_jittered_augmented = add_noise(rebound_oversampled, noise_std=0.02)

    # Combine oversampled and jittered data
    balanced_train_pass_data = train_pass_data
    balanced_train_shot_data = torch.cat([train_shot_data, shot_jittered_augmented], dim=0)
    balanced_train_rebound_data = torch.cat([train_rebound_data, rebound_jittered_augmented], dim=0)

    print("Dribble data size:", len(train_dribble_data))
    print("Pass data size:", len(balanced_train_pass_data))
    print("Shot data size:", len(balanced_train_shot_data))
    print("Rebound data size:", len(balanced_train_rebound_data))
    # Combine all classes back together
    augmented_train_data = torch.cat([train_dribble_data, balanced_train_pass_data, balanced_train_shot_data, balanced_train_rebound_data], dim=0)
    augmented_train_labels = torch.cat([
        torch.zeros(len(train_dribble_data), dtype=torch.long),  # Dribble labels (0)
        torch.ones(len(balanced_train_pass_data), dtype=torch.long),  # Pass labels (1)
        torch.full((len(balanced_train_shot_data),), 2, dtype=torch.long),  # Shot labels (2)
        torch.full((len(balanced_train_rebound_data),), 3, dtype=torch.long)  # Rebound labels (3)
    ])

    return augmented_train_data, augmented_train_labels

def train(model, train_loader, criterion, optimizer, device):
    # Training model with fixed numbre of epochs
   
    model.train()
    total_loss = 0.0
    for i, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def evaluate(model, val_loader, criterion, device):
    num_classes = 4
    # Evaluate model on validation set
    model.eval()
    total_loss = 0.0
    correct_preds = 0
    total_preds = 0
    # Per-class accuracy tracking
    class_correct = [0] * num_classes
    class_total = [0] * num_classes

    with torch.no_grad():
        for data, labels in val_loader:
            data, labels = data.to(device), labels.to(device)
            outputs = model(data)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, dim=1)
            correct_preds += (predicted == labels).sum().item()
            total_preds += len(labels)

            # Track per-class accuracy
            for label, pred in zip(labels, predicted):
                class_correct[label] += (label == pred).item()
                class_total[label] += 1
    # Overall Accuracy
    val_loss = total_loss / len(val_loader)
    val_accuracy = correct_preds / total_preds * 100

    # Per-class accuracy
    class_accuracies = [
        100 * class_correct[i] / class_total[i] if class_total[i] > 0 else 0.0
        for i in range(num_classes)
    ]
    return val_loss, val_accuracy, class_accuracies

def k_fold_strat_cross_validation(model_class, dataset, input_dim, k=5,epochs=15, batch_size=64,
                            d_model=256, nhead=8, num_layers=6, dropout=0.1, 
                            learning_rate=5e-5, num_classes=4, checkpoint_path='best_model.pth'):
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    class_names = ['Dribble', 'Pass', 'Shot', 'Rebound']
    # Labels array for stratification (one label per dataset item)
    labels = np.array([dataset[i][1] for i in range(len(dataset))])

    # stratified Kfold
    skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)

    best_accuracy = 0.0 # Track the best accuracy
    best_model_state = None # Track the best model state dict

    fold_accuracies = []  
    fold_class_accuracies = {i: [] for i in range(num_classes)}       
    fold_losses = []             
    all_training_losses = []     

    for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(labels)), labels)):
        print(f"Fold {fold + 1}/{k}")

        # pre-augmentation sample sizes and class mix
        y_tr = labels[train_idx]
        y_va = labels[val_idx]
        tr_counts = np.bincount(y_tr, minlength=num_classes) # get training class counts
        va_counts = np.bincount(y_va, minlength=num_classes) # get validation class counts

        print(f"[Pre-aug] train_n={len(train_idx)}  val_n={len(val_idx)}")
        print("train per-class:", {cls: int(tr_counts[i]) for i, cls in enumerate(class_names)})
        print("val   per-class:", {cls: int(va_counts[i]) for i, cls in enumerate(class_names)})

        train_data, train_labels = zip(*[dataset[i] for i in train_idx])
        train_data = torch.stack(train_data)
        train_labels = torch.tensor(train_labels)

        # We still need to augment as there are more dribbles than other classes
        augmented_train_data, augmented_train_labels = augment_data(train_data, train_labels)

        train_loader = DataLoader(TensorDataset(augmented_train_data, augmented_train_labels), batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(Subset(dataset, val_idx), batch_size=batch_size, shuffle=False)

        # Initialize the model, criterion, and optimizer
        model = model_class(input_dim, sequence_length, num_joints, num_features, 
                            d_model, nhead, num_layers, num_classes, dropout).to(device)
        
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

        # Training loop
        training_losses = []
        for epoch in range(epochs):
            train_loss = train(model, train_loader, criterion, optimizer, device)
            training_losses.append(train_loss)
            print(f"Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}")

        # Evaluate on validation set
        val_loss, val_accuracy, class_accuracies = evaluate(model, val_loader, criterion, device)
        print(f"Fold {fold + 1} Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%\n")

        fold_accuracies.append(val_accuracy)
        fold_losses.append(val_loss)
        all_training_losses.append(training_losses)

        for class_idx in range(num_classes):
            fold_class_accuracies[class_idx].append(class_accuracies[class_idx])

        # Save the best model based on highest validation accuracy
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            best_model_state = model.state_dict()
            save_checkpoint(model, optimizer, epoch, filename=checkpoint_path)


    # Performance Metrics
    mean_accuracy = np.mean(fold_accuracies)
    std_accuracy = np.std(fold_accuracies)
    min_accuracy = np.min(fold_accuracies)
    max_accuracy = np.max(fold_accuracies)

    print(f"Mean Accuracy: {mean_accuracy:.2f}%")
    print(f"Accuracy Variance: {std_accuracy:.2f}")
    print(f"Worst Fold Accuracy: {min_accuracy:.2f}%")
    print(f"Best Fold Accuracy: {max_accuracy:.2f}%")

    # Generate 95% confidence interval
    confidence_interval = 1.96 * std_accuracy / np.sqrt(k)
    ci_lower = mean_accuracy - confidence_interval
    ci_upper = mean_accuracy + confidence_interval
    print(f"Confidence Interval: {confidence_interval}")
    print(f"95% Confidence Range: ({ci_lower:.2f} - {ci_upper:.2f})")

    
    # Per-Class Accuracy Metrics
    
    for class_idx in range(num_classes):
        class_mean = np.mean(fold_class_accuracies[class_idx])
        class_std = np.std(fold_class_accuracies[class_idx])
        class_ci_range = 1.96 * (class_std / np.sqrt(k))

        print(f"Class {class_names[class_idx]} Accuracy: {class_mean:.2f}% ± {class_ci_range:.2f}%"
              f" (Range: {class_mean - class_ci_range:.2f}% to {class_mean + class_ci_range:.2f}%)")


  
    # Visualize Graphs
    # Accuracy per Fold
    plt.figure(figsize=(8, 5))
    plt.plot(range(1, k + 1), fold_accuracies, marker='o', linestyle='-', label='Fold Accuracy')
    plt.title("Fold-Wise Accuracy")
    plt.xlabel("Fold")
    plt.ylabel("Accuracy (%)")
    plt.xticks(range(1, k + 1))
    plt.grid(True)
    plt.show()

    # Per-Class Accuracy Plot
    plt.figure(figsize=(10, 6))
    for class_idx in range(num_classes):
        class_acc = fold_class_accuracies[class_idx]
        class_ci = 1.96 * (np.std(class_acc) / np.sqrt(k))

        plt.errorbar(
            x=range(1, k + 1),
            y=class_acc,
            yerr=class_ci,
            fmt='o-', capsize=5, label=f'Class {class_names[class_idx]} ± CI'
        )

    plt.title("Per-Class Accuracy with Confidence Intervals J4")
    plt.xlabel("Fold")
    plt.ylabel("Accuracy (%)")
    plt.grid(False)
    plt.legend()
    plt.show()

    # Training Loss per Epoch for Each Fold
    plt.figure(figsize=(10, 6))
    for i, losses in enumerate(all_training_losses):
        plt.plot(range(1, epochs + 1), losses, label=f'Fold {i + 1}')
    plt.title("Training Loss per Epoch")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(False)
    plt.show()

    # Accuracy Distribution (Box Plot)
    plt.figure(figsize=(8, 5))
    sns.boxplot(fold_accuracies)
    plt.title("Accuracy Distribution Across Folds J4")
    plt.xlabel("K-Fold Results")
    plt.ylabel("Accuracy (%)")
    plt.show()

    # Visualizing Confidence Intervals
    plt.figure(figsize=(8, 5))
    plt.errorbar(
        x=range(1, k + 1),
        y=fold_accuracies,
        yerr=confidence_interval,
        fmt='o', capsize=5, label='Fold Accuracy ± CI'
    )
    plt.title("K-Fold Accuracy with Confidence Intervals J4")
    plt.xlabel("Fold")
    plt.ylabel("Accuracy (%)")
    plt.grid(False)
    plt.legend()
    plt.show()

    # Lets return the best model state
    model.load_state_dict(best_model_state) # it may not be the last fold which is why we have to load
    return model



# Train, Evaluate, and save Best Model

In [None]:

# Wrap train/val data in TensorDataset for K-Fold
train_val_dataset = TensorDataset(combined_data, labels)

# Define model checkpoint path. This is location where we save the model
checkpoint_path = ' '

# Define model parameters. Input the best model parameters from the hyperparmeter tuning
input_dim = num_joints * num_features + 3
sequence_length = 21
num_classes = 4
d_model = 128
nhead = 8
num_layers = 6
dropout = 0.1
lr = 0.00010236630901374436
num_epochs = 15
batch_size = 64


# Run KFold cross validation
k_test = 5


best_model = k_fold_strat_cross_validation(PoseTransformer, train_val_dataset, input_dim=input_dim, k=k_test, epochs=num_epochs, batch_size=batch_size,
                        d_model=d_model, nhead=nhead, num_layers=num_layers, dropout=dropout, 
                        learning_rate=lr, num_classes=num_classes, checkpoint_path=checkpoint_path)