In [1]:
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


In [None]:
from AudioKeystrokeDataset.AudioKeystrokeDataset import AudioKeystrokeDataset
from CoatNet.CoatNet import CoAtNet
from CoatNet.Trainer import Trainer

import os
import json

from torch.utils.data import random_split, DataLoader

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau

## Utils

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

with open('config.json', 'r') as f:
    config = json.load(f)

DATASET_PATH = config['DATASET_PATH']['all']

Using device: cuda


## 1. Create Dataset fot Training

In [20]:
transform = transforms.Compose([
    transforms.ToPILImage(),           
    transforms.Resize((224, 224)),    
    transforms.ToTensor(),           
])

dataset = AudioKeystrokeDataset(DATASET_PATH, full_dataset=True, transform=transform)
print(f"Dataset contains {len(dataset)} keystroke samples.")

Processing Audio Files: 100%|██████████| 412/412 [05:39<00:00,  1.22it/s]

Dataset contains 14421 keystroke samples.





In [21]:
dataset_size = len(dataset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

val_dataset_size = len(val_dataset)
val_size = int(0.5 * val_dataset_size)
test_size = val_dataset_size - val_size
val_dataset, test_dataset = random_split(val_dataset, [val_size, test_size])

print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Testing dataset size: {len(test_dataset)}")

num_classes = set(dataset.get_labels())

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

Training dataset size: 11536
Validation dataset size: 1442
Testing dataset size: 1443


## 2. Create Model

In [None]:
model = CoAtNet(num_classes=len(dataset.label2idx), in_channels=1)
model = model.to(device)



In [23]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, device, scheduler=None, early_stopping_patience=10):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.scheduler = scheduler 
        self.early_stopping_patience = early_stopping_patience

    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for data, targets in self.train_loader:
            # Move data to device and ensure it has a channel dimension
            data, targets = data.to(self.device), targets.to(self.device)
            # Ensure data has 4 dimensions [B, C, H, W]
            if data.ndim == 3:  # [B, H, W]
                data = data.unsqueeze(1)  # [B, 1, H, W]
            # If still only 1 channel, repeat to make 3 channels
            if data.shape[1] == 1:
                data = data.repeat(1, 3, 1, 1)  # [B, 3, H, W]
            self.optimizer.zero_grad()
            outputs = self.model(data)
            loss = self.criterion(outputs, targets)
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item() * data.size(0)
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
        epoch_loss = running_loss / total
        epoch_acc = correct / total
        return epoch_loss, epoch_acc

    def validate(self):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for data, targets in self.val_loader:
                data, targets = data.to(self.device), targets.to(self.device)
                # Ensure data has 4 dimensions [B, C, H, W]
                if data.ndim == 3:  # [B, H, W]
                    data = data.unsqueeze(1)  # [B, 1, H, W]
                # If still only 1 channel, repeat to make 3 channels
                if data.shape[1] == 1:
                    data = data.repeat(1, 3, 1, 1)  # [B, 3, H, W]
                outputs = self.model(data)
                loss = self.criterion(outputs, targets)
                running_loss += loss.item() * data.size(0)
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        epoch_loss = running_loss / total
        epoch_acc = correct / total
        return epoch_loss, epoch_acc

    def train(self, num_epochs, save_path=None, resume=False, load_path=None):
        start_epoch = 0
        best_val_acc = 0.0
        patience_counter = 0
        
        if resume and load_path and os.path.exists(load_path):
            checkpoint = torch.load(load_path)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            start_epoch = checkpoint.get('epoch', 0)
            best_val_acc = checkpoint.get('best_val_acc', 0.0)
            print(f"Resuming training from epoch {start_epoch}")
            
        for epoch in range(start_epoch, start_epoch + num_epochs):
            train_loss, train_acc = self.train_epoch()
            val_loss, val_acc = self.validate()
            print(f"Epoch {epoch+1}/{start_epoch + num_epochs} | "
                  f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
                  f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
            
            # Step the scheduler if provided (using val_acc as the metric)
            if self.scheduler:
                self.scheduler.step(val_acc)
            
            # Early stopping check based on validation accuracy
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_counter = 0
                if save_path:
                    torch.save({
                        'epoch': epoch + 1,
                        'model_state_dict': self.model.state_dict(),
                        'best_val_acc': best_val_acc
                    }, save_path)
                    print(f"Saved best model at epoch {epoch+1} with Val Acc: {best_val_acc:.4f}")
            else:
                patience_counter += 1
                if patience_counter >= self.early_stopping_patience:
                    print(f"Early stopping triggered at epoch {epoch+1}. No improvement in validation accuracy for {self.early_stopping_patience} epochs.")
                    break

In [24]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5, verbose=True)

trainer = Trainer(model, train_loader, val_loader, criterion, optimizer, device, scheduler, early_stopping_patience=20)



In [None]:
trainer.train(num_epochs=10, save_path='models/vit.pth', resume=False)

## 3. Evaluate

In [None]:
model.eval()
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
correct = 0
total = 0
with torch.no_grad():
    for data, targets in test_loader:
        data, targets = data.to(device), targets.to(device)
        if len(data.shape) == 3:
            data = data.unsqueeze(1)
        outputs = model(data)
        _, predicted = torch.max(outputs, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()
test_accuracy = correct / total
print(f"Test Accuracy: {test_accuracy:.4f}")

Test Accuracy: 0.7970
