In [None]:
# Cell 1: Import libraries
import os
import json
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.optim as optim

In [None]:
# Cell 2: Define the dataset for JSON keypoint files (updated)
class ASLDataset(Dataset):
    def __init__(self, json_dir, max_seq_length=100, default_num_keypoints=17, default_keypoint_dim=2, transform=None):
        self.json_files = [os.path.join(json_dir, f) for f in os.listdir(json_dir) if f.endswith('.json')]
        self.max_seq_length = max_seq_length
        self.default_num_keypoints = default_num_keypoints
        self.default_keypoint_dim = default_keypoint_dim
        self.transform = transform
        # Collect unique labels from filenames (assumes "label_videoId.json" naming)
        self.labels = sorted(list({self._get_label(f) for f in self.json_files}))
        self.label_to_idx = {label: idx for idx, label in enumerate(self.labels)}
    
    def _get_label(self, filepath):
        filename = os.path.basename(filepath)
        return filename.split('_')[0]
    
    def _pad_or_trim(self, keypoints):
        # keypoints: (num_frames, num_keypoints, keypoint_dim)
        num_frames = keypoints.shape[0]
        if num_frames < self.max_seq_length:
            pad_shape = (self.max_seq_length - num_frames, *keypoints.shape[1:])
            padding = np.zeros(pad_shape, dtype=keypoints.dtype)
            keypoints = np.concatenate([keypoints, padding], axis=0)
        elif num_frames > self.max_seq_length:
            keypoints = keypoints[:self.max_seq_length]
        return keypoints

    def __len__(self):
        return len(self.json_files)
    
    def __getitem__(self, idx):
        json_path = self.json_files[idx]
        with open(json_path, 'r') as f:
            data = json.load(f)
        # If keypoints are missing or empty, create a dummy sequence with zeros
        if 'keypoints' not in data or not data['keypoints']:
            keypoints = np.zeros((self.max_seq_length, self.default_num_keypoints, self.default_keypoint_dim))
        else:
            keypoints = np.array(data['keypoints'])
        keypoints = self._pad_or_trim(keypoints)
        # Normalize coordinates if keypoints are not all zeros
        if np.any(keypoints):
            max_val = np.max(np.abs(keypoints))
            if max_val > 0:
                keypoints = keypoints / max_val
        label = self._get_label(json_path)
        label_idx = self.label_to_idx[label]
        if self.transform:
            keypoints = self.transform(keypoints)
        # Convert to tensors
        keypoints = torch.tensor(keypoints, dtype=torch.float)
        label_idx = torch.tensor(label_idx, dtype=torch.long)
        return keypoints, label_idx

In [None]:
# Cell 3: Define the STCGN model
class STCGN(nn.Module):
    def __init__(self, num_keypoints, keypoint_dim, num_classes):
        super(STCGN, self).__init__()
        # Example convolution: adjust kernel size and channels as needed.
        self.conv = nn.Conv2d(in_channels=keypoint_dim, out_channels=64, kernel_size=(1,1))
        self.relu = nn.ReLU()
        self.fc = nn.Linear(64 * num_keypoints, num_classes)
    
    def forward(self, x):
        # x shape: (batch, seq_length, num_keypoints, keypoint_dim)
        # Rearrange to (batch, keypoint_dim, seq_length, num_keypoints)
        x = x.permute(0, 3, 1, 2)
        x = self.relu(self.conv(x))
        # Average over temporal dimension: (batch, 64, num_keypoints)
        x = torch.mean(x, dim=2)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [None]:
# Cell 4: Define training and evaluation functions
def train_model(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    return running_loss / len(dataloader.dataset)

def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return running_loss / len(dataloader.dataset), correct / total


In [None]:
# Cell 5: Main training loop
def main():
    json_dir = '/home/haggenmueller/asl_detection/machine_learning/datasets/wlasl/keypoints'
    batch_size = 32
    max_seq_length = 100
    num_epochs = 10

    dataset = ASLDataset(json_dir, max_seq_length=max_seq_length)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Determine input dimensions from one sample
    sample, _ = dataset[0]
    num_keypoints = sample.shape[1]
    keypoint_dim = sample.shape[2]
    num_classes = len(dataset.labels)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = STCGN(num_keypoints, keypoint_dim, num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(num_epochs):
        train_loss = train_model(model, train_loader, criterion, optimizer, device)
        test_loss, test_acc = evaluate_model(model, test_loader, criterion, device)
        print(f'Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}')

    torch.save(model.state_dict(), 'st_cgn_model.pth')

In [None]:
# Cell 6: Execute the main function
main()