This script will be used to train and run a GCN model for classifying basketball activity<br>
Users must manually input the data parquet files where each row is a separate basketball event that tracks player joints over a 21 frame window. Note this script takes in flattened data values.<br>
For example lets say we were analyzing a combination with 10 joints, the pose_data column contains a single flattend row of the 21 frames * 10 joints * 3 (x,y,z) position of each joint. This data needs to be unflattened before processed to easily map each coordinate to the correct joint

*The input dataframe has the following columns of interest*:

 ['game_id', 'stadium_id', 'player_id', 'team_id',
'event_seq_id', 'event_type', 'ball_position', 'player_com', 'pose_data', 'label']

# Data processing and Graph Creation

In [None]:
#Import libraries
import os
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as GeometricDataLoader
from torch_geometric.nn import GCNConv
import torch.nn as nn
from torch.nn import GRU

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import KFold, StratifiedKFold
import matplotlib.pyplot as plt
import optuna
import seaborn as sns


# Define Global Parameters

NUM_JOINTS = 16 #  NOTE: Adjust if your data uses a different number of joints


# Explicitly define the joint columns and limb pairs for the GCN
joint_columns = ['rShoulder', 'lShoulder', 'rElbow', 'lElbow', 'rWrist', 'lWrist', 'neck', 'rKnee', 'lKnee', 'rAnkle', 'lAnkle', 'rHip', 'lHip', 'midHip', 'rHeel', 'lHeel'] 
# Unlike the transformer, for the GCN we have to also define the graph struture where limb pairs define our edges

limb_pairs = [  
    ('lElbow', 'lWrist'),
    ('rElbow', 'rWrist'),
    ('lShoulder', 'lElbow'),
    ('rShoulder', 'rElbow'),
    ('lShoulder', 'neck'),
    ('rShoulder', 'neck'),
    ('neck', 'midHip'),
    ('midHip', 'lHip'),
    ('midHip', 'rHip'),
    ('lShoulder', 'lHip'),
    ('rShoulder', 'rHip'),
    ('lHip', 'lKnee'), 
    ('rHip', 'rKnee'),
    ('lKnee', 'lAnkle'),
    ('rKnee', 'rAnkle'),
    ('lAnkle', 'lHeel'),
    ('rAnkle', 'rHeel')
] 
 


# Create Graph Edge Index

def create_edge_index():
    edge_index = []
    joint_to_index = {joint: i for i, joint in enumerate(joint_columns)}
    for limb in limb_pairs:
        edge_index.append((joint_to_index[limb[0]], joint_to_index[limb[1]]))
        edge_index.append((joint_to_index[limb[1]], joint_to_index[limb[0]]))  # both directions
    return torch.tensor(edge_index, dtype=torch.long).t().contiguous()

EDGE_INDEX = create_edge_index()


# Load Data

def load_data(dribble_file, pass_file, shot_file, rebound_file):
    """
    Loads parquet files, concatenates, filters out invalid pose data,
    and maps string event_type to numeric labels: (0=dribble, 1=pass, 2=shot, 3=rebound).
    """
    dribble_df = pd.read_parquet(dribble_file)
    pass_df    = pd.read_parquet(pass_file)
    shot_df    = pd.read_parquet(shot_file)
    rebound_df = pd.read_parquet(rebound_file)

    all_data = pd.concat([dribble_df, pass_df, shot_df, rebound_df], ignore_index=True)

    # Filter out rows where pose_data is not the right length
    expected_pose_len = 21 * NUM_JOINTS * 3
    all_data = all_data[
        all_data['pose_data'].apply(
            lambda x: len(x) == expected_pose_len if isinstance(x, (list, np.ndarray)) else False
        )
    ]

    # Map event_type to label
    event_mapping = {'dribble': 0, 'pass': 1, 'shot': 2, 'rebound': 3}
    all_data['label'] = all_data['event_type'].map(event_mapping)

    return all_data

# Split Dataframe into Train, Val, Test

