In [8]:
import numpy as np
import os
import pandas as pd
import random
import torch
import torchaudio
import torchaudio.transforms as T

from time import time
from torch import nn
from torch.utils.data import DataLoader

from msc_dataset import MSCDataset

In [9]:
# Device setup for Mac M4 Pro (MPS), CUDA (NVIDIA), or CPU fallback
if torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
elif torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')

print(f"Using device: {DEVICE}")

Using device: mps


In [10]:
# ==================== CONFIGURATION ====================
# CFG = {
#     'sampling_rate': 16000,
#     'frame_length_in_s': 0.04,
#     'frame_step_in_s': 0.02,
#     'n_mels': 40,
#     'f_min': 0,
#     'f_max': 8000,
#     'seed': 0,
#     'train_steps': 2000,
#     'train_batch_size': 128,
#     'learning_rate': 0.001,
#     'epochs': 10,
# }
CFG = {
    'sampling_rate': 16000,
    'frame_length_in_s': 0.04,
    'frame_step_in_s': 0.02,
    'n_mels': 40,
    'f_min': 80,  # ‚Üê CAMBIA
    'f_max': 4000,  # ‚Üê CAMBIA
    'seed': 0,
    'train_steps': 4000,  # ‚Üê CAMBIA
    'train_batch_size': 128,
    'learning_rate': 0.001,
    'epochs': 25,  # ‚Üê CAMBIA
}

# Define the set of target classes
CLASSES = ['stop', 'up']

# Set Deterministic Behaviour
torch.manual_seed(CFG['seed'])
np.random.seed(CFG['seed'])
random.seed(CFG['seed'])

In [11]:
# ==================== MEL-SPECTROGRAM FEATURE EXTRACTOR ====================
class MelSpectrogramExtractor(nn.Module):
    """ONNX-compatible Mel-Spectrogram feature extractor"""
    def __init__(self):
        super().__init__()
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=CFG['sampling_rate'],
            n_fft=int(CFG['frame_length_in_s'] * CFG['sampling_rate']),
            hop_length=int(CFG['frame_step_in_s'] * CFG['sampling_rate']),
            n_mels=CFG['n_mels'],
            f_min=CFG['f_min'],
            f_max=CFG['f_max'],
            window_fn=torch.hann_window,
            power=2.0,
            normalized=False,
            center=True,
            pad_mode="reflect"
        )
        
    def forward(self, waveform):
        # waveform: (batch, samples)
        mel_spec = self.mel_transform(waveform)  # (batch, n_mels, time)
        
        # Log scale
        log_mel = torch.log(mel_spec + 1e-9)
        
        # Normalize to [-1, 1] range (per-sample)
        log_mel = (log_mel - log_mel.mean(dim=[1, 2], keepdim=True)) / (log_mel.std(dim=[1, 2], keepdim=True) + 1e-9)
        
        return log_mel.unsqueeze(1)  # (batch, 1, n_mels, time)

In [12]:
# ==================== CNN MODEL ====================
class KeywordSpotter(nn.Module):
    """Optimized CNN for Up/Stop classification"""
    def __init__(self, num_classes=2):
        super().__init__()
        
        # Block 1: 1 ‚Üí 64 channels
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU(inplace=True)
        
        # Block 2: 64 ‚Üí 64 channels, downsample
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=True)
        
        # Block 3: 64 ‚Üí 128 channels, downsample
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU(inplace=True)
        
        # Block 4: 128 ‚Üí 128 channels
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(128)
        self.relu4 = nn.ReLU(inplace=True)
        
        # Global Average Pooling
        self.gap = nn.AdaptiveAvgPool2d(1)
        
        # Classifier
        self.fc = nn.Linear(128, num_classes, bias=True)
        
    def forward(self, x):
        # x: (batch, 1, 40, 49)
        x = self.relu1(self.bn1(self.conv1(x)))  # (batch, 64, 40, 49)
        x = self.relu2(self.bn2(self.conv2(x)))  # (batch, 64, 20, 25)
        x = self.relu3(self.bn3(self.conv3(x)))  # (batch, 128, 10, 13)
        x = self.relu4(self.bn4(self.conv4(x)))  # (batch, 128, 10, 13)
        
        x = self.gap(x)  # (batch, 128, 1, 1)
        x = x.view(x.size(0), -1)  # (batch, 128)
        x = self.fc(x)  # (batch, 2)
        
        return x

