# Audrey - Building a Speech Recognition System

In 1952, Bell Labs created **AUDREY** (Automatic Digit Recognition), the first speech recognition system. It could recognize spoken digits 0-9 with ~90% accuracy - but only for a single speaker!

In this notebook, we'll build our own version of Audrey using modern deep learning. Along the way, you'll learn:

- **Data Collection**: Recording your own speech dataset
- **Data Augmentation**: Creating variations to make your model more robust  
- **Neural Networks**: Training both simple (MLP) and convolutional (CNN) models
- **Inference**: Using your trained model for real-time digit recognition

**Key Concept**: Machine learning models learn patterns from data. The more varied and representative your training data, the better your model will generalize to new inputs.

## Part 1: Recording Your Dataset

Every machine learning project starts with **data**. We need examples of what we want our model to recognize.

For speech recognition, we need audio recordings of each digit (0-9). You'll record yourself saying each digit, creating 10 audio files that will form the basis of our training data.

**Why record yourself?** Like the original Audrey, our model will be "speaker-dependent" - trained on your voice. This makes the problem easier to solve with limited data.

In [None]:
# Let's create our own speech dataset! This work with numbers first

# Recording audio with sounddevice and soundfile
# 
# https://python-soundfile.readthedocs.io/
# https://python-sounddevice.readthedocs.io


import sounddevice as sd
import soundfile as sf
import numpy as np
import time

def record_audio(filename: str, duration: int):

    # config
    samplerate = 44100
    duration = duration
    channels = 1

    print(f"Recording for {duration} seconds at {samplerate} Hz...")

    # record audio from the microphone into a numpy array with sounddevice
    recording = sd.rec(
        int(duration * samplerate),
        samplerate=samplerate,
        channels=channels,
        dtype='float32'
    )
    sd.wait()

    print(f"Recording finished. Saving to {filename}...")

    # save the file with soundfile
    sf.write(
        filename,
        recording,
        samplerate,
        subtype='PCM_16'
        )

    print(f"File '{filename}' saved successfully.")

In [None]:
# record yourself saying the digits 0-9. Say each number clearly and distinctly. This will matter a lot later on!

record_audio(
    filename='unprocessed/0.wav', # rename the file to 1.wav, 2.wav, 3.wav, etc.
    duration=1)

## Part 2: Data Augmentation

We only have 10 recordings, but neural networks typically need thousands of examples to learn well. 

**Data augmentation** solves this by creating variations of our original data:
- **Noise**: Adding random background noise (simulates different environments)
- **Time Stretch**: Making audio faster/slower (simulates speaking pace variation)
- **Pitch Shift**: Raising/lowering pitch (simulates voice variation)
- **Time Shift**: Moving audio left/right (simulates different recording starts)

From 10 original recordings, we'll create **10,000 augmented samples** (1,000 variations per digit)!

In [None]:
# Data augmentation functions - each creates a different variation of the audio

import numpy as np
import librosa

def noise(data, noise_amt=0.035):
    """Add random background noise to simulate different recording environments"""
    noise_amp = noise_amt * np.random.uniform() * np.amax(data)
    data = data + noise_amp * np.random.normal(size=data.shape[0])
    return data

def stretch(data, rate=0.8):
    """Speed up or slow down the audio (rate < 1 = slower, rate > 1 = faster)"""
    return librosa.effects.time_stretch(data, rate=rate)

def shift(data):
    """Shift audio left or right in time (simulates different recording starts)"""
    shift_range = int(np.random.uniform(low=-5, high=5) * 1000)
    return np.roll(data, shift_range)

def pitch(data, sampling_rate, n_steps=2):
    """Shift pitch up or down (n_steps = semitones, + = higher, - = lower)"""
    return librosa.effects.pitch_shift(data, sr=sampling_rate, n_steps=n_steps)

In [None]:
# take our recorded digits, and augment them to create a larger dataset

import os
import glob
import subprocess
import librosa
import numpy as np
import soundfile as sf
from tqdm.notebook import tqdm

# get all files in the 'unprocessed' directory (only .wav files)
files = glob.glob('unprocessed/*.wav')
print(files)

for file in tqdm(files):
    # get the digit from the file name
    digit = file.split('/')[-1].split('.')[0]
    
    # create the directory if it doesn't exist
    os.makedirs(f'processed/{digit}', exist_ok=True)
    # load file with sf
    audio, sample_rate = sf.read(file)

    for i in tqdm(range(1000)):
        
        processed_audio = noise(audio, np.random.uniform(0.001, 0.01))
        processed_audio = stretch(processed_audio, rate=np.random.uniform(0.8, 1.2))
        processed_audio = shift(processed_audio)
        processed_audio = pitch(processed_audio, sample_rate, n_steps=np.random.randint(-3, 3))

        sf.write(f'processed/{digit}/{digit}_{i}.wav', processed_audio, sample_rate)

In [None]:
# get all of the fies in speech_digits with glob
import glob

files = glob.glob('processed/*/*')

print(len(files))
print(files[:5])


## Part 3: Data Preprocessing

Before training, we need to prepare our data:

1. **Consistent Length**: Neural networks expect fixed-size inputs. We'll pad shorter audio files to match the longest one.

2. **Labels**: Each file needs a label (0-9) so the model knows what it should predict.

3. **Verification**: Check that all files are the same length after processing.

**Why padding?** Think of it like standardizing paper sizes before putting them in a binder - everything needs to fit the same format.

In [None]:
from IPython.display import Audio, display
import librosa
import matplotlib.pyplot as plt
import numpy as np
import os


digit_dir = 'processed'

# Get all files in the digit directory
files = glob.glob('processed/*/*')
print(files)

# Display spectogram and audio player for each file
for file in files[:5]:
    # Load the audio file
    y, sr = librosa.load(file, sr=None)
    
    # Display the spectogram
    plt.figure(figsize=(10, 4))
    librosa.display.waveshow(y, sr=sr)
    plt.title(f'Waveform for {os.path.basename(file)}')
    plt.show()
    
    # Display the spectogram
    plt.figure(figsize=(10, 4))
    D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max)
    
    plt.figure(figsize=(10, 4))
    librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='log')
    plt.colorbar(format='%+2.0f dB')
    plt.title(f'Spectrogram for {os.path.basename(file)}')
    plt.show()


    display(Audio(file))

In [None]:
# First pass: preprocess audio files so that they are all the same length

import glob
import librosa
import soundfile as sf
import numpy as np
from tqdm import tqdm

files = glob.glob('processed/*/*')
print(f"Total files: {len(files)}")

audio_data = []
labels = [] # here is where we create our labels
longest_audio_file_length = 0