def split_dataframe(all_data):
    """
    Splits the data into train (60%), val (20%), test (20%) with stratification by 'label'.
    """
    train_df, temp_df = train_test_split(
        all_data, test_size=0.4, random_state=42, stratify=all_data['label']
    )
    val_df, test_df = train_test_split(
        temp_df, test_size=0.5, random_state=42, stratify=temp_df['label']
    )
    print(f"Train size: {len(train_df)}, Val size: {len(val_df)}, Test size: {len(test_df)}")
    return train_df, val_df, test_df


# Balance Training Data
# Downsample passes if bigger than dribbles,
# Upsample shots & rebounds if smaller than dribbles.

def balance_train_df(train_df, noise_for_upsampled=0.02):
    """
    - Dribble -> label=0
    - Pass -> label=1
    - Shot -> label=2
    - Rebound -> label=3
    
    We'll match the size of the dribble class. 
    """
    dribble_df = train_df[train_df['label'] == 0]
    pass_df    = train_df[train_df['label'] == 1]
    shot_df    = train_df[train_df['label'] == 2]
    rebound_df = train_df[train_df['label'] == 3]

    # Target size is the dribble class size
    target_size = len(dribble_df)
    print(f"Dribble size => {target_size}")

    # Downsample Pass if pass_df is larger than dribble_df
    if len(pass_df) > target_size:
        pass_df = pass_df.sample(n=target_size, replace=False, random_state=42)
        print(f"Downsampled pass to {len(pass_df)}")
    else:
        print(f"Pass count => {len(pass_df)} (no downsampling needed)")

    # Upsample Shots if smaller
    if len(shot_df) < target_size:
        diff = target_size - len(shot_df)
        # randomly select indicies to upsample
        upsample_indices = np.random.choice(shot_df.index, size=diff, replace=True)
        shot_upsampled = shot_df.loc[upsample_indices]
        # add random jitter to the new upsampled joint positions
        shot_upsampled_aug = shot_upsampled.apply(
            lambda row: noise_augment_pose(row, noise_level=noise_for_upsampled), axis=1
        )
        shot_df = pd.concat([shot_df, shot_upsampled_aug], ignore_index=True)
        print(f"Upsampled shots to {len(shot_df)}")

    # Upsample Rebounds if smaller
    if len(rebound_df) < target_size:
        diff = target_size - len(rebound_df)
        upsample_indices = np.random.choice(rebound_df.index, size=diff, replace=True)
        rebound_upsampled = rebound_df.loc[upsample_indices]
        # add random jitter to the new upsampled joint positions
        rebound_upsampled_aug = rebound_upsampled.apply(
            lambda row: noise_augment_pose(row, noise_level=noise_for_upsampled), axis=1
        )
        rebound_df = pd.concat([rebound_df, rebound_upsampled_aug], ignore_index=True)
        print(f"Upsampled rebounds to {len(rebound_df)}")

    balanced_train = pd.concat([dribble_df, pass_df, shot_df, rebound_df], ignore_index=True)
    print(f"Balanced train size => {len(balanced_train)}")
    return balanced_train


# Data Augmentation Helper
# Added random noise to pose data when upsampling
def noise_augment_pose(row, noise_level=0.02):
    """
    row: a single row containing row['pose_data'] with shape [21*NUM_JOINTS*3].
    Returns a copy of the row with added Gaussian noise on 'pose_data'.
    """
    pose_arr = np.array(row['pose_data']).reshape(21, NUM_JOINTS, 3)
    noise = np.random.normal(loc=0, scale=noise_level, size=pose_arr.shape)
    augmented = pose_arr + noise

    row_copy = row.copy()
    row_copy['pose_data'] = augmented.flatten()
    return row_copy


# Convert Dataframe Rows to PyTorch Geometric Data
# This defines the graph structure for the GCN

def convert_df_to_pyg_dataset(df):
    data_list = []
    for _, row in df.iterrows():
        # Reshape -> [21, NUM_JOINTS, 3]
        pose_data = np.array(row['pose_data']).reshape(21, NUM_JOINTS, 3)
        velocity = np.diff(pose_data, axis=0, prepend=pose_data[0:1])
        acceleration = np.diff(velocity, axis=0, prepend=velocity[0:1])
        # Reshape -> [21, NUM_JOINTS, 9] if using pose + velocity + acceleration
        combined = np.concatenate([pose_data, velocity, acceleration], axis=-1)
        node_feats = torch.tensor(combined.reshape(-1, 9), dtype=torch.float)
        # NOTE:To just use pose_data uncomment below:
        # node_feats = torch.tensor(pose_data.reshape(-1, 3), dtype=torch.float)

        label = torch.tensor(row['label'], dtype=torch.long)
        graph = Data(x=node_feats, edge_index=EDGE_INDEX, y=label)
        data_list.append(graph)
    return data_list



