In [None]:
# Cell 1: Import libraries
import os
import json
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.optim as optim

In [None]:
class KeypointAugmentation:
    def __init__(self, noise_std=0.01, scale_range=(0.9, 1.1)):
        self.noise_std = noise_std
        self.scale_range = scale_range

    def __call__(self, keypoints):
        # keypoints shape: (seq_length, num_keypoints, keypoint_dim)
        # Add Gaussian noise
        noise = np.random.randn(*keypoints.shape) * self.noise_std
        keypoints = keypoints + noise
        # Apply a random scaling factor
        scale = np.random.uniform(*self.scale_range)
        keypoints = keypoints * scale
        return keypoints

In [None]:
class ASLDataset(Dataset):
    def __init__(self, json_dir, max_seq_length=100, default_num_keypoints=31, default_keypoint_dim=4, transform=None):
        # List all JSON files in the directory
        self.json_files = [os.path.join(json_dir, f) for f in os.listdir(json_dir) if f.endswith('.json')]
        self.max_seq_length = max_seq_length
        self.default_num_keypoints = default_num_keypoints
        self.default_keypoint_dim = default_keypoint_dim
        self.transform = transform
        # Extract unique labels from filenames (assumes naming "label_videoId.json")
        self.labels = sorted(list({self._get_label(f) for f in self.json_files}))
        self.label_to_idx = {label: idx for idx, label in enumerate(self.labels)}
    
    def _get_label(self, filepath):
        filename = os.path.basename(filepath)
        return filename.split('_')[0]
    
    def _pad_or_trim(self, keypoints):
        # keypoints: (num_frames, num_keypoints, keypoint_dim)
        num_frames = keypoints.shape[0]
        if num_frames < self.max_seq_length:
            pad_shape = (self.max_seq_length - num_frames, *keypoints.shape[1:])
            padding = np.zeros(pad_shape, dtype=keypoints.dtype)
            keypoints = np.concatenate([keypoints, padding], axis=0)
        elif num_frames > self.max_seq_length:
            keypoints = keypoints[:self.max_seq_length]
        return keypoints

    def process_frame(self, frame):
        """
        Process a single frame.
        If the frame is a dict with a "pose" key, convert that value to a NumPy array.
        Then enforce the shape (default_num_keypoints, default_keypoint_dim) by padding or trimming.
        """
        if isinstance(frame, dict):
            if "pose" in frame and frame["pose"] is not None:
                pose = np.array(frame["pose"])
            else:
                pose = np.zeros((self.default_num_keypoints, self.default_keypoint_dim))
        else:
            pose = np.array(frame)
        if pose.ndim != 2:
            pose = np.zeros((self.default_num_keypoints, self.default_keypoint_dim))
        n, d = pose.shape
        # Adjust the number of columns (keypoint dimension)
        if d < self.default_keypoint_dim:
            pad_d = np.zeros((n, self.default_keypoint_dim - d))
            pose = np.concatenate([pose, pad_d], axis=1)
        elif d > self.default_keypoint_dim:
            pose = pose[:, :self.default_keypoint_dim]
        # Enforce a fixed number of keypoints (rows)
        if n < self.default_num_keypoints:
            pad_rows = np.zeros((self.default_num_keypoints - n, self.default_keypoint_dim))
            pose = np.concatenate([pose, pad_rows], axis=0)
        elif n > self.default_num_keypoints:
            pose = pose[:self.default_num_keypoints, :]
        return pose
    
    def __len__(self):
        return len(self.json_files)
    
    def __getitem__(self, idx):
        json_path = self.json_files[idx]
        with open(json_path, 'r') as f:
            data = json.load(f)
        # If keypoints are missing or empty, create a dummy sequence of zeros
        if 'keypoints' not in data or not data['keypoints']:
            keypoints = np.zeros((self.max_seq_length, self.default_num_keypoints, self.default_keypoint_dim))
        else:
            raw_keypoints = data['keypoints']
            # Process keypoints if stored as a dict (indexed by frame)
            if isinstance(raw_keypoints, dict):
                sorted_frame_keys = sorted(raw_keypoints.keys(), key=lambda x: int(x) if x.isdigit() else x)
                frame_list = [self.process_frame(raw_keypoints[k]) for k in sorted_frame_keys]
                keypoints = np.array(frame_list)
            # Process keypoints if stored as a list
            elif isinstance(raw_keypoints, list):
                frame_list = [self.process_frame(frame) for frame in raw_keypoints]
                keypoints = np.array(frame_list)
            else:
                keypoints = np.array(raw_keypoints)
        # Pad or trim the sequence along the temporal dimension
        keypoints = self._pad_or_trim(keypoints)
        # Normalize the keypoints if non-zero
        if np.any(keypoints):
            max_val = np.max(np.abs(keypoints))
            if max_val > 0:
                keypoints = keypoints / max_val
        if self.transform:
            keypoints = self.transform(keypoints)
        keypoints = torch.tensor(keypoints, dtype=torch.float)
        label = self._get_label(json_path)
        label_idx = self.label_to_idx[label]
        label_idx = torch.tensor(label_idx, dtype=torch.long)
        return keypoints, label_idx