# First pass: load data and find longest audio file
for f in tqdm(files):
    try:
        audio, sample_rate = librosa.load(f)
        if len(audio) == 0:
            print(f"Warning: Empty audio file: {f}")
            continue
        labels.append(int(f.split('/')[-2]))  # Adjust this based on your file structure
        longest_audio_file_length = max(longest_audio_file_length, len(audio))
    except Exception as e:
        print(f"Error processing file {f}: {str(e)}")

print(f"Longest audio size: {longest_audio_file_length}")

In [None]:
# Second pass: Pad audio files and resave them
for f in tqdm(files):
    try:
        audio, sample_rate = librosa.load(f)
        if len(audio) == 0:
            print(f"Warning: Empty audio file: {f}")
            continue
        current_size = len(audio)
        pad_size = longest_audio_file_length - current_size
        left_pad = pad_size // 2
        right_pad = pad_size - left_pad
        padded_audio = np.pad(audio, (left_pad, right_pad), mode='constant')
        sf.write(f, padded_audio, sample_rate)
    except Exception as e:
        print(f"Error processing file {f}: {str(e)}")

In [None]:
# Third pass: Verify that all files have the same size
file_sizes = []
for f in tqdm(files):
    try:
        audio, _ = librosa.load(f)
        file_sizes.append(len(audio))
    except Exception as e:
        print(f"Error processing file {f}: {str(e)}")

if len(set(file_sizes)) == 1:
    print(f"All files have the same size: {file_sizes[0]} samples")
else:
    print("Warning: Not all files have the same size")
    print(f"Unique file sizes: {set(file_sizes)}")
    print(f"Min size: {min(file_sizes)}, Max size: {max(file_sizes)}")

In [None]:
print(files[1000:1005])
print(labels[1000:1005])


## Part 4: Creating the Dataset and DataLoader

Now we convert our audio files into a format PyTorch can use for training.

### Key Concepts:

**Mel Spectrogram**: Instead of feeding raw audio waveforms to our model, we convert them to **mel spectrograms** - visual representations of sound that show frequency content over time. This is similar to how humans perceive sound!

**Dataset**: A PyTorch class that holds our data and knows how to load individual samples.

**DataLoader**: Handles batching (grouping samples together) and shuffling during training.

### Train/Validation/Test Split (Critical!)

- **Training set (70%)**: What the model learns from
- **Validation set (20%)**: Used to check for overfitting during training
- **Test set (10%)**: Held out completely - only used for final evaluation

**Avoiding Data Leakage**: All 10,000 augmented files come from just 10 original recordings. If we randomly split, variations of the same source end up in both train AND test - the model learns to recognize the specific recording rather than the digit!

**Our solution**: Split **proportionally within each digit**. For each digit's 1,000 files, the first 700 go to train, next 200 to validation, last 100 to test.

In [None]:
import torch as t
from torchaudio import transforms
import torchaudio
import random
from torch.utils.data import Dataset, DataLoader

# Custom Dataset class - tells PyTorch how to load our audio data
class AudioDataset(Dataset):
    def __init__(self, file_paths, labels, transforms=transforms.MelSpectrogram()):
        self.file_paths = file_paths  # List of paths to audio files
        self.labels = labels          # List of corresponding labels (0-9)
        self.transforms = transforms  # MelSpectrogram converts audio to spectrogram

    def __len__(self):
        """How many samples in the dataset?"""
        return len(self.file_paths)

    def __getitem__(self, idx):
        """Load and return a single sample (called by DataLoader)"""
        audio_path = self.file_paths[idx]
        waveform, _ = torchaudio.load(audio_path)

        # Ensure mono audio (single channel)
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0).unsqueeze(0)

        # Convert waveform to mel spectrogram
        if self.transforms:
            spec = self.transforms(waveform)
        return spec, self.labels[idx]


# Split files BY DIGIT to avoid data leakage
# (Augmentations of the same source recording stay in the same set)
import glob

train_files, train_labels = [], []
val_files, val_labels = [], []
test_files, test_labels = [], []

for digit in range(10):
    # Get all files for this digit, sorted by name (keeps similar augmentations together)
    digit_files = sorted(glob.glob(f'processed/{digit}/*.wav'))
    n = len(digit_files)
    
    # Split: first 70% train, next 20% val, last 10% test
    train_end = int(0.7 * n)
    val_end = int(0.9 * n)
    
    train_files.extend(digit_files[:train_end])
    train_labels.extend([digit] * train_end)
    
    val_files.extend(digit_files[train_end:val_end])
    val_labels.extend([digit] * (val_end - train_end))
    
    test_files.extend(digit_files[val_end:])
    test_labels.extend([digit] * (n - val_end))

print(f"Training samples: {len(train_files)}")
print(f"Validation samples: {len(val_files)}")
print(f"Test samples: {len(test_files)}")

# Create separate datasets for each split (no random_split needed!)
train_dataset = AudioDataset(train_files, train_labels, transforms=transforms.MelSpectrogram())
validation_dataset = AudioDataset(val_files, val_labels, transforms=transforms.MelSpectrogram())
test_dataset = AudioDataset(test_files, test_labels, transforms=transforms.MelSpectrogram())

# DataLoaders handle batching and shuffling
# batch_size=32 means we process 32 samples at a time
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32, shuffle=True)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(validation_loader)}")

In [None]:
import matplotlib.pyplot as plt

# see a batch
for batch in train_loader:
    inputs, targets = batch
    print(inputs.shape)
    print(inputs[0][0].shape)
    print(targets)
    break


mel_freq_bins = inputs[0][0].shape[0]
time_steps = inputs[0][0].shape[1]

print("mel freq bins: ", mel_freq_bins)
print("time steps: ", time_steps)


In [None]:
# train with a simple Multi-Layer Perceptron (MLP) - Fully-Connected Neural Network

device = t.device('cuda' if t.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = t.nn.Sequential(
    t.nn.Flatten(),
    t.nn.Linear(mel_freq_bins*time_steps, 512), # 128 mel bins, 366 time steps
    t.nn.ReLU(),
    t.nn.Linear(512, 512),
    t.nn.ReLU(),
    t.nn.Linear(512, 10),
    t.nn.Softmax(dim=1)
)

# train our model
device = t.device('cuda' if t.cuda.is_available() else 'cpu')

model.to(device)

loss_fn = t.nn.CrossEntropyLoss()
optimizer = t.optim.Adam(model.parameters(), lr=0.001)

epochs = 10

print(f"Training for {epochs} epochs")
for epoch in tqdm(range(epochs)):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

# evaluate our model

model.eval()

correct = 0
total = 0

with t.no_grad():
    for batch in validation_loader:
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)
        _, predicted = t.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

print(f"Accuracy of the model on the test set: {100 * correct / total}%")

### Understanding the MLP Above

