In [31]:
import os
import csv
import json
import numpy as np
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on {device}")

Running on cpu


In [3]:
# Load in stroke data from a csv 
cousin_directory_path = os.path.join(os.getcwd(), '..', 'training_data')
cousin_directory_path = os.path.abspath(cousin_directory_path)

stroke_data_filename = 'normalized_strokes.csv'
csv_file_path = os.path.join(cousin_directory_path, stroke_data_filename)

The dataset has string representations for each label. We create a mapping between each unique string to a unique integer.

In [4]:
unique_labels = set()

with open(csv_file_path, mode='r') as file:
    csv_reader = csv.reader(file)
    next(csv_reader)  # Skip header
    for row in csv_reader:
        label = row[1]
        unique_labels.add(label)

sorted_labels = sorted(list(unique_labels))

label_to_index = {label: index for index, label in enumerate(sorted_labels)}
index_to_label = {index: label for label, index in label_to_index.items()}

Create dataset for stroke data

In [10]:
class StrokeDataset(Dataset):
    def __init__(self, csv_file, label_mapping):
        """
        Args:
            csv_file (str): Path to the csv file with stroke data.
            label_mapping (dict): Mapping from string labels to integer indices.
        """
        super().__init__()
        self.data = []
        self.labels = []
        self.label_mapping = label_mapping

        with open(csv_file, mode='r') as file:
            csv_reader = csv.reader(file)
            next(csv_reader)  # Skip header
            for row in csv_reader:
                label_str = row[1]
                stroke_points = row[2]

                # Convert stroke points from string to float tuples
                stroke_points = json.loads(stroke_points)
                self.data.append(stroke_points)
                self.labels.append(self.label_mapping[label_str])

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        stroke = np.array(self.data[index], dtype=np.float32)
        label = self.labels[index]
                
        return stroke, label

In [26]:
class SubsetStrokeDataset(Dataset):
    def __init__(self, data, labels):
        super().__init__()
        self.data = data
        self.labels = labels 
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        stroke = np.array(self.data[index], dtype=np.float32)
        label = self.labels[index]

        return stroke, label

In [14]:
dataset = StrokeDataset(csv_file_path, label_to_index)

In [29]:
train_size = 0.7
val_size = 0.15
test_size = 0.15

train_data, temp_data, train_labels, temp_labels = train_test_split(
    dataset.data, dataset.labels,
    train_size=train_size,
    random_state=42,
    stratify=dataset.labels
)

counter_temp = Counter(temp_labels)
insufficient_classes = [cls for cls, count in counter_temp.items() if count < 2]
print("Classes with fewer than 2 samples in temp_labels:", insufficient_classes)

if insufficient_classes:
    indices_to_move = [i for i, label in enumerate(temp_labels) if label in insufficient_classes]
    
    # Reassign these samples to the training set
    for idx in sorted(indices_to_move, reverse=True):
        train_data.append(temp_data.pop(idx))
        train_labels.append(temp_labels.pop(idx))
    
    print(f"Moved {len(indices_to_move)} samples to the training set to ensure all classes in temp_labels have at least 2 samples.")

val_data, test_data, val_labels, test_labels = train_test_split(
    temp_data, temp_labels,
    train_size=val_size / (val_size + test_size),
    random_state=42,
    stratify=temp_labels
)

train_dataset = SubsetStrokeDataset(train_data, train_labels)
val_dataset = SubsetStrokeDataset(val_data, val_labels)
test_dataset = SubsetStrokeDataset(test_data, test_labels)

print(f"Final dataset sizes:\n"
      f" - Training: {len(train_dataset)} samples\n"
      f" - Validation: {len(val_dataset)} samples\n"
      f" - Testing: {len(test_dataset)} samples")

Classes with fewer than 2 samples in temp_labels: [729, 237, 330, 845, 208, 170, 955, 286, 219, 1076, 829, 337, 154, 55, 910, 341, 673, 892, 963, 789, 805, 972, 117, 1082, 152, 272, 859, 216, 971, 228, 383, 848, 98]
Moved 33 samples to the training set to ensure all classes in temp_labels have at least 2 samples.
Final dataset sizes:
 - Training: 147350 samples
 - Validation: 31552 samples
 - Testing: 31552 samples


In [32]:
def collate_fn(batch):
    """
    Handles variable length sequences in stroke data by padding them with zeros.

    Args:
        batch (List[Tuple]): Each tuple contains (stroke, label).

    Returns:
        padded_strokes (torch.Tensor): Tensor of shape (batch_size, max_seq_length, 2).
        labels (torch.Tensor): Tensor of shape (batch_size,).
    """
    strokes, labels = zip(*batch)
    strokes = [torch.tensor(stroke, dtype=torch.float32) for stroke in strokes]

    # Pad sequences with zeros to amke sure each stroke is the same length
    padded_strokes = pad_sequence(strokes, batch_first=True, padding_value=0.0)
    labels = torch.tensor(labels, dtype=torch.long)

    return padded_strokes, labels

In [38]:
# Create dataloaders
batch_size = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size, 
    shuffle=True, 
    drop_last=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    drop_last=False,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    drop_last=False,
    collate_fn=collate_fn
)