# GCN Model Architecture

In [None]:

# Define GCN Model (EnhancedActionClassifierGCN)
class EnhancedActionClassifierGCN(nn.Module):
    def __init__(self, node_feature_dim, hidden_dim, num_classes, dropout):
        super(EnhancedActionClassifierGCN, self).__init__()
        self.conv1 = GCNConv(node_feature_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.conv4 = GCNConv(hidden_dim, hidden_dim)
        self.gru = GRU(hidden_dim, hidden_dim, batch_first=True)
        self.dropout = nn.Dropout(p=dropout)
        self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc2 = nn.Linear(hidden_dim // 2, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        # Graph convolutions
        x = torch.relu(self.conv1(x, edge_index))
        x = self.dropout(x)
        x = torch.relu(self.conv2(x, edge_index))
        x = self.dropout(x)
        x = torch.relu(self.conv3(x, edge_index))
        x = self.dropout(x)
        x = torch.relu(self.conv4(x, edge_index))
        x = self.dropout(x)

        # Reshape for GRU
        batch_size = data.batch.max().item() + 1
        seq_length = x.size(0) // batch_size // NUM_JOINTS
        x = x.view(batch_size, seq_length, NUM_JOINTS, -1)  # [B, seq_len, num_joints, hidden_dim]
        x = x.mean(dim=2)  # average across joints -> [B, seq_len, hidden_dim]

        # GRU
        x, _ = self.gru(x)       # [B, seq_len, hidden_dim]
        x = x[:, -1, :]          # take last timestep -> [B, hidden_dim]

        # Fully connected layers
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)



# Model Training Logic

In [None]:

# Train and Evaluate the GCN Model

def train_epoch(loader, model, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(loader, model, device):
    model.eval()
    num_classes = 4
    correct = 0
    total = 0
    class_correct = [0] * num_classes
    class_total = [0] * num_classes
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            out = model(batch)
            pred = out.argmax(dim=1)
            correct += (pred == batch.y).sum().item()
            total += batch.y.size(0)

            # Per-Class Accuracy Tracking
            for label, prediction in zip(batch.y, pred):
                class_total[label] += 1
                if label == prediction:
                    class_correct[label] += 1
    overall_acc = 100 * correct / total 
    class_accuracies = [
        100 * class_correct[i] / class_total[i] if class_total[i] > 0 else 0.0
        for i in range(num_classes)
    ]

    return overall_acc, class_accuracies


# Generate Confusion Matrix to evaluate model performance

def generate_confusion_matrix(model, test_data, device):
    test_loader = GeometricDataLoader(test_data, batch_size=32, shuffle=False)
    all_preds = []
    all_labels = []

    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            out =+ model(batch)
            preds = out.argmax(dim=1).cpu().numpy()
            labels = batch.y.cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels)

    cm = confusion_matrix(all_labels, all_preds, labels=[0, 1, 2, 3])
    # Normalize the confusion matrix by dividing each row by the sum of its elements
    # This is to display percentages
    conf_matrix_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    # Calculate overall accuracy
    total_correct = np.trace(cm)
    total = np.sum(cm)
    overall_accuracy = total_correct / total
    print(f"Overall Accuracy: {overall_accuracy:.4f}")

    # Calculate per class accuracy
    # The diagonal of the normalized matrix contains the per-class accuracy
    per_class_acc = np.diagonal(conf_matrix_normalized)
    class_names = ["dribble","pass","shot","rebound"]
    print("\n Per-Class Accuracy (from normalized diagonal):")
    for i, acc in enumerate(per_class_acc):
        print(f"  {class_names[i]}: {acc:.4f}")
    
    # Display the confusion matrix
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(cmap=plt.cm.Blues)
    plt.title(f"Confusion Matrix J{NUM_JOINTS}")
    plt.show()

    # Display the percentage confusion matrix
    disp_percentage = ConfusionMatrixDisplay(confusion_matrix=conf_matrix_normalized, display_labels=class_names)
    disp_percentage.plot(cmap='Blues', values_format='.02%')  # Format to show percentages with 2 decimal places
    plt.title(f"Confusion Matrix (Percentages) J{NUM_JOINTS}")
    plt.show()


# Stratified Kfold processing
def regular_k_fold_evaluation(dataframe, k=5, best_params=None):
    from torch_geometric.loader import DataLoader as GeometricDataLoader
    from sklearn.model_selection import KFold

    # Setup GPU for training
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    num_classes = 4
    # define model parameters
    hidden_dim   = best_params["hidden_dim"]
    dropout      = best_params["dropout"]
    learning_rate= best_params["lr"]
    batch_size   = best_params["batch_size"]

    # need to generate labels for stratified k fold
    labels = dataframe['label'].values

    #lets use stratified k fold to maintain class distribution
    skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)

    fold_accuracies = []
    fold_class_accuracies = {i: [] for i in range(num_classes)}
    best_model_state = None
    best_fold_accuracy = 0
    all_training_losses = []   # track training loss per epoch for each fold

    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(dataframe, labels)):
        print(f"\nFold {fold_idx + 1}/{k}")

        # Pre-Augmentation sample sizes and class mix
        y_tr = labels[train_idx]
        y_va = labels[val_idx]
        tr_counts = np.bincount(y_tr, minlength=num_classes)
        va_counts = np.bincount(y_va, minlength=num_classes)
        
        class_names = ['Dribble', 'Pass', 'Shot', 'Rebound']
        print(f"[Pre-aug] train_n={len(train_idx)}  val_n={len(val_idx)}")
        print("train per-class:", {cls: int(tr_counts[i]) for i, cls in enumerate(class_names)})
        print("val   per-class:", {cls: int(va_counts[i]) for i, cls in enumerate(class_names)})


        train_df = dataframe.iloc[train_idx].reset_index(drop=True)
        val_df = dataframe.iloc[val_idx].reset_index(drop=True)

        balanced_train_df = balance_train_df(train_df)
        train_data = convert_df_to_pyg_dataset(balanced_train_df)
        val_data = convert_df_to_pyg_dataset(val_df)

        train_loader = GeometricDataLoader(train_data, batch_size=batch_size, shuffle=True)
        val_loader = GeometricDataLoader(val_data, batch_size=batch_size, shuffle=False)

        # If using only pose data then node_feature_dim = 3
        # If using pose + velocity + acceleration then node_feature_dim = 9
        model = EnhancedActionClassifierGCN(
            node_feature_dim=9,
            hidden_dim=hidden_dim,
            num_classes=4,
            dropout=dropout
        ).to(device)

        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        criterion = nn.CrossEntropyLoss()

        # Training loop
        training_losses = []
        for epoch in range(10):
            train_loss = train_epoch(train_loader, model, optimizer, criterion, device)
            training_losses.append(train_loss)
            print(f"Epoch {epoch+1}/10 - Loss: {train_loss:.4f}")

        fold_accuracy, class_accuracies = evaluate(val_loader, model, device)
        fold_accuracies.append(fold_accuracy)
        all_training_losses.append(training_losses)

        if fold_accuracy > best_fold_accuracy:
            best_fold_accuracy = fold_accuracy
            best_model_state = model.state_dict()
            # Save the best model
            torch.save(best_model_state, f"temporal_gcn_kfold_best_model_J{NUM_JOINTS}.pth")
            print("Best model saved after K-Fold.")

        for class_idx in range(num_classes):
            fold_class_accuracies[class_idx].append(class_accuracies[class_idx])
    
    
    
    # Performance Metrics
    mean_accuracy = np.mean(fold_accuracies)
    std_accuracy = np.std(fold_accuracies)
    ci_range = 1.96 * (std_accuracy / np.sqrt(k)) #95% CI

    print(f"\nMean Accuracy: {mean_accuracy:.2f}% ± {ci_range:.2f}%")
    print(f"Accuracy Variance: {std_accuracy:.2f}")

    # Per-Class Accuracy
    class_names = ['Dribble', 'Pass', 'Shot', 'Rebound']
    for class_idx in range(num_classes):
        class_mean = np.mean(fold_class_accuracies[class_idx])
        class_std = np.std(fold_class_accuracies[class_idx])
        class_ci = 1.96 * class_std / np.sqrt(k)
        print(f"Class {class_names[class_idx]} Accuracy: {class_mean:.2f}% ± {class_ci:.2f}%")


    # Per-Class Accuracy Plot
    plt.figure(figsize=(10, 6))
    for class_idx in range(num_classes):
        class_acc = fold_class_accuracies[class_idx]
        class_ci = 1.96 * (np.std(class_acc) / np.sqrt(k))

        plt.errorbar(
            x=range(1, k + 1),
            y=class_acc,
            yerr=class_ci,
            fmt='o-', capsize=5, label=f'Class {class_names[class_idx]} ± CI'
        )

    plt.title(f"Per-Class Accuracy with Confidence Intervals J{NUM_JOINTS}")
    plt.xlabel("Fold")
    plt.ylabel("Accuracy (%)")
    plt.grid(False)
    plt.legend()
    plt.show()

    # Training Loss per Epoch for Each Fold
    plt.figure(figsize=(10, 6))
    for i, losses in enumerate(all_training_losses):
        plt.plot(range(1, 10 + 1), losses, label=f'Fold {i + 1}')
    plt.title(f"Training Loss per Epoch J{NUM_JOINTS}")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(False)
    plt.show()

    # Accuracy Distribution (Box Plot)
    plt.figure(figsize=(8, 5))
    sns.boxplot(fold_accuracies)
    plt.title(f"Accuracy Distribution Across Folds J{NUM_JOINTS}")
    plt.xlabel("K-Fold Results")
    plt.ylabel("Accuracy (%)")
    plt.show()

    
    # Visualizing Confidence Intervals
    
    plt.figure(figsize=(8, 5))
    plt.errorbar(
        x=range(1, k + 1),
        y=fold_accuracies,
        yerr=ci_range,
        fmt='o', capsize=5, label='Fold Accuracy ± CI'
    )
    plt.title(f"K-Fold Accuracy with Confidence Intervals J{NUM_JOINTS}")
    plt.xlabel("Fold")
    plt.ylabel("Accuracy (%)")
    plt.grid(False)
    plt.legend()
    plt.show()

    # Lets return the best model state
    model.load_state_dict(best_model_state) # it may not be the last fold which is why we have to load
    return model

