#### Imports

In [None]:
from torch.utils.data import random_split, DataLoader
import torch
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

from AudioKeystrokeDataset import AudioKeystrokeDataset
from CoatNet import CoAtNet
from Trainer import Trainer

#### Utils

In [None]:
import json

with open('config.json', 'r') as f:
    config = json.load(f)

DATASET_PATH = config['DATASET_PATH']['all']
model_id = config['MODELS']['llama3-8B']

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Create Dataset fot Training

In [None]:
dataset = AudioKeystrokeDataset(DATASET_PATH, full_dataset=True)
print(f"Dataset contains {len(dataset)} keystroke samples.")

In [None]:
with open("data/label2idx.json", "w") as f:
    json.dump(dataset.label2idx, f)
with open("data/device2idx.json", "w") as f:
    json.dump(dataset.device2idx, f)

In [None]:
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import Subset
import numpy as np

# Step 1: Get labels per sample
all_labels = [dataset.label2idx[label] for _, label, _ in dataset.samples] 
all_labels = np.array(all_labels)

# Step 2: First stratified split (train vs val+test)
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_idx, valtest_idx in sss1.split(np.zeros(len(all_labels)), all_labels):
    train_dataset = Subset(dataset, train_idx)
    valtest_labels = all_labels[valtest_idx]
    valtest_dataset = Subset(dataset, valtest_idx)

# Step 3: Second stratified split (val vs test)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
for val_idx, test_idx in sss2.split(np.zeros(len(valtest_labels)), valtest_labels):
    val_dataset = Subset(valtest_dataset, val_idx)
    test_dataset = Subset(valtest_dataset, test_idx)

# Step 4: Print final sizes
print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Testing dataset size: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=32, num_workers=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32,num_workers=4, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=4, shuffle=False)

## 2. Create Model

In [None]:
model = CoAtNet(num_classes=len(dataset.label2idx), num_devices=len(dataset.device2idx), in_channels=1)
model = model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5, verbose=True)

trainer = Trainer(model, train_loader, val_loader, criterion, optimizer, device, scheduler, early_stopping_patience=1000)

In [None]:
history = trainer.train(num_epochs=200, save_path='models/model_all.pth', best_save_path='models/best_model_all.pth')

## 3. Evaluate