The cell above trained a **Multi-Layer Perceptron (MLP)** - the simplest neural network architecture.

**How it works:**
1. **Flatten** the 2D spectrogram into a 1D vector (loses spatial structure)
2. Pass through **linear layers** that learn weights for each input
3. **ReLU activation** adds non-linearity: `output = max(0, input)`
4. Final layer outputs 10 values (one per digit)

**Training Loop:**
1. **Forward pass**: Feed data through the network to get predictions
2. **Calculate loss**: Measure how wrong the predictions are (CrossEntropyLoss)
3. **Backward pass**: Calculate gradients (how to adjust weights)
4. **Update weights**: Use optimizer (Adam) to improve the model

MLPs work, but they don't understand spatial relationships in the spectrogram. Let's try something better...

In [None]:
# Convolutional Neural Network for audio classification
import torch as t
import torch.nn as nn

# Use GPU if available (much faster training!)
device = t.device('cuda' if t.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class ConvModel(nn.Module):
    def __init__(self, mel_freq_bins, time_steps, num_classes=10):
        super(ConvModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Calculate the size of the flattened features (using explicit parameters)
        self.flat_features = 128 * (mel_freq_bins // 8) * (time_steps // 8)
        
        self.fc1 = nn.Linear(self.flat_features, 512)
        self.relu4 = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        # Input shape: (batch_size, 1, 128, 366)
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.pool3(self.relu3(self.conv3(x)))
        x = x.view(-1, self.flat_features) # rewrite this line with einops / ARENA
        x = self.relu4(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

### Understanding the CNN Above

The cell above defined a **Convolutional Neural Network (CNN)** - designed for data with spatial structure like images and spectrograms.

**Key Components in ConvModel:**

| Layer | Purpose |
|-------|---------|
| `Conv2d` | Slides small filters across the input, detecting local patterns |
| `ReLU` | Activation function: `output = max(0, input)` |
| `MaxPool2d` | Reduces dimensions by keeping max value in each region |
| `Dropout` | Randomly disables neurons during training (prevents overfitting) |
| `Linear` | Fully connected layers for final classification |

**Architecture Flow:**
```
Spectrogram → [Conv→ReLU→Pool] x3 → Flatten → Linear → Linear → 10 digit scores
```

**Why CNNs beat MLPs:** Convolutions can detect patterns (like frequency bands) regardless of where they appear in the spectrogram. The pooling layers make the network robust to small shifts in timing.

In [None]:
# Initialize the model (passing mel_freq_bins and time_steps explicitly)

conv_model = ConvModel(mel_freq_bins=mel_freq_bins, time_steps=time_steps)
print(conv_model)

In [None]:
# Train the CNN with proper metric tracking
# We'll track TRAINING LOSS, VALIDATION LOSS, and VALIDATION ACCURACY each epoch

conv_model = conv_model.to(device)

loss_fn = t.nn.CrossEntropyLoss()
optimizer = t.optim.Adam(conv_model.parameters(), lr=0.001)

epochs = 15

# Track metrics for each epoch
train_loss_history = []
val_loss_history = []
val_accuracy_history = []

print(f"Training for {epochs} epochs")
print("=" * 70)

for epoch in range(epochs):
    # =====================
    # TRAINING PHASE
    # =====================
    conv_model.train()  # Enable dropout, batch norm training mode
    train_loss = 0
    
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()       # Clear gradients from last batch
        outputs = conv_model(inputs) # Forward pass
        loss = loss_fn(outputs, targets)  # Calculate loss
        loss.backward()             # Backpropagation (compute gradients)
        optimizer.step()            # Update weights

        train_loss += loss.item()
    
    avg_train_loss = train_loss / len(train_loader)
    train_loss_history.append(avg_train_loss)
    
    # =====================
    # VALIDATION PHASE
    # =====================
    conv_model.eval()  # Disable dropout, use running stats for batch norm
    val_loss = 0
    correct = 0
    total = 0
    
    with t.no_grad():  # Don't compute gradients for validation
        for batch in validation_loader:
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)
            
            outputs = conv_model(inputs)
            loss = loss_fn(outputs, targets)
            val_loss += loss.item()
            
            _, predicted = t.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    
    avg_val_loss = val_loss / len(validation_loader)
    val_accuracy = 100 * correct / total
    
    val_loss_history.append(avg_val_loss)
    val_accuracy_history.append(val_accuracy)
    
    # Print epoch summary
    print(f"Epoch {epoch+1}/{epochs}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}, Val Acc={val_accuracy:.1f}%")

print("=" * 70)
print("Finished training!")

# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss plot
axes[0].plot(train_loss_history, label='Training Loss', marker='o')
axes[0].plot(val_loss_history, label='Validation Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training vs Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy plot
axes[1].plot(val_accuracy_history, label='Validation Accuracy', marker='o', color='green')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_ylim([0, 105])

plt.tight_layout()
plt.show()

print(f"\nFinal Validation Accuracy: {val_accuracy_history[-1]:.1f}%")





### Understanding the Training Curves

The plots above show the **learning progress** of our neural network. Here's how to read them:

#### Training Loss vs Validation Loss

| Pattern | What it means | What to do |
|---------|---------------|------------|
| Both decreasing together | Model is learning well! | Keep training |
| Training ↓, Validation ↑ | **Overfitting!** Model memorizes training data | Stop earlier, add regularization, or get more data |
| Both high and flat | **Underfitting.** Model can't learn patterns | Train longer, use bigger model, or check data |

#### Why Two Losses?

- **Training Loss**: How wrong the model is on data it's *actively learning from*
- **Validation Loss**: How wrong the model is on data it has *never seen*

We care most about validation loss because it predicts real-world performance!

#### Validation Accuracy

- Shows what percentage of validation samples the model classifies correctly
- Should increase as training progresses
- If it plateaus early while loss keeps decreasing, the model may be "over-confident" on wrong answers

#### The Overfitting Gap

Watch the **gap** between training and validation metrics:
- Small gap = Good generalization (model learned the concept)
- Large gap = Overfitting (model memorized the examples)

For our digit recognition task, we want validation accuracy > 90% with training and validation loss staying close together.

### What Are Model Weights?

When we say we "trained" a neural network, we mean we adjusted millions of numbers (called **weights** or **parameters**) until the network could recognize patterns in our data.

The cell below shows you what these weights actually look like - just arrays of decimal numbers!

In [None]:
# Peek inside the model weights - what did the network actually learn?

print("=" * 60)
print("MODEL WEIGHTS - The numbers the network learned!")
print("=" * 60)

# Show all the layer names and their shapes
print("\nLayers in the model:\n")
for name, param in conv_model.named_parameters():
    print(f"  {name:30} shape: {str(list(param.shape)):20} ({param.numel():,} numbers)")

total_params = sum(p.numel() for p in conv_model.parameters())
print(f"\n  TOTAL: {total_params:,} learnable parameters!")

# Peek at actual values from the first conv layer
print("\n" + "=" * 60)
print("Zooming into conv1 weights (3x3 slice of the first filter):")
print("=" * 60)
conv1_weights = conv_model.conv1.weight.data[0, 0, :3, :3]  # First filter, 3x3 slice
print(conv1_weights)

print("\nThese numbers were RANDOM before training!")
print("Training adjusted them to recognize patterns in spectrograms.")

In [None]:
# save the model with today's datetime
import datetime
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

# Create config dict with all parameters needed for inference
config = {
    # Model architecture parameters
    "mel_freq_bins": mel_freq_bins,
    "time_steps": time_steps,
    "num_classes": 10,
    
    # Audio preprocessing parameters (needed to recreate the same spectrogram shape)
    "sample_rate": 22050,  # librosa default
    "longest_audio_file_length": longest_audio_file_length,  # in samples
    
    # MelSpectrogram params (torchaudio defaults)
    "n_mels": 128,
    "n_fft": 400,
    "hop_length": 512,
}

# Extract the TRUE test set file paths (files the model never saw during training!)
# test_dataset is now a direct AudioDataset with file_paths and labels attributes
test_file_paths = test_dataset.file_paths
test_file_labels = test_dataset.labels

print(f"Saving {len(test_file_paths)} test set file paths (held out from training)")

# Save config, model weights, AND test set info
checkpoint = {
    "config": config,
    "model_state_dict": conv_model.state_dict(),
    "test_file_paths": test_file_paths,
    "test_file_labels": test_file_labels,
}

#make dir called model_weights
os.makedirs('model_weights', exist_ok=True) 
saved_model_path = f'model_weights/audrey_model_weights_{timestamp}.pth'

t.save(checkpoint, saved_model_path)
print(f"Saved model checkpoint to: {saved_model_path}")
print(f"(The inference cell will automatically use this path)")
print(f"Config: {config}")

### What Did We Just Save?

The cell above created a **checkpoint** file containing:

- **Model weights**: The learned parameters (millions of numbers the network learned during training)
- **Config**: Architecture details (input dimensions, etc.) needed to reconstruct the model
- **Test set paths**: The exact files held out from training, for fair evaluation later

This checkpoint lets us load the trained model after restarting the notebook - no need to retrain!

---

## Part 9: Using the Trained Model

Everything below this point can be run **after restarting the notebook**. You just need to run the inference cell to load your trained model.

In [None]:
# ============================================================
# INFERENCE ONLY - Run this cell after restart to load model
# ============================================================
# This cell is self-contained and can be run after clearing 
# all variables or restarting the notebook kernel.

import torch as t
import torch.nn as nn
from torchaudio import transforms
import torchaudio

# 1. Define the model architecture (must match training)
class ConvModel(nn.Module):
    def __init__(self, mel_freq_bins, time_steps, num_classes=10):
        super(ConvModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.flat_features = 128 * (mel_freq_bins // 8) * (time_steps // 8)
        
        self.fc1 = nn.Linear(self.flat_features, 512)
        self.relu4 = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.pool3(self.relu3(self.conv3(x)))
        x = x.view(-1, self.flat_features)
        x = self.relu4(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# 2. Load checkpoint (contains both config and weights)
device = t.device('cuda' if t.cuda.is_available() else 'cpu')
# Use the path from the save cell if available, otherwise specify manually
if 'saved_model_path' not in dir():
    # Update this path if running after kernel restart
    saved_model_path = 'model_weights/audrey_model_weights_2026-02-01_18-00-30.pth'
print(f"Loading model from: {saved_model_path}")

checkpoint = t.load(saved_model_path, map_location=device)
config = checkpoint['config']

print(f"Loaded config: {config}")

# Load the true test set (files the model never saw during training)
test_file_paths = checkpoint.get('test_file_paths', [])
test_file_labels = checkpoint.get('test_file_labels', [])
print(f"Loaded {len(test_file_paths)} held-out test files")

# 3. Initialize model with saved config and load weights
model = ConvModel(
    mel_freq_bins=config['mel_freq_bins'],
    time_steps=config['time_steps'],
    num_classes=config['num_classes']
)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

# 4. Define preprocessing function
def preprocess_audio(audio_path, config):
    """Load and preprocess audio to match training data shape."""
    waveform, sr = torchaudio.load(audio_path)
    
    # Resample to match training sample rate (librosa default is 22050)
    target_sr = config['sample_rate']  # 22050
    if sr != target_sr:
        resampler = transforms.Resample(orig_freq=sr, new_freq=target_sr)
        waveform = resampler(waveform)
    
    # Ensure mono
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0).unsqueeze(0)
    
    # Pad or truncate to match training audio length
    target_length = config['longest_audio_file_length']
    current_length = waveform.shape[1]
    
    if current_length < target_length:
        # Pad (center padding like training)
        pad_size = target_length - current_length
        left_pad = pad_size // 2
        right_pad = pad_size - left_pad
        waveform = t.nn.functional.pad(waveform, (left_pad, right_pad))
    elif current_length > target_length:
        # Truncate (center crop)
        start = (current_length - target_length) // 2
        waveform = waveform[:, start:start + target_length]
    
    # Use default MelSpectrogram (matches AudioDataset training setup)
    mel_transform = transforms.MelSpectrogram()
    spec = mel_transform(waveform)
    
    return spec.unsqueeze(0)  # Add batch dimension

# 5. Define prediction function
def predict_digit(audio_path):
    """Predict the digit from an audio file."""
    spec = preprocess_audio(audio_path, config).to(device)
    
    with t.no_grad():
        output = model(spec)
        predicted = t.argmax(output, dim=1).item()
        confidence = t.softmax(output, dim=1)[0, predicted].item()
    
    return predicted, confidence

print(f"Model loaded successfully! Ready for inference.")
print(f"Use predict_digit('path/to/audio.wav') to make predictions.")

### Understanding the Inference Cell Above

The large cell above is **self-contained** - it can be run after restarting the notebook to load your trained model. It includes:

1. **Model architecture** (must match what was trained)
2. **Checkpoint loading** (weights + config)
3. **`preprocess_audio()`** - Prepares new audio to match training format
4. **`predict_digit()`** - Takes an audio file path, returns predicted digit

**Why preprocessing matters:** The model expects input in a very specific format (sample rate, length, spectrogram shape). If we preprocess differently than during training, the model will fail!

## Part 10: Evaluating Model Performance

### Why Test Sets Matter

The cells below demonstrate a critical ML concept: **proper evaluation**.

| Test Type | What It Tests | Potential Issue |
|-----------|---------------|-----------------|
| **True Test Set** | Files NEVER seen during training | Fair evaluation |
| **Random Files** | Any files from `processed/` | May include training data! |

**Overfitting** happens when a model memorizes training examples instead of learning general patterns. Signs of overfitting:
- High accuracy on training data
- Lower accuracy on truly new data (test set)

If random files show much higher accuracy than the true test set, your model is overfitting!

In [None]:
# Test on the TRUE HELD-OUT TEST SET (model never saw these during training!)
import random

if not test_file_paths:
    print("No test set found in checkpoint. Re-run training with the updated save cell.")
else:
    # Test on a sample of the held-out test set
    num_tests = min(20, len(test_file_paths))
    test_indices = random.sample(range(len(test_file_paths)), num_tests)
    correct = 0
    
    print(f"Testing on {num_tests} files from the TRUE HELD-OUT TEST SET")
    print(f"(Total held-out files: {len(test_file_paths)})")
    print("=" * 50)
    print("These files were NEVER seen during training!\n")
    
    for i, idx in enumerate(test_indices):
        test_file = test_file_paths[idx]
        true_label = test_file_labels[idx]
        predicted_digit, confidence = predict_digit(test_file)
        
        is_correct = predicted_digit == true_label
        correct += is_correct
        
        status = "correct" if is_correct else "WRONG"
        print(f"{i+1}. True={true_label}, Pred={predicted_digit}, Conf={confidence:.1%} [{status}]")
    
    print("=" * 50)
    print(f"TRUE TEST SET Accuracy: {correct}/{num_tests} = {100*correct/num_tests:.1f}%")

In [None]:
# Test inference with a random file from the processed dataset
# (Works after kernel restart - just needs the inference cell to be run first)
import random
import glob

# Get all processed audio files
processed_files = glob.glob('processed/*/*')
print(f"Found {len(processed_files)} files in processed/")

# Pick a random file
test_file = random.choice(processed_files)

# Extract true label from folder name (e.g., "processed/3/3_123.wav" -> 3)
true_label = int(test_file.split('/')[-2])

# Make prediction
predicted_digit, confidence = predict_digit(test_file)

print(f"\nTest file: {test_file}")
print(f"True label: {true_label}")
print(f"Predicted: {predicted_digit}")
print(f"Confidence: {confidence:.2%}")
print(f"Correct: {'Yes' if predicted_digit == true_label else 'No'}")

In [None]:
# Test on RANDOM files from processed/ (MAY INCLUDE TRAINING DATA!)
# Compare this accuracy to the true test set above to see if the model is overfitting
import random
import glob

processed_files = glob.glob('processed/*/*')
num_tests = min(20, len(processed_files))
correct = 0

# Get random files
test_files = random.sample(processed_files, num_tests)

print(f"Testing on {num_tests} RANDOM files from processed/")
print("=" * 50)
print("WARNING: Some of these may have been in the training set!\n")

for i, test_file in enumerate(test_files):
    true_label = int(test_file.split('/')[-2])
    predicted_digit, confidence = predict_digit(test_file)
    
    is_correct = predicted_digit == true_label
    correct += is_correct
    
    status = "correct" if is_correct else "WRONG"
    print(f"{i+1}. True={true_label}, Pred={predicted_digit}, Conf={confidence:.1%} [{status}]")

print("=" * 50)
print(f"RANDOM FILES Accuracy: {correct}/{num_tests} = {100*correct/num_tests:.1f}%")
print("\nCompare this to the TRUE TEST SET accuracy above!")

In [None]:
# Test on 10 random files from the processed dataset
import glob
import random

processed_files = glob.glob('processed/*/*')
print(f"Found {len(processed_files)} files in processed/\n")

num_tests = 10
test_files = random.sample(processed_files, num_tests)
correct = 0

for i, test_file in enumerate(test_files):
    true_label = int(test_file.split('/')[-2])
    predicted_digit, confidence = predict_digit(test_file)
    
    is_correct = predicted_digit == true_label
    correct += is_correct
    
    status = "correct" if is_correct else "WRONG"
    print(f"{i+1}. {test_file.split('/')[-1]}: True={true_label}, Pred={predicted_digit}, Conf={confidence:.1%} [{status}]")

print(f"\nAccuracy: {correct}/{num_tests} = {100*correct/num_tests:.1f}%")

In [None]:
# Test on ALL original recordings from the unprocessed dataset
import glob

unprocessed_files = sorted(glob.glob('unprocessed/*.wav'))
print(f"Found {len(unprocessed_files)} original recordings in unprocessed/\n")

correct = 0

for test_file in unprocessed_files:
    # Extract true label from filename (e.g., "0.wav" -> 0)
    true_label = int(test_file.split('/')[-1].split('.')[0])
    predicted_digit, confidence = predict_digit(test_file)
    
    is_correct = predicted_digit == true_label
    correct += is_correct
    
    status = "correct" if is_correct else "WRONG"
    print(f"{test_file}: True={true_label}, Pred={predicted_digit}, Conf={confidence:.1%} [{status}]")

print(f"\nAccuracy on original recordings: {correct}/{len(unprocessed_files)} = {100*correct/len(unprocessed_files):.1f}%")

## Part 11: Real-time Digit Recognition

Now for the fun part - using your trained model in real-time!

### Critical Insight: Processing Pipeline Must Match Training

For real-time recognition to work accurately, the audio must be processed **exactly the same way** as during training. This means:

1. **Save to file** (even though it adds a bit of latency)
2. **Load with librosa** (automatically resamples to 22050 Hz, just like training)
3. **Pad to target length** (center padding, same as preprocessing)
4. **Save padded file, then load with torchaudio** (same as AudioDataset)
5. **Apply MelSpectrogram** (same transform as training)

This file-based approach ensures perfect consistency between training and inference!

### Two Working Approaches

We provide two real-time recognition methods below:

| Method | How it works | Best for |
|--------|--------------|----------|
| **Enter-to-Record** | Press Enter, speak for 1 second, get prediction | Simple, reliable |
| **Buffered Onset Detection** | Hands-free, auto-triggers when you speak | More interactive |

**Requirements:** Run the inference cell first to load the model!

In [None]:
# =============================================================================
# ENTER-TO-RECORD: Simple, reliable real-time recognition
# =============================================================================
# Press Enter to record for 1 second, then get a prediction.
# Uses the EXACT same pipeline as training for maximum accuracy.

import sounddevice as sd
import soundfile as sf
import numpy as np
import torch as t
import torchaudio
import librosa
import os

# Temporary files for the file-based pipeline
TEMP_RAW = '_temp_recording.wav'
TEMP_PADDED = '_temp_padded.wav'

# Recording config - 44100 Hz is common for most microphones
RECORD_SAMPLE_RATE = 44100
RECORD_DURATION = 1.0  # 1 second, matching our training data

def predict_from_recording(raw_audio):
    """
    Process audio through the EXACT same pipeline as training data.
    This is the key to accurate real-time recognition!
    """
    # Step 1: Save raw recording to file
    sf.write(TEMP_RAW, raw_audio, RECORD_SAMPLE_RATE, subtype='PCM_16')
    
    # Step 2: Load with librosa (resamples to 22050 Hz - same as training!)
    audio, sr = librosa.load(TEMP_RAW)  # Default sr=22050
    
    # Step 3: Pad to target length (center padding - same as preprocessing!)
    target_length = config['longest_audio_file_length']
    current_length = len(audio)
    
    if current_length < target_length:
        pad_size = target_length - current_length
        left_pad = pad_size // 2
        right_pad = pad_size - left_pad
        audio = np.pad(audio, (left_pad, right_pad), mode='constant')
    elif current_length > target_length:
        # Center crop if too long
        start = (current_length - target_length) // 2
        audio = audio[start:start + target_length]
    
    # Step 4: Save padded audio and reload with torchaudio (like AudioDataset!)
    sf.write(TEMP_PADDED, audio, sr)
    waveform, loaded_sr = torchaudio.load(TEMP_PADDED)
    
    # Step 5: Apply MelSpectrogram (same as AudioDataset!)
    mel = torchaudio.transforms.MelSpectrogram()
    spec = mel(waveform)
    
    # Step 6: Add batch dimension and move to device
    spec = spec.unsqueeze(0).to(device)
    
    # Step 7: Run inference
    model.eval()
    with t.no_grad():
        output = model(spec)
        predicted = t.argmax(output, dim=1).item()
        confidence = t.softmax(output, dim=1)[0, predicted].item()
    
    return predicted, confidence

# Main loop
print("=" * 50)
print("ENTER-TO-RECORD DIGIT RECOGNITION")
print("=" * 50)
print("Instructions:")
print("  1. Press Enter when ready to record")
print("  2. Say a digit clearly")
print(f"  3. Recording lasts {RECORD_DURATION} seconds")
print("  4. Type 'q' to quit")
print("=" * 50)

try:
    while True:
        user_input = input("\nPress Enter to record (or 'q' to quit): ")
        if user_input.lower() == 'q':
            break
        
        # Record audio
        print("Recording... Say a digit NOW!")
        recording = sd.rec(
            int(RECORD_DURATION * RECORD_SAMPLE_RATE),
            samplerate=RECORD_SAMPLE_RATE,
            channels=1,
            dtype='float32'
        )
        sd.wait()
        
        # Process and predict
        print("Processing...")
        recording = recording.flatten()
        predicted, confidence = predict_from_recording(recording)
        
        print(f"\n  Predicted: {predicted}")
        print(f"  Confidence: {confidence:.1%}")
        
except KeyboardInterrupt:
    pass
finally:
    # Clean up temp files
    for f in [TEMP_RAW, TEMP_PADDED]:
        if os.path.exists(f):
            os.remove(f)
    print("\nGoodbye!")

def process_and_predict(audio_data):
    """Process accumulated audio and run prediction."""
    # Normalize audio (peak normalization to match training data levels)
    audio_data = audio_data / (np.max(np.abs(audio_data)) + 1e-8)
    audio_data = audio_data * 0.9  # Scale to ~90% to avoid clipping
    
    # Convert to tensor
    waveform = t.tensor(audio_data, dtype=t.float32).unsqueeze(0)
    
    # Resample to match training
    resampler = transforms.Resample(orig_freq=SAMPLE_RATE, new_freq=TARGET_SAMPLE_RATE)
    waveform = resampler(waveform)
    
    # Pad or truncate to match training audio length
    target_length = config['longest_audio_file_length']
    current_length = waveform.shape[1]
    
    if current_length < target_length:
        pad_size = target_length - current_length
        left_pad = pad_size // 2
        right_pad = pad_size - left_pad
        waveform = t.nn.functional.pad(waveform, (left_pad, right_pad))
    elif current_length > target_length:
        start = (current_length - target_length) // 2
        waveform = waveform[:, start:start + target_length]
    
    # Compute mel spectrogram (using defaults to match training)
    mel_transform = transforms.MelSpectrogram()
    spec = mel_transform(waveform).unsqueeze(0).to(device)
    
    # Run inference
    with t.no_grad():
        output = model(spec)
        predicted = t.argmax(output, dim=1).item()
        confidence = t.softmax(output, dim=1)[0, predicted].item()
    
    return predicted, confidence

def audio_callback(indata, frames, time_info, status):
    """Called for each block of audio from the microphone."""
    global audio_buffer, is_recording, silence_samples
    
    if status:
        print(f"Status: {status}")
    
    # Calculate RMS volume for this block
    audio_block = indata[:, 0]
    rms = np.sqrt(np.mean(audio_block**2))
    
    if not is_recording:
        # Waiting for speech to start
        if rms > VOLUME_THRESHOLD:
            is_recording = True
            silence_samples = 0
            audio_buffer = list(audio_block)  # Start with this block
            clear_output(wait=True)
            print("Recording... (speak your digit)")
    else:
        # Currently recording
        audio_buffer.extend(audio_block.tolist())
        
        if rms < VOLUME_THRESHOLD:
            silence_samples += 1
        else:
            silence_samples = 0  # Reset silence counter if sound detected
        
        # Check if we should stop recording
        should_stop = False
        
        if len(audio_buffer) >= max_samples:
            should_stop = True  # Hit max duration
        elif silence_samples >= silence_samples_threshold and len(audio_buffer) >= min_samples:
            should_stop = True  # Silence detected after minimum recording
        
        if should_stop:
            # Process the recording
            audio_data = np.array(audio_buffer)
            duration = len(audio_data) / SAMPLE_RATE
            
            predicted, confidence = process_and_predict(audio_data)
            
            clear_output(wait=True)
            print("=" * 40)
            print(f"  Predicted Digit:  {predicted}")
            print(f"  Confidence:       {confidence:.1%}")
            print(f"  Audio duration:   {duration:.2f}s")
            print("=" * 40)
            print(f"\nListening... (speak a digit to start)")
            
            # Reset state
            audio_buffer = []
            is_recording = False
            silence_samples = 0

# Start listening
print("Real-time Digit Recognition (Voice Activated)")
print("=" * 40)
print(f"Volume threshold: {VOLUME_THRESHOLD} RMS")
print(f"Silence timeout:  {SILENCE_DURATION}s")
print(f"Max recording:    {MAX_RECORDING_DURATION}s")
print("=" * 40)
print("\nListening... (speak a digit to start)")
print("\n>>> Click the STOP button (square) or press 'i' twice to stop <<<")

stream = None
try:
    stream = sd.InputStream(
        samplerate=SAMPLE_RATE,
        blocksize=BLOCK_SIZE,
        channels=1,
        callback=audio_callback
    )
    stream.start()
    
    # Loop with short sleeps - much easier to interrupt
    while True:
        sd.sleep(100)  # 100ms intervals - responsive to interrupts
        
except KeyboardInterrupt:
    pass
finally:
    if stream is not None:
        stream.stop()
        stream.close()
    audio_buffer = []
    is_recording = False
    silence_samples = 0
    print("\nStopped!")

In [None]:
# =============================================================================
# BUFFERED ONSET DETECTION: Hands-free automatic recognition
# =============================================================================
# Listens continuously and auto-triggers when you speak.
# Uses a rolling buffer to capture the START of your speech (not just the middle!).

import sounddevice as sd
import soundfile as sf
import numpy as np
import torch as t
import torchaudio
import librosa
import os
from collections import deque
from IPython.display import clear_output

# Temporary files for the file-based pipeline
TEMP_RAW = '_temp_onset_recording.wav'
TEMP_PADDED = '_temp_onset_padded.wav'

# Audio settings
SAMPLE_RATE = 44100
BLOCK_SIZE = 1024

# Onset detection settings - ADJUST FOR YOUR ENVIRONMENT
# If too sensitive (triggers on noise): increase ONSET_THRESHOLD to 0.05 or higher
# If not sensitive enough (misses words): decrease to 0.02
ONSET_THRESHOLD = 0.04   # RMS level to detect speech (default: 0.04, noisy rooms: 0.06-0.08)
RECORD_AFTER_ONSET = 1.0 # Seconds to record after onset detected
BUFFER_SECONDS = 0.5     # Seconds of pre-onset audio to keep (captures word start!)
COOLDOWN_SECONDS = 0.5   # Wait this long after prediction before listening again

# Calculate buffer sizes
buffer_blocks = int(BUFFER_SECONDS * SAMPLE_RATE / BLOCK_SIZE)
record_samples_after_onset = int(RECORD_AFTER_ONSET * SAMPLE_RATE)
cooldown_blocks = int(COOLDOWN_SECONDS * SAMPLE_RATE / BLOCK_SIZE)

# Rolling buffer to capture audio BEFORE onset
rolling_buffer = deque(maxlen=buffer_blocks)

# State
is_recording = False
recorded_audio = []
samples_recorded = 0
cooldown_remaining = 0  # Blocks to wait before listening again

def predict_from_audio(audio_data):
    """
    Process audio through the EXACT same pipeline as training data.
    """
    # Step 1: Save raw recording
    sf.write(TEMP_RAW, audio_data, SAMPLE_RATE, subtype='PCM_16')
    
    # Step 2: Load with librosa (resamples to 22050 Hz)
    audio, sr = librosa.load(TEMP_RAW)
    
    # Step 3: Pad to target length
    target_length = config['longest_audio_file_length']
    current_length = len(audio)
    
    if current_length < target_length:
        pad_size = target_length - current_length
        left_pad = pad_size // 2
        right_pad = pad_size - left_pad
        audio = np.pad(audio, (left_pad, right_pad), mode='constant')
    elif current_length > target_length:
        start = (current_length - target_length) // 2
        audio = audio[start:start + target_length]
    
    # Step 4: Save and reload with torchaudio
    sf.write(TEMP_PADDED, audio, sr)
    waveform, _ = torchaudio.load(TEMP_PADDED)
    
    # Step 5: MelSpectrogram
    mel = torchaudio.transforms.MelSpectrogram()
    spec = mel(waveform).unsqueeze(0).to(device)
    
    # Step 6: Inference
    model.eval()
    with t.no_grad():
        output = model(spec)
        predicted = t.argmax(output, dim=1).item()
        confidence = t.softmax(output, dim=1)[0, predicted].item()
    
    return predicted, confidence

def audio_callback(indata, frames, time_info, status):
    """Called for each block of audio from the microphone."""
    global is_recording, recorded_audio, samples_recorded, rolling_buffer, cooldown_remaining
    
    audio_block = indata[:, 0].copy()
    rms = np.sqrt(np.mean(audio_block**2))
    
    # Handle cooldown period (prevents rapid re-triggering)
    if cooldown_remaining > 0:
        cooldown_remaining -= 1
        return
    
    if not is_recording:
        # Keep audio in rolling buffer (captures pre-onset audio!)
        rolling_buffer.append(audio_block)
        
        # Check for onset
        if rms > ONSET_THRESHOLD:
            is_recording = True
            # Include the buffer contents (this captures the word START!)
            recorded_audio = list(np.concatenate(list(rolling_buffer)))
            recorded_audio.extend(audio_block.tolist())
            samples_recorded = len(recorded_audio)
            clear_output(wait=True)
            print("Recording... (detected speech)")
    else:
        # Continue recording
        recorded_audio.extend(audio_block.tolist())
        samples_recorded += len(audio_block)
        
        # Check if we've recorded enough
        if samples_recorded >= record_samples_after_onset:
            # Process the recording
            audio_data = np.array(recorded_audio, dtype=np.float32)
            
            try:
                predicted, confidence = predict_from_audio(audio_data)
                
                clear_output(wait=True)
                print("=" * 50)
                print(f"  Predicted Digit: {predicted}")
                print(f"  Confidence: {confidence:.1%}")
                print("=" * 50)
                print("\nListening... (speak a digit)")
            except Exception as e:
                clear_output(wait=True)
                print(f"Error: {e}")
                print("\nListening... (speak a digit)")
            
            # Reset state with cooldown to prevent immediate re-trigger
            is_recording = False
            recorded_audio = []
            samples_recorded = 0
            cooldown_remaining = cooldown_blocks  # Wait before listening again

# Start listening
print("=" * 50)
print("BUFFERED ONSET DETECTION")
print("=" * 50)
print(f"Onset threshold: {ONSET_THRESHOLD} RMS")
print(f"  (If glitchy/too sensitive, increase to 0.06 or 0.08)")
print(f"Pre-onset buffer: {BUFFER_SECONDS}s")
print(f"Recording duration: {RECORD_AFTER_ONSET}s after onset")
print(f"Cooldown between detections: {COOLDOWN_SECONDS}s")
print("=" * 50)
print("\nListening... (speak a digit)")
print("\n>>> Press STOP (square button) or Kernel > Interrupt to stop <<<")

stream = None
try:
    stream = sd.InputStream(
        samplerate=SAMPLE_RATE,
        blocksize=BLOCK_SIZE,
        channels=1,
        callback=audio_callback
    )
    stream.start()
    
    while True:
        sd.sleep(100)  # Short sleep for responsive interrupts
        
except KeyboardInterrupt:
    pass
finally:
    if stream is not None:
        stream.stop()
        stream.close()
    # Clean up temp files and reset state
    for f in [TEMP_RAW, TEMP_PADDED]:
        if os.path.exists(f):
            os.remove(f)
    rolling_buffer.clear()
    is_recording = False
    recorded_audio = []
    samples_recorded = 0
    cooldown_remaining = 0
    print("\nStopped!")

## Exercises and Next Steps

Now that you've built your own Audrey, try these challenges:

### Exploration
1. **Test robustness**: What happens if you say a digit in a different accent or tone?
2. **Out-of-vocabulary**: What does the model predict when you say a word that isn't a digit?
3. **Different device**: Play a digit from your phone speaker - does it still work?

### Extensions
4. **New vocabulary**: Train on different sounds (you might need to change how the files get saved and processed in the dataset creation step):
   - Yes / No
   - Colors (red, blue, green...)
   - Letters (A, B, C...)
   - Environmental sounds (clap, snap, whistle...)
   
5. **Improve the model**: 
   - Record more original samples
   - Adjust augmentation parameters
   - Try different network architectures

6. **Investigate bias**: 
   - Train on recordings from multiple people
   - Test in different acoustic environments
   - What makes a "fair" speech recognition dataset?

---

## Summary

In this notebook you learned:
- **Data collection** and the importance of representative training data
- **Data augmentation** to create variations and improve robustness
- **Neural network architectures**: MLP vs CNN for audio classification
- **Training loop**: forward pass → loss → backpropagation → optimization
- **Evaluation**: Why held-out test sets matter
- **Inference**: Using trained models on new data
- **Dataset bias**: How recording conditions affect model performance

Congratulations - you've built a speech recognition system from scratch!

---

## Appendix: Debugging Real-time Recognition

If real-time recognition isn't working well, the cells below can help diagnose why.

### Common Issues

| Symptom | Likely Cause | Solution |
|---------|--------------|----------|
| Always predicts same digit | Sample rate mismatch | Use file-based pipeline |
| Very low confidence | Spectrogram shape mismatch | Check MelSpectrogram params |
| Works on files, fails live | Processing pipeline differs from training | Use the file-save/load approach |
| Inconsistent predictions | Background noise | Increase onset threshold |

### The Golden Rule

**If file-based evaluation works (100% on test set) but real-time fails, the problem is in the real-time preprocessing, not the model.**

The diagnostic cells below help verify that real-time audio is being processed identically to training data.

In [None]:
# =============================================================================
# DIAGNOSTIC: Compare Real-time vs Training Pipeline
# =============================================================================
# This cell helps you verify that real-time audio is processed the same way
# as training data. If both predict the same digit, your pipeline is correct!

import sounddevice as sd
import soundfile as sf
import numpy as np
import torch as t
import torchaudio
import librosa
import os

TEMP_FILE = '_diagnostic_recording.wav'
TEMP_PADDED = '_diagnostic_padded.wav'

print("DIAGNOSTIC: Recording 2 seconds... Say a digit!")
recording = sd.rec(int(2 * 44100), samplerate=44100, channels=1, dtype='float32')
sd.wait()
recording = recording.flatten()
print("Recording done!\n")

# Save and process through file-based pipeline (like training)
sf.write(TEMP_FILE, recording, 44100, subtype='PCM_16')

# Step 1: Load with librosa (this is what training uses!)
audio, sr = librosa.load(TEMP_FILE)  # Resamples to 22050 by default
print(f"After librosa.load: SR={sr}, length={len(audio)}")

# Step 2: Pad to target length
target_length = config['longest_audio_file_length']
if len(audio) < target_length:
    pad_size = target_length - len(audio)
    audio = np.pad(audio, (pad_size // 2, pad_size - pad_size // 2), mode='constant')
elif len(audio) > target_length:
    start = (len(audio) - target_length) // 2
    audio = audio[start:start + target_length]
print(f"After padding: length={len(audio)} (target={target_length})")

# Step 3: Save and reload with torchaudio (like AudioDataset)
sf.write(TEMP_PADDED, audio, sr)
waveform, loaded_sr = torchaudio.load(TEMP_PADDED)
print(f"After torchaudio.load: shape={waveform.shape}, SR={loaded_sr}")

# Step 4: Create spectrogram
mel = torchaudio.transforms.MelSpectrogram()
spec = mel(waveform)
print(f"Spectrogram shape: {spec.shape}")
print(f"Expected shape: [1, 128, {config['time_steps']}]")

# Step 5: Run inference
spec_batch = spec.unsqueeze(0).to(device)
model.eval()
with t.no_grad():
    output = model(spec_batch)
    predicted = t.argmax(output, dim=1).item()
    confidence = t.softmax(output, dim=1)[0, predicted].item()

print(f"\n{'=' * 50}")
print(f"PREDICTION: {predicted} (confidence: {confidence:.1%})")
print(f"{'=' * 50}")

# Clean up
for f in [TEMP_FILE, TEMP_PADDED]:
    if os.path.exists(f):
        os.remove(f)

print("\nIf this prediction is correct but real-time recognition is wrong,")
print("then the real-time cell is not using the same processing pipeline.")

In [None]:
# =============================================================================
# DIAGNOSTIC: Visualize Training vs Live Audio Spectrograms
# =============================================================================
# Compare spectrograms from a training file vs your live recording.
# They should look similar in structure (frequency patterns, energy distribution).

import matplotlib.pyplot as plt
import glob
import random

# Get a random training file
processed_files = glob.glob('processed/*/*')
random_file = random.choice(processed_files)
true_label = int(random_file.split('/')[-2])

# Load training file spectrogram (exactly how the model sees it)
train_waveform, _ = torchaudio.load(random_file)
train_spec = torchaudio.transforms.MelSpectrogram()(train_waveform)

print(f"Training file: {random_file}")
print(f"Training spec shape: {train_spec.shape}")

# Create comparison plot
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Training spectrogram
axes[0].imshow(train_spec[0].numpy(), aspect='auto', origin='lower', cmap='viridis')
axes[0].set_title(f'Training File (digit {true_label})')
axes[0].set_xlabel('Time frames')
axes[0].set_ylabel('Mel frequency bins')

# Instructions for live comparison
axes[1].text(0.5, 0.5, 'Run the diagnostic cell above\nto compare with live audio', 
             ha='center', va='center', fontsize=12)
axes[1].set_title('Live Recording (compare here)')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print("\nTip: The frequency patterns (horizontal bands) should be similar between")
print("training and live audio. Big differences indicate a preprocessing mismatch.")