In [None]:
from AudioKeystrokeDataset.AudioKeystrokeDataset import AudioKeystrokeDataset
from CoatNet.CoatNet import CoAtNet

from torch.utils.data import random_split, DataLoader

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms

## Utils

In [59]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [60]:
import json

with open('config.json', 'r') as f:
    config = json.load(f)

DATASET_PATH = config['DATASET_PATH']['mac']

## 1. Create Dataset fot Training

In [61]:
dataset = AudioKeystrokeDataset(DATASET_PATH)
print(f"Dataset contains {len(dataset)} keystroke samples.")

Processing Audio Files: 100%|██████████| 61/61 [00:31<00:00,  1.93it/s]

Label mapping: {'(.).wav': 0, '0.wav': 1, '1.wav': 2, '2.wav': 3, '3.wav': 4, '4.wav': 5, '5.wav': 6, '6.wav': 7, '7.wav': 8, '8.wav': 9, '9.wav': 10, 'Lalt.wav': 11, 'Lcmd.wav': 12, 'Lctrl.wav': 13, 'Lshift.wav': 14, 'Ralt.wav': 15, 'Rshift.wav': 16, 'a.wav': 17, "apostrophe(').wav": 18, 'b.wav': 19, 'backslash.wav': 20, 'bracketclose(]).wav': 21, 'bracketopen([).wav': 22, 'c.wav': 23, 'caps.wav': 24, 'd.wav': 25, 'delete.wav': 26, 'down.wav': 27, 'e.wav': 28, 'enter.wav': 29, 'equal(=).wav': 30, 'esc.wav': 31, 'f.wav': 32, 'fn.wav': 33, 'g.wav': 34, 'h.wav': 35, 'i.wav': 36, 'j.wav': 37, 'k.wav': 38, 'l.wav': 39, 'left.wav': 40, 'm.wav': 41, 'n.wav': 42, 'o.wav': 43, 'p.wav': 44, 'q.wav': 45, 'r.wav': 46, 'right.wav': 47, 's.wav': 48, 'slash.wav': 49, 'space.wav': 50, 'start.wav': 51, 't.wav': 52, 'tab.wav': 53, 'u.wav': 54, 'up.wav': 55, 'v.wav': 56, 'w.wav': 57, 'x.wav': 58, 'y.wav': 59, 'z.wav': 60}
Dataset contains 2132 keystroke samples.





In [62]:
dataset_size = len(dataset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

Training dataset size: 1705
Validation dataset size: 427


In [71]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

model = CoAtNet(num_classes=61)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-4)

In [72]:
# Training loop
for epoch in range(1100):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs = inputs.unsqueeze(1).float().to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{10}], Loss: {avg_loss:.4f}")

    # Validation every few epochs
    if (epoch + 1) % 5 == 0:
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.unsqueeze(1).float().to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        print(f"Validation Accuracy: {correct / total:.4f}")


torch.save(model.state_dict(), 'model.pth')

Epoch [1/10], Loss: 4.2954
Epoch [2/10], Loss: 4.1424
Epoch [3/10], Loss: 4.1304
Epoch [4/10], Loss: 4.1230
Epoch [5/10], Loss: 4.1215
Validation Accuracy: 0.0094
Epoch [6/10], Loss: 4.1187
Epoch [7/10], Loss: 4.1189
Epoch [8/10], Loss: 4.1153
Epoch [9/10], Loss: 4.1152
Epoch [10/10], Loss: 4.1152
Validation Accuracy: 0.0094
Epoch [11/10], Loss: 4.1121
Epoch [12/10], Loss: 4.1068
Epoch [13/10], Loss: 4.0994
Epoch [14/10], Loss: 4.0918
Epoch [15/10], Loss: 4.0809
Validation Accuracy: 0.0211
Epoch [16/10], Loss: 4.0611
Epoch [17/10], Loss: 4.0496
Epoch [18/10], Loss: 4.0182
Epoch [19/10], Loss: 3.9999
Epoch [20/10], Loss: 3.9453
Validation Accuracy: 0.0328
Epoch [21/10], Loss: 3.8980
Epoch [22/10], Loss: 3.8839
Epoch [23/10], Loss: 3.8473
Epoch [24/10], Loss: 3.7900
Epoch [25/10], Loss: 3.7283
Validation Accuracy: 0.0492
Epoch [26/10], Loss: 3.7060
Epoch [27/10], Loss: 3.6727
Epoch [28/10], Loss: 3.6532
Epoch [29/10], Loss: 3.6130
Epoch [30/10], Loss: 3.6029
Validation Accuracy: 0.0562
E

KeyboardInterrupt: 