In [None]:
model.eval()
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
correct = 0
total = 0
with torch.no_grad():
    for data, targets in test_loader:
        data, targets = data.to(device), targets.to(device)
        if len(data.shape) == 3:
            data = data.unsqueeze(1)
        outputs = model(data)
        probs = torch.nn.functional.softmax(outputs, dim=1)
        _, predicted = torch.max(probs, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()
test_accuracy = correct / total
print(f"Test Accuracy: {test_accuracy:.4f}")

In [None]:
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np

# Set model to evaluation mode
model.eval()

# Lists to store true labels and predictions
all_targets = []
all_predictions = []

# Disable gradient calculations for inference
with torch.no_grad():
    for data, targets in test_loader:
        data, targets = data.to(device), targets.to(device)

        # Ensure correct input dimensions
        if len(data.shape) == 3:
            data = data.unsqueeze(1)

        # Forward pass
        outputs = model(data)
        _, predicted = torch.max(outputs, 1)

        # Store results
        all_targets.extend(targets.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

# Convert lists to numpy arrays
all_targets = np.array(all_targets)
all_predictions = np.array(all_predictions)

# Get unique class labels
class_labels = np.unique(all_targets)

# Compute confusion matrix
cm = confusion_matrix(all_targets, all_predictions, labels=class_labels)

# Plot confusion matrix
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=False, fmt="d", cmap="Blues", xticklabels=class_labels, yticklabels=class_labels)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
# Remove x and y axis ticks completely
plt.xticks([])
plt.yticks([])
plt.show()

In [None]:
CHAR_MAPPINGS = {
    '.': ['fullstop'],
    ',': ['comma(,)'],
    "'": ["apostrophe(')"],
    '/': ['slash'],
    '\\': ['backslash'],
    '?': ['Lshift', 'slash'],
    ';': ['semicolon(;)'],
    ':': ['Lshift', 'semicolon(;)'],
    '-' : ['dash(-)'],
    '_': ['Lshift', 'dash(-)'],
    '=': ['equal(=)'],
    '+': ['Lshift', 'equal(=)'],
    ')': ['Lshift', '0'],
    '!': ['Lshift', '1'],
    '@': ['Lshift', '2'],
    '#': ['Lshift', '3'],
    '$': ['Lshift', '4'],
    '%': ['Lshift', '5'],
    '^': ['Lshift', '6'],
    '&': ['Lshift', '7'],
    '*': ['Lshift', '8'],
    '(': ['Lshift', '9'],
    '{': ['Lshift', 'bracketopen([)'],
    '}': ['Lshift', 'bracketclose(])'],
    '[': ['bracketopen([)'],
    ']': ['bracketclose(])'],
    ' ': ['space'],
}

In [None]:
def keystrokes_to_text(predicted_keys):
    reverse_map = {tuple(v): k for k, v in CHAR_MAPPINGS.items()}
    result = []
    i = 0
    caps_mode = False
    while i < len(predicted_keys):
        key = predicted_keys[i]
        if key == 'caps':
            caps_mode = not caps_mode
            i += 1
            continue
        matched = False
        for kseq, char in reverse_map.items():
            if predicted_keys[i:i + len(kseq)] == list(kseq):
                result.append(char)
                i += len(kseq)
                matched = True
                break
        if matched:
            continue
        if key == 'space':
            result.append(' ')
        # Handle regular characters
        elif len(key) == 1:
            result.append(key.upper() if caps_mode else key)
        # Unrecognized/special key
        else:
            result.append(f"<{key}>")
        i += 1
    return ''.join(result)

In [None]:
import os
import json
import librosa
import torch
from tqdm import tqdm

from save_individual_keystrokes import isolate_keystrokes, generate_mel_spectrogram

# === Paths ===
SENT_DIR = "data/sentences"
LABEL_PATH = "data/label2idx.json"
DEVICE_MAP_PATH = os.path.join(SENT_DIR, "sentence2device.json")

# === Constants ===
SR = 44100
N_FFT = 1024
HOP_LENGTH_STFT = 256
SEGMENT_LENGTH = 14400
N_MELS = 64
SPEC_HOP_LENGTH = 500
THRESHOLD = 100  # Use if isolate_keystrokes needs it

# === Load Label Mappings ===
with open(LABEL_PATH, 'r') as f:
    label2idx = json.load(f)
idx2label = {idx: label for label, idx in label2idx.items()}

# === Load Sentence to Device Map ===
with open(DEVICE_MAP_PATH, 'r') as f:
    sentence2device = json.load(f)

# === Model Eval Mode ===
model.eval()

# === Collect Predictions ===
predictions = {}

for fname in tqdm(os.listdir(SENT_DIR)):
    if not fname.endswith(".wav"):
        continue

    sid = fname.split(".")[0]
    wav_path = os.path.join(SENT_DIR, fname)

    if sid not in sentence2device:
        print(f"Skipping {sid}: device not found")
        continue

    device_name = sentence2device[sid]
    device_id = torch.tensor([dataset.device2idx[device_name]]).to(device)

    audio, _ = librosa.load(wav_path, sr=SR)
    threshold, segments, segment_starts = isolate_keystrokes(
        audio,
        sr=SR,
        segment_length=SEGMENT_LENGTH,
        n_fft=N_FFT,
        hop_length=HOP_LENGTH_STFT
    )

    if not segments:
        print(f"No keystrokes detected in {fname}")
        predictions[sid] = ""
        continue

    pred_chars = []
    with torch.no_grad():
        for seg in segments:
            mel = generate_mel_spectrogram(seg, sr=SR)
            mel = (mel - mel.min()) / (mel.max() - mel.min() + 1e-6)
            x = torch.from_numpy(mel).unsqueeze(0).unsqueeze(0).to(device)
            out = model(x, device_id)
            probs = torch.nn.functional.softmax(out, dim=1)
            idx = probs.argmax(dim=1).item()
            pred_chars.append(idx2label[idx])

    text = keystrokes_to_text(pred_chars)
    predictions[sid] = text

# === Final Output ===
print("\nPredictions:")
for sid, text in predictions.items():
    print(f"{sid}: {text}")

## 4. Language Model

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch
import torch.nn.functional as F

In [None]:
PROMPT = '''You are a helpful assistant that corrects typos in sentences based on the likely intent of the user. Each typo was generated by a model that tries to guess typed words based on keystroke sounds. 

The correction should be the closest valid sentence with proper spelling and grammar, assuming the model made as few mistakes as possible.

Examples are given below.

Typo: "the wuick bronw fix"
Correct: "the quick brown fox"

Typo: "how arw ypu"
Correct: "how are you"

Typo: "in tghe beeginning"
Correct: "in the beginning"

Typo: "whi is thsi hapenung"
Correct: "why is this happening"

Typo: "i kniw waht im doinf"
Correct: "i know what I'm doing"

Typo: "plrase snd help"
Correct: "please send help"

Correct each of the following typo sentences. Give the corrected sentence only, without any additional text or explanation.
'''

In [None]:
from transformers import pipeline

# Combine all messages into a single prompt string
prompt = PROMPT
for sentence in list(predictions.values())[:1]:
    prompt += f'\nTypo: "{sentence}"'

# Now generate using the text-generation pipeline
generator = pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto"
)

outputs = generator(prompt, max_new_tokens=256)
generated_text = outputs[0]['generated_text']

# Print the full output
print(generated_text)

In [None]:
for output in outputs:
    messages = output['generated_text']
    assistant_response = next(
        (msg['content'] for msg in reversed(messages) if msg['role'] == 'assistant'),
        None
    )
    print(assistant_response)
    print("===")