In [None]:
from AudioKeystrokeDataset.AudioKeystrokeDataset import AudioKeystrokeDataset
# from CoatNet.CoatNet import CoAtNet
# from CoatNet.Trainer import Trainer

from torch.utils.data import random_split, DataLoader

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau

## CoatNet

In [None]:
class MBConv(nn.Module):
    """
    A Mobile Inverted Bottleneck Convolution (MBConv) block.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, expand_ratio=4):
        super(MBConv, self).__init__()
        hidden_dim = in_channels * expand_ratio
        self.use_residual = (stride == 1 and in_channels == out_channels)
        self.conv = nn.Sequential(
            # Expansion phase
            nn.Conv2d(in_channels, hidden_dim, kernel_size=1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1), 
            # Depthwise convolution
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=kernel_size, stride=stride,
                      padding=kernel_size // 2, groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            # Projection phase
            nn.Dropout(0.1),
            nn.Conv2d(hidden_dim, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        out = self.conv(x)
        if self.use_residual:
            return out + x
        return out

class TransformerBlock(nn.Module):
    """
    A transformer block that first normalizes and then applies multi-head attention 
    followed by a feed-forward network.
    """
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x is expected to have shape (B, C, H, W)
        B, C, H, W = x.shape
        # Flatten spatial dimensions: (S, B, C) where S = H*W
        x_flat = x.flatten(2).permute(2, 0, 1)
        # Apply multi-head self-attention with residual connection
        attn_out, _ = self.attn(self.norm1(x_flat), self.norm1(x_flat), self.norm1(x_flat))
        x_flat = x_flat + self.dropout(attn_out)
        # Feed-forward network with residual connection
        ff_out = self.ff(self.norm2(x_flat))
        x_flat = x_flat + self.dropout(ff_out)
        # Reshape back to (B, C, H, W)
        x = x_flat.permute(1, 2, 0).view(B, C, H, W)
        return x

##############################################
# Define the Modified CoAtNet Architecture
##############################################

class CoAtNet(nn.Module):
    def __init__(self, num_classes=36, in_channels=1):
        super(CoAtNet, self).__init__()
        # Stem: initial convolution to reduce spatial dimensions
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        # Stage 1: Convolutional blocks (MBConv)
        self.stage1 = nn.Sequential(
            MBConv(32, 64, stride=2),
            MBConv(64, 64, stride=1)
        )
        # Stage 2: Transformer blocks
        self.stage2 = nn.Sequential(
            TransformerBlock(embed_dim=64, num_heads=4),
            TransformerBlock(embed_dim=64, num_heads=4)
        )
        # Stage 3: Further convolutional blocks (MBConv)
        self.stage3 = nn.Sequential(
            MBConv(64, 128, stride=2),
            MBConv(128, 128, stride=1)
        )
        # Classification head: global pooling and a linear classifier
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        # x expected shape: (B, 1, H, W)
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

## Trainer

In [None]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, device, 
                 scheduler=None, early_stopping_patience=10):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.scheduler = scheduler 
        self.early_stopping_patience = early_stopping_patience

    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for data, targets in self.train_loader:
            data, targets = data.to(self.device), targets.to(self.device)
            if len(data.shape) == 3:  
                data = data.unsqueeze(1)
            self.optimizer.zero_grad()
            outputs = self.model(data)
            loss = self.criterion(outputs, targets)
            loss.backward()
            self.optimizer.step()
            running_loss += loss.item() * data.size(0)
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
        epoch_loss = running_loss / total
        epoch_acc = correct / total
        return epoch_loss, epoch_acc

    def validate(self):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for data, targets in self.val_loader:
                data, targets = data.to(self.device), targets.to(self.device)
                if len(data.shape) == 3:
                    data = data.unsqueeze(1)
                outputs = self.model(data)
                loss = self.criterion(outputs, targets)
                running_loss += loss.item() * data.size(0)
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        epoch_loss = running_loss / total
        epoch_acc = correct / total
        return epoch_loss, epoch_acc

    def train(self, num_epochs, save_path=None, resume=False, load_path=None):
        start_epoch = 0
        best_val_acc = 0.0
        patience_counter = 0
        
        if resume and load_path and os.path.exists(load_path):
            checkpoint = torch.load(load_path)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            start_epoch = checkpoint.get('epoch', 0)
            best_val_acc = checkpoint.get('best_val_acc', 0.0)
            print(f"Resuming training from epoch {start_epoch}")
            
        for epoch in range(start_epoch, start_epoch + num_epochs):
            train_loss, train_acc = self.train_epoch()
            val_loss, val_acc = self.validate()
            print(f"Epoch {epoch+1}/{start_epoch + num_epochs} | "
                  f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
                  f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
            
            # Step the scheduler if provided (using val_acc as the metric)
            if self.scheduler:
                self.scheduler.step(val_acc)
            
            # Early stopping check based on validation accuracy
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_counter = 0
                if save_path:
                    torch.save({
                        'epoch': epoch + 1,
                        'model_state_dict': self.model.state_dict(),
                        'best_val_acc': best_val_acc
                    }, save_path)
                    print(f"Saved best model at epoch {epoch+1} with Val Acc: {best_val_acc:.4f}")
            else:
                patience_counter += 1
                if patience_counter >= self.early_stopping_patience:
                    print(f"Early stopping triggered at epoch {epoch+1}. No improvement in validation accuracy for {self.early_stopping_patience} epochs.")
                    break

## Utils

In [57]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" "cpu")
print(f"Using device: {device}")

Using device: mps


In [None]:
import json

with open('config.json', 'r') as f:
    config = json.load(f)

DATASET_PATH = config['DATASET_PATH']['all']

## 1. Create Dataset fot Training

In [None]:
dataset = AudioKeystrokeDataset(DATASET_PATH, full_dataset=True)
print(f"Dataset contains {len(dataset)} keystroke samples.")

Processing Audio Files: 100%|██████████| 412/412 [04:51<00:00,  1.41it/s]

Dataset contains 14421 keystroke samples.





In [None]:
dataset_size = len(dataset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

val_dataset_size = len(val_dataset)
val_size = int(0.5 * val_dataset_size)
test_size = val_dataset_size - val_size
val_dataset, test_dataset = random_split(val_dataset, [val_size, test_size])

print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Testing dataset size: {len(test_dataset)}")

Training dataset size: 11536
Validation dataset size: 1442
Testing dataset size: 1443


In [None]:
num_classes = len(set(dataset.get_labels()))
num_classes

73

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

model = CoAtNet(num_classes=len(dataset.label2idx), in_channels=1)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5, verbose=True)

trainer = Trainer(model, train_loader, val_loader, criterion, optimizer, device, scheduler)

In [None]:
trainer.train(num_epochs=100, save_path='100_model_opt.pth', resume=False)

Epoch 1/100 | Train Loss: 4.2848, Train Acc: 0.0189 | Val Loss: 4.2388, Val Acc: 0.0229
Saved best model at epoch 1 with Val Acc: 0.0229
Epoch 2/100 | Train Loss: 4.1334, Train Acc: 0.0370 | Val Loss: 4.0254, Val Acc: 0.0479
Saved best model at epoch 2 with Val Acc: 0.0479
Epoch 3/100 | Train Loss: 3.8140, Train Acc: 0.0719 | Val Loss: 3.6795, Val Acc: 0.0818
Saved best model at epoch 3 with Val Acc: 0.0818
Epoch 4/100 | Train Loss: 3.4052, Train Acc: 0.1249 | Val Loss: 3.0912, Val Acc: 0.1574
Saved best model at epoch 4 with Val Acc: 0.1574
Epoch 5/100 | Train Loss: 3.0433, Train Acc: 0.1808 | Val Loss: 2.7672, Val Acc: 0.2178
Saved best model at epoch 5 with Val Acc: 0.2178
Epoch 6/100 | Train Loss: 2.6948, Train Acc: 0.2347 | Val Loss: 2.5915, Val Acc: 0.2587
Saved best model at epoch 6 with Val Acc: 0.2587
Epoch 7/100 | Train Loss: 2.4232, Train Acc: 0.2961 | Val Loss: 2.2753, Val Acc: 0.3155
Saved best model at epoch 7 with Val Acc: 0.3155
Epoch 8/100 | Train Loss: 2.1893, Train A