In [1]:
%env CUDA_VISIBLE_DEVICES=1,2

env: CUDA_VISIBLE_DEVICES=1,2


In [5]:
import librosa
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
def extract_asd(audio, sr, n_fft=512, hop_length=128):
    stft = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length)
    magnitude = np.abs(stft)
    asd = np.mean(magnitude, axis=0)
    return asd

In [6]:
class KeystrokeCountCNN(nn.Module):
    def __init__(self, max_keystrokes=3):
        super().__init__()
        self.conv1 = nn.Conv1d(1, 8, 3, padding=1)
        self.conv2 = nn.Conv1d(8, 16, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool1d(32)
        self.fc1 = nn.Linear(32 * 16, 64)
        self.fc2 = nn.Linear(64, max_keystrokes + 1) 

    def forward(self, x):
        x = x.unsqueeze(1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

In [9]:
from scipy.signal import find_peaks, peak_prominences

def estimate_keystroke_onsets(audio, sr, hop_length=128, segment_length=14400, max_peaks=3):
    asd = extract_asd(audio, sr, hop_length=hop_length)
    peaks, _ = find_peaks(asd, distance=5)
    prominences = peak_prominences(asd, peaks)[0]

    ranked_peaks = [peak for _, peak in sorted(zip(prominences, peaks), reverse=True)]
    onset_samples = [(p * hop_length) for p in ranked_peaks[:max_peaks]]

    return onset_samples

def segment_from_onsets(audio, sr, onsets, segment_length=14400):
    segments = []
    for onset in onsets:
        start = max(onset - segment_length // 2, 0)
        end = min(start + segment_length, len(audio))
        if end - start == segment_length:
            segments.append(audio[start:end])
    return segments

In [10]:
def chameleon_pipeline(audio, sr, count_model, device='cpu'):
    # Preprocess audio to mono if needed
    if audio.ndim > 1:
        audio = np.mean(audio, axis=0)

    # === Step 1: ASD ===
    asd = extract_asd(audio, sr)
    asd_tensor = torch.tensor(asd, dtype=torch.float32).unsqueeze(0).to(device)

    # === Step 2: Predict Keystroke Count ===
    with torch.no_grad():
        logits = count_model(asd_tensor)
        pred_count = torch.argmax(logits, dim=1).item()

    # === Step 3: Predict Onsets ===
    onsets = estimate_keystroke_onsets(audio, sr, max_peaks=pred_count)

    # === Step 4: Segment ===
    segments = segment_from_onsets(audio, sr, onsets)
    return segments, onsets