In [None]:
# ==================== MAIN TRAINING PIPELINE ====================
print("=" * 60)
print("UP/STOP KEYWORD SPOTTER - TRAINING PIPELINE")
print("=" * 60)
print(f"Device: {DEVICE}")
print(f"Mel-Spectrogram config: n_mels={CFG['n_mels']}, n_fft={int(CFG['frame_length_in_s'] * CFG['sampling_rate'])}, hop={int(CFG['frame_step_in_s'] * CFG['sampling_rate'])}")
print(f"Training config: epochs={CFG['epochs']}, batch_size={CFG['train_batch_size']}, lr={CFG['learning_rate']}")
print("=" * 60)

# Create Mel-Spectrogram transform
transform = MelSpectrogramExtractor()

# Create datasets
print("\nüìÅ Loading datasets...")
train_dataset = MSCDataset(
    root='.',
    classes=CLASSES,
    split='training',
    preprocess=None,
)

val_dataset = MSCDataset(
    root='.',
    classes=CLASSES,
    split='validation',
    preprocess=None,
)

test_dataset = MSCDataset(
    root='.',
    classes=CLASSES,
    split='testing',
    preprocess=None,
)

# Create dataloaders with RandomSampler for training
sampler = torch.utils.data.RandomSampler(
    train_dataset,
    replacement=True,
    num_samples=CFG['train_steps'] * CFG['train_batch_size'],
)
train_loader = DataLoader(
    train_dataset,
    batch_size=CFG['train_batch_size'],
    sampler=sampler,
    num_workers=0,  # Set to 0 for macOS compatibility
)

val_loader = DataLoader(val_dataset, batch_size=100, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=100, num_workers=0)

# Initialize models
print("\nüèóÔ∏è  Initializing models...")
feature_extractor = MelSpectrogramExtractor().to(DEVICE)
model = KeywordSpotter(num_classes=len(CLASSES)).to(DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
print(f"Estimated size (float32): {total_params * 4 / 1024:.2f} KB")
print(f"Estimated size (int8): {total_params / 1024:.2f} KB")

# Loss and optimizer
loss_module = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=CFG['learning_rate'])

# Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=CFG['epochs'], eta_min=1e-5
)

# Validation evaluation function
def evaluate(model, feature_extractor, loader, device):
    model.eval()
    correct = total = 0
    total_loss = 0
    loss_module = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch in loader:
            x = batch['x'].squeeze(1).to(device)
            y = batch['y'].to(device)
            features = feature_extractor(x)
            output = model(features)
            predictions = output.argmax(dim=1)
            correct += (predictions == y).sum().item()
            total += y.size(0)
            total_loss += loss_module(output, y).item()
    
    accuracy = (correct / total) * 100
    avg_loss = total_loss / len(loader)
    return accuracy, avg_loss

# Early stopping setup
best_val_acc = 0
best_val_loss = float('inf')
patience = 10
patience_counter = 0
best_model_state = None

# Training loop
print("\nüöÄ Starting training...")
model.train()

steps_per_epoch = len(train_loader) // CFG['epochs']
current_epoch = 0