def train_best_model(best_params):
    """
    Retrain a final model with the best params.
    """
    hidden_dim   = best_params["hidden_dim"]
    dropout      = best_params["dropout"]
    learning_rate= best_params["lr"]
    batch_size   = best_params["batch_size"]

    model = EnhancedActionClassifierGCN(
        node_feature_dim=9,
        hidden_dim=hidden_dim,
        num_classes=4,
        dropout=dropout
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    train_loader = GeometricDataLoader(train_data, batch_size=batch_size, shuffle=True)

    epochs_for_final = 10
    for epoch in range(epochs_for_final):
        train_loss = train_epoch(train_loader, model, optimizer, criterion, device)
        print(f"Epoch {epoch+1}/{epochs_for_final} - Train Loss: {train_loss:.4f}")

    return model



# Run Main Function to Process and Evaluate Model

In [None]:

# Main function
def main():
    # Load your data from parquet files of joint data for each event type
    dribble_file = " "
    pass_file    = " "
    shot_file    = " "
    rebound_file = " "

    all_data = load_data(dribble_file, pass_file, shot_file, rebound_file)
    print(f"All data size (filtered) => {len(all_data)}")


    # Setup GPU for training
    global device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Define best parameters. NOTE We chose best hyperparameters after running optuna and tuning for each joint combination
    best_params = {
        "hidden_dim": 256,
        "dropout": 0.1017837074458662,
        "lr": 0.00022862438008880502,
        "batch_size": 64
    }

    # KFold evaluation to get best model. We are running this on the entire dataset
    best_model = regular_k_fold_evaluation(all_data, k=5, best_params=best_params)

    

if __name__ == "__main__":
    main()