In [None]:
class STCGN(nn.Module):
    def __init__(self, num_keypoints, keypoint_dim, num_classes):
        super(STCGN, self).__init__()
        # First Convolution Block
        self.conv1 = nn.Conv2d(in_channels=keypoint_dim, out_channels=64, kernel_size=(3,1), padding=(1,0))
        self.bn1 = nn.BatchNorm2d(64)
        # Second Convolution Block
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3,1), padding=(1,0))
        self.bn2 = nn.BatchNorm2d(128)
        # Third Convolution Block
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3,1), padding=(1,0))
        self.bn3 = nn.BatchNorm2d(256)
        # Increased Dropout to reduce overfitting
        self.dropout = nn.Dropout(0.6)
        # Adaptive Average Pooling to pool over the time dimension
        self.avg_pool = nn.AdaptiveAvgPool2d((1, num_keypoints))
        self.fc = nn.Linear(256 * num_keypoints, num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        # x shape: (batch, seq_length, num_keypoints, keypoint_dim)
        # Permute to shape: (batch, keypoint_dim, seq_length, num_keypoints)
        x = x.permute(0, 3, 1, 2)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.dropout(x)
        x = self.avg_pool(x)  # Output shape: (batch, 256, 1, num_keypoints)
        x = x.reshape(x.size(0), -1)  # Flatten to shape: (batch, 256 * num_keypoints)
        x = self.fc(x)
        return x

In [None]:
# Cell: Updated Main Block with Transfer Learning, Early Stopping, and Scheduler
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

def main():
    json_dir = '/home/haggenmueller/asl_detection/machine_learning/datasets/wlasl/keypoints'  # Path to your JSON files
    batch_size = 32
    max_seq_length = 100
    num_epochs = 30
    patience = 5  # Early Stopping: Number of epochs to wait for improvement
    use_transfer_learning = True  # Option to freeze early layers for transfer learning
    transform = KeypointAugmentation(noise_std=0.01, scale_range=(0.9, 1.1))
    
    dataset = ASLDataset(json_dir, max_seq_length=max_seq_length, transform=transform)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Determine input dimensions
    sample, _ = dataset[0]
    num_keypoints = sample.shape[1]
    keypoint_dim = sample.shape[2]
    num_classes = len(dataset.labels)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = STCGN(num_keypoints, keypoint_dim, num_classes).to(device)

    # Transfer Learning: Freeze weights of early convolution layers if enabled
    if use_transfer_learning:
        for param in model.conv1.parameters():
            param.requires_grad = False
        for param in model.conv2.parameters():
            param.requires_grad = False

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

    best_loss = float('inf')
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        # Training Phase
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        train_loss = running_loss / len(train_loader.dataset)
        
        # Evaluation Phase
        model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                running_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        test_loss = running_loss / len(test_loader.dataset)
        test_acc = correct / total

        scheduler.step(test_loss)
        print(f'Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}')

        # Early Stopping: If test loss does not improve for 'patience' epochs, stop training
        if test_loss < best_loss:
            best_loss = test_loss
            epochs_no_improve = 0
            # Optionally save the best model
            torch.save(model.state_dict(), 'best_stcgn_model.pth')
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print("Early Stopping: No improvement.")
                break

    # Save the final model
    torch.save(model.state_dict(), 'stcgn_model.pth')

if __name__ == '__main__':
    main()