for step, batch in enumerate(train_loader):
    x = batch['x'].squeeze(1).to(DEVICE)  # Remove channel dim from waveform
    y = batch['y'].to(DEVICE)
    
    # Extract features
    with torch.no_grad():
        features = feature_extractor(x)
    
    # Forward pass
    output = model(features)
    loss = loss_module(output, y)
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if ((step + 1) % 100) == 0 or step == 0:
        print(f'Step={step}; Training Loss={loss.item():.3f}')
    
    # Check if epoch ended
    if (step + 1) % steps_per_epoch == 0:
        current_epoch += 1
        
        # Evaluate on validation set
        val_acc, val_loss = evaluate(model, feature_extractor, val_loader, DEVICE)
        print(f'\nüìä Epoch {current_epoch}: Val Acc={val_acc:.2f}%, Val Loss={val_loss:.4f}')
        
        # Early stopping check
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_val_loss = val_loss
            best_model_state = model.state_dict().copy()
            patience_counter = 0
            print(f'‚úÖ New best model! Val Acc={val_acc:.2f}%')
        else:
            patience_counter += 1
            print(f'‚ö†Ô∏è  No improvement. Patience: {patience_counter}/{patience}')
        
        if patience_counter >= patience:
            print(f'\nüõë Early stopping at epoch {current_epoch}')
            break
        
        # Learning rate step
        scheduler.step()
        model.train()

# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print(f'\n‚úÖ Loaded best model with Val Acc={best_val_acc:.2f}%')

# Evaluation
print("\nüìä Evaluating model on test set...")
model.eval()

correct = 0
total = 0

with torch.no_grad():
    for batch in test_loader:
        x = batch['x'].squeeze(1).to(DEVICE)
        y = batch['y'].to(DEVICE)
        
        # Extract features
        features = feature_extractor(x)
        
        output = model(features)
        predictions = output.argmax(dim=1)
        
        correct += (predictions == y).sum().item()
        total += y.size(0)

test_accuracy = (correct / total) * 100
print(f'\nüéØ Test Accuracy: {test_accuracy:.2f}%')

if test_accuracy > 99.4:
    print("‚úÖ PASSED: Accuracy > 99.4%")
else:
    print("‚ùå FAILED: Accuracy <= 99.4%")


UP/STOP KEYWORD SPOTTER - TRAINING PIPELINE
Device: mps
Mel-Spectrogram config: n_mels=40, n_fft=640, hop=320
Training config: epochs=25, batch_size=128, lr=0.001

üìÅ Loading datasets...
Using data folder: ./msc-training
Loaded 1600 samples from ./msc-training for classes ['stop', 'up']
Using data folder: ./msc-validation
Loaded 200 samples from ./msc-validation for classes ['stop', 'up']
Using data folder: ./msc-testing
Loaded 200 samples from ./msc-testing for classes ['stop', 'up']

üèóÔ∏è  Initializing models...
Model parameters: 259,650
Estimated size (float32): 1014.26 KB
Estimated size (int8): 253.56 KB

üöÄ Starting training...
Step=0; Training Loss=0.670
Step=99; Training Loss=0.161
Step=99; Training Loss=0.161

üìä Epoch 1: Val Acc=75.00%, Val Loss=0.6960
‚úÖ New best model! Val Acc=75.00%

üìä Epoch 1: Val Acc=75.00%, Val Loss=0.6960
‚úÖ New best model! Val Acc=75.00%
Step=199; Training Loss=0.049
Step=199; Training Loss=0.049
Step=299; Training Loss=0.035
Step=299; Tr

In [14]:
# ==================== SAVE MODEL ====================
print("\n" + "="*60)
print("SAVING MODEL")
print("="*60)

timestamp = int(time())
saved_model_dir = './saved_models/'
if not os.path.exists(saved_model_dir):
    os.makedirs(saved_model_dir)

print(f'Model Timestamp: {timestamp}')

model.eval()
feature_extractor.eval()

# Move models to CPU for ONNX export (MPS not supported for export)
print("\nüîÑ Moving models to CPU for ONNX export...")
model_cpu = model.cpu()
feature_extractor_cpu = feature_extractor.cpu()

# Export Feature Extractor to ONNX
print("\nüì¶ Exporting Feature Extractor to ONNX...")
torch.onnx.export(
    feature_extractor_cpu,  # model to export
    torch.randn(1, 16000),  # inputs of the model (waveform)
    f'{saved_model_dir}/{timestamp}_frontend.onnx',  # filename of the ONNX model
    input_names=['input'],  # input name in the ONNX model
    dynamo=True,
    optimize=True,
    report=False,
    external_data=False,
)
print(f"‚úÖ Feature extractor saved: {saved_model_dir}/{timestamp}_frontend.onnx")

# Export Model to ONNX
print("\nüì¶ Exporting Model to ONNX...")
# Get a sample waveform from training dataset and extract features
sample_waveform = train_dataset[0]['x'].squeeze(0).unsqueeze(0).cpu()  # (1, 16000)
sample_features = feature_extractor_cpu(sample_waveform)  # (1, 1, n_mels, time)
torch.onnx.export(
    model_cpu,  # model to export
    sample_features,  # inputs of the model (mel-spectrogram features)
    f'{saved_model_dir}/{timestamp}_model.onnx',  # filename of the ONNX model
    input_names=['input'],  # input name in the ONNX model
    dynamo=True,
    optimize=True,
    report=False,
    external_data=False,
)
print(f"‚úÖ Model saved: {saved_model_dir}/{timestamp}_model.onnx")

# Check sizes
fe_size = os.path.getsize(f'{saved_model_dir}/{timestamp}_frontend.onnx') / 1024
model_size = os.path.getsize(f'{saved_model_dir}/{timestamp}_model.onnx') / 1024
total_size = fe_size + model_size

print("\n" + "="*60)
print("SIZE REPORT (ONNX - Float32)")
print("="*60)
print(f"Feature Extractor: {fe_size:.2f} KB")
print(f"Model: {model_size:.2f} KB")
print(f"Total: {total_size:.2f} KB")

if total_size < 300:
    print("‚úÖ PASSED: Total size < 300 KB (before quantization)")
else:
    print("‚ö†Ô∏è  WARNING: Size > 300 KB - quantization required!")

# Save Hyperparameters & Results
print("\nüìù Saving hyperparameters and results...")
output_dict = {
    'timestamp': timestamp,
    **CFG,
    'test_accuracy': test_accuracy
}

df = pd.DataFrame([output_dict])
output_path = './keyword_spotter_results.csv'
df.to_csv(output_path, mode='a', header=not os.path.exists(output_path), index=False)
print(f"‚úÖ Results saved to {output_path}")

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)



SAVING MODEL
Model Timestamp: 1764340364

üîÑ Moving models to CPU for ONNX export...

üì¶ Exporting Feature Extractor to ONNX...


W1128 15:32:44.736000 34793 torch/onnx/_internal/exporter/_registration.py:107] torchvision is not installed. Skipping torchvision::nms


[torch.onnx] Obtain model graph for `MelSpectrogramExtractor([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `MelSpectrogramExtractor([...]` with `torch.export.export(..., strict=False)`... ‚úÖ
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ‚úÖ
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ‚úÖ
Applied 4 of general pattern rewrite rules.
‚úÖ Feature extractor saved: ./saved_models//1764340364_frontend.onnx

üì¶ Exporting Model to ONNX...
[torch.onnx] Run decomposition... ‚úÖ
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ‚úÖ
Applied 4 of general pattern rewrite rules.
‚úÖ Feature extractor saved: ./saved_models//1764340364_frontend.onnx

üì¶ Exporting Model to ONNX...


W1128 15:32:45.084000 34793 torch/onnx/_internal/exporter/_registration.py:107] torchvision is not installed. Skipping torchvision::nms


[torch.onnx] Obtain model graph for `KeywordSpotter([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `KeywordSpotter([...]` with `torch.export.export(..., strict=False)`... ‚úÖ
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ‚úÖ
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ‚úÖ
Applied 8 of general pattern rewrite rules.
‚úÖ Model saved: ./saved_models//1764340364_model.onnx

SIZE REPORT (ONNX - Float32)
Feature Extractor: 332.49 KB
Model: 1022.68 KB
Total: 1355.17 KB

üìù Saving hyperparameters and results...
‚úÖ Results saved to ./keyword_spotter_results.csv

TRAINING COMPLETE
[torch.onnx] Run decomposition... ‚úÖ
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ‚úÖ
Applied 8 of general pattern rewrite rules.
‚úÖ Model saved: ./saved_models//1764340364_model.onnx

SIZE REPORT (ONNX - Float32)
Feature Extractor: 332.49 KB
Model: 1022.