In [1]:
import numpy as np
import os
import pandas as pd
import random
import torch
import torchaudio

from time import time
from torch import nn
from torch.utils.data import DataLoader

from msc_dataset import MSCDataset

# For data augmentation
import torch.nn.functional as F

In [2]:
# Device setup for Mac M4 Pro (MPS), CUDA (NVIDIA), or CPU fallback
if torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
elif torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')

print(f"Using device: {DEVICE}")

Using device: mps


In [3]:
# ==================== CONFIGURATION ====================
CFG = {
    'sampling_rate': 16000,
    'frame_length_in_s': 0.04,
    'frame_step_in_s': 0.02,
    'n_mels': 40,
    'f_min': 20,
    'f_max': 4000,
    'seed': 42,
    
    # Training Parameters
    'train_steps': 6000,
    'train_batch_size': 32,
    'learning_rate': 0.0005,
    'epochs': 60,
    
    # Data Augmentation
    'time_shift_ms': 80,
    'noise_level': 0.002,
    'time_stretch_factor': 0.06,
    
    # Label Smoothing
    'label_smoothing': 0.1,
}

CLASSES = ['stop', 'up']

# Set Deterministic Behaviour
torch.manual_seed(CFG['seed'])
np.random.seed(CFG['seed'])
random.seed(CFG['seed'])
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [4]:
# ==================== MEL-SPECTROGRAM FEATURE EXTRACTOR ====================
class MelSpectrogramExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=CFG['sampling_rate'],
            n_fft=int(CFG['frame_length_in_s'] * CFG['sampling_rate']),
            hop_length=int(CFG['frame_step_in_s'] * CFG['sampling_rate']),
            n_mels=CFG['n_mels'],
            f_min=CFG['f_min'],
            f_max=CFG['f_max'],
            window_fn=torch.hann_window,
            power=2.0,
            normalized=False,
            center=True,
            pad_mode="reflect"
        )
        
    def forward(self, waveform):
        mel_spec = self.mel_transform(waveform)
        log_mel = torch.log(mel_spec + 1e-9)
        log_mel = (log_mel - log_mel.mean(dim=[1, 2], keepdim=True)) / (log_mel.std(dim=[1, 2], keepdim=True) + 1e-9)
        return log_mel.unsqueeze(1)


In [5]:
# ==================== DATA AUGMENTATION ====================
class AudioAugmentation:
    def __init__(self, config, training=True):
        self.config = config
        self.training = training
        
    def time_shift(self, waveform):
        if not self.training or random.random() > 0.5:
            return waveform
            
        shift_samples = int(random.uniform(-self.config['time_shift_ms'], 
                                          self.config['time_shift_ms']) 
                          * self.config['sampling_rate'] / 1000)
        return torch.roll(waveform, shifts=shift_samples, dims=-1)
    
    def add_noise(self, waveform):
        if not self.training or random.random() > 0.5:
            return waveform
            
        noise = torch.randn_like(waveform) * self.config['noise_level']
        return waveform + noise
    
    def time_stretch(self, waveform):
        if not self.training or random.random() > 0.5:
            return waveform
            
        rate = 1.0 + random.uniform(-self.config['time_stretch_factor'], 
                                    self.config['time_stretch_factor'])
        
        stretched = F.interpolate(
            waveform.unsqueeze(0), 
            size=int(waveform.shape[-1] * rate),
            mode='linear',
            align_corners=False
        ).squeeze(0)
        
        target_len = waveform.shape[-1]
        if stretched.shape[-1] < target_len:
            stretched = F.pad(stretched, (0, target_len - stretched.shape[-1]))
        else:
            stretched = stretched[..., :target_len]
            
        return stretched
    
    def __call__(self, waveform):
        waveform = self.time_shift(waveform)
        waveform = self.add_noise(waveform)
        waveform = self.time_stretch(waveform)
        return waveform


In [6]:
# ==================== CNN MODEL ====================
class KeywordSpotterV2(nn.Module):
    def __init__(self, num_classes=2, dropout=0.35):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU(inplace=True)
        self.dropout1 = nn.Dropout2d(dropout * 0.3)
        
        self.conv2a = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2a = nn.BatchNorm2d(64)
        self.relu2a = nn.ReLU(inplace=True)
        self.conv2b = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn2b = nn.BatchNorm2d(64)
        self.relu2b = nn.ReLU(inplace=True)
        self.dropout2 = nn.Dropout2d(dropout * 0.3)
        
        self.downsample1 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=1, stride=2, bias=False),
            nn.BatchNorm2d(64)
        )
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU(inplace=True)
        self.dropout3 = nn.Dropout2d(dropout * 0.5)
        
        self.conv4a = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn4a = nn.BatchNorm2d(128)
        self.relu4a = nn.ReLU(inplace=True)
        self.conv4b = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn4b = nn.BatchNorm2d(128)
        self.relu4b = nn.ReLU(inplace=True)
        
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn5 = nn.BatchNorm2d(256)
        self.relu5 = nn.ReLU(inplace=True)
        self.dropout5 = nn.Dropout2d(dropout)
        
        self.attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(256, 64, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 256, kernel_size=1),
            nn.Sigmoid()
        )
        
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.dropout_fc = nn.Dropout(dropout)
        self.fc = nn.Linear(256, num_classes, bias=True)
        
    def forward(self, x):
        x = self.dropout1(self.relu1(self.bn1(self.conv1(x))))
        
        identity = self.downsample1(x)
        x = self.relu2a(self.bn2a(self.conv2a(x)))
        x = self.bn2b(self.conv2b(x))
        x = self.relu2b(x + identity)
        x = self.dropout2(x)
        
        x = self.dropout3(self.relu3(self.bn3(self.conv3(x))))
        
        identity = x
        x = self.relu4a(self.bn4a(self.conv4a(x)))
        x = self.bn4b(self.conv4b(x))
        x = self.relu4b(x + identity)
        
        x = self.dropout5(self.relu5(self.bn5(self.conv5(x))))
        
        att = self.attention(x)
        x = x * att
        
        x = self.gap(x)
        x = x.view(x.size(0), -1)
        x = self.dropout_fc(x)
        x = self.fc(x)
        
        return x


In [7]:
# ==================== TRAINING SETUP ====================
print("=" * 60)
print("UP/STOP KEYWORD SPOTTER")
print("=" * 60)
print(f"Device: {DEVICE}")
print(f"Training: epochs={CFG['epochs']}, batch_size={CFG['train_batch_size']}, lr={CFG['learning_rate']}")
print(f"Augmentation: time_shift=¬±{CFG['time_shift_ms']}ms, noise={CFG['noise_level']}, stretch=¬±{CFG['time_stretch_factor']*100}%")
print("=" * 60)

transform = MelSpectrogramExtractor()
train_augmentation = AudioAugmentation(CFG, training=True)
val_augmentation = AudioAugmentation(CFG, training=False)

print("\nüìÅ Loading datasets...")
train_dataset = MSCDataset(
    root='.',
    classes=CLASSES,
    split='training',
    preprocess=None,
)

val_dataset = MSCDataset(
    root='.',
    classes=CLASSES,
    split='validation',
    preprocess=None,
)

test_dataset = MSCDataset(
    root='.',
    classes=CLASSES,
    split='testing',
    preprocess=None,
)

sampler = torch.utils.data.RandomSampler(
    train_dataset,
    replacement=True,
    num_samples=CFG['train_steps'] * CFG['train_batch_size'],
)
train_loader = DataLoader(
    train_dataset,
    batch_size=CFG['train_batch_size'],
    sampler=sampler,
    num_workers=0,
)

val_loader = DataLoader(val_dataset, batch_size=100, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=100, num_workers=0)

print("\nüèóÔ∏è  Initializing models...")
feature_extractor = MelSpectrogramExtractor().to(DEVICE)
model = KeywordSpotterV2(num_classes=len(CLASSES), dropout=0.3).to(DEVICE)

total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
print(f"Estimated size (float32): {total_params * 4 / 1024:.2f} KB")

loss_module = nn.CrossEntropyLoss(label_smoothing=CFG['label_smoothing'])
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=CFG['learning_rate'], 
    weight_decay=2e-4,
    betas=(0.9, 0.999)
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, 
    T_0=10,
    T_mult=2,
    eta_min=1e-7
)

best_val_acc = 0
best_val_loss = float('inf')
patience = 20
patience_counter = 0
best_model_state = None


UP/STOP KEYWORD SPOTTER
Device: mps
Training: epochs=60, batch_size=32, lr=0.0005
Augmentation: time_shift=¬±80ms, noise=0.002, stretch=¬±6.0%

üìÅ Loading datasets...
Using data folder: ./msc-training
Loaded 1600 samples from ./msc-training for classes ['stop', 'up']
Using data folder: ./msc-validation
Loaded 200 samples from ./msc-validation for classes ['stop', 'up']
Using data folder: ./msc-testing
Loaded 200 samples from ./msc-testing for classes ['stop', 'up']

üèóÔ∏è  Initializing models...
Model parameters: 777,346
Estimated size (float32): 3036.51 KB


In [8]:
# ==================== TRAINING & EVALUATION FUNCTIONS ====================
def evaluate(model, feature_extractor, loader, device):
    model.eval()
    correct = total = 0
    total_loss = 0
    loss_module = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch in loader:
            x = batch['x'].squeeze(1).to(device)
            y = batch['y'].to(device)
            features = feature_extractor(x)
            output = model(features)
            predictions = output.argmax(dim=1)
            correct += (predictions == y).sum().item()
            total += y.size(0)
            total_loss += loss_module(output, y).item()
    
    accuracy = (correct / total) * 100
    avg_loss = total_loss / len(loader)
    return accuracy, avg_loss


def train_epoch(model, feature_extractor, train_loader, optimizer, loss_module, device, 
                steps_per_epoch, current_epoch, augmentation):
    model.train()
    
    start_step = current_epoch * steps_per_epoch
    end_step = start_step + steps_per_epoch
    
    epoch_loss = 0
    step_count = 0
    
    for step, batch in enumerate(train_loader):
        if step < start_step:
            continue
        if step >= end_step:
            break
            
        x = batch['x'].squeeze(1)
        y = batch['y'].to(device)
        
        x_augmented = []
        for i in range(x.shape[0]):
            aug_sample = augmentation(x[i:i+1])
            x_augmented.append(aug_sample)
        x = torch.cat(x_augmented, dim=0).to(device)
        
        with torch.no_grad():
            features = feature_extractor(x)
        
        output = model(features)
        loss = loss_module(output, y)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        epoch_loss += loss.item()
        step_count += 1
        
        if ((step + 1) % 100) == 0 or step == 0:
            print(f'Step={step}; Training Loss={loss.item():.4f}')
    
    avg_epoch_loss = epoch_loss / step_count if step_count > 0 else 0
    return avg_epoch_loss


def test_model(model, feature_extractor, test_loader, device):
    model.eval()
    correct = total = 0
    
    with torch.no_grad():
        for batch in test_loader:
            x = batch['x'].squeeze(1).to(device)
            y = batch['y'].to(device)
            features = feature_extractor(x)
            output = model(features)
            predictions = output.argmax(dim=1)
            correct += (predictions == y).sum().item()
            total += y.size(0)
    
    accuracy = (correct / total) * 100
    return accuracy


In [9]:
# ==================== TRAINING LOOP ====================
print("\nüöÄ Starting training...")

steps_per_epoch = len(train_loader) // CFG['epochs']
current_epoch = 0
train_history = {'epoch': [], 'train_loss': [], 'val_acc': [], 'val_loss': [], 'lr': []}

for epoch in range(CFG['epochs']):
    print(f"\n{'='*60}")
    print(f"EPOCH {epoch+1}/{CFG['epochs']}")
    print(f"{'='*60}")
    
    train_loss = train_epoch(
        model, feature_extractor, train_loader, optimizer, loss_module, 
        DEVICE, steps_per_epoch, epoch, train_augmentation
    )
    
    current_epoch += 1
    val_acc, val_loss = evaluate(model, feature_extractor, val_loader, DEVICE)
    current_lr = optimizer.param_groups[0]['lr']
    
    print(f'\nüìä Epoch {current_epoch} Summary:')
    print(f'   Train Loss: {train_loss:.4f}')
    print(f'   Val Acc: {val_acc:.2f}%')
    print(f'   Val Loss: {val_loss:.4f}')
    print(f'   Learning Rate: {current_lr:.6f}')
    
    train_history['epoch'].append(current_epoch)
    train_history['train_loss'].append(train_loss)
    train_history['val_acc'].append(val_acc)
    train_history['val_loss'].append(val_loss)
    train_history['lr'].append(current_lr)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_val_loss = val_loss
        best_model_state = model.state_dict().copy()
        patience_counter = 0
        print(f'‚úÖ New best model! Val Acc={val_acc:.2f}%')
    else:
        patience_counter += 1
        print(f'‚è≥ No improvement. Patience: {patience_counter}/{patience}')
    
    if patience_counter >= patience:
        print(f'\nüõë Early stopping at epoch {current_epoch}')
        break
    
    scheduler.step()

if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print(f'\n‚úÖ Loaded best model with Val Acc={best_val_acc:.2f}%')

print(f"\n{'='*60}")
print(f"Training completed after {current_epoch} epochs")
print(f"Best validation accuracy: {best_val_acc:.2f}%")
print(f"{'='*60}")



üöÄ Starting training...

EPOCH 1/60
Step=0; Training Loss=0.6940
Step=0; Training Loss=0.6940
Step=99; Training Loss=0.4114

üìä Epoch 1 Summary:
   Train Loss: 0.5773
   Val Acc: 87.50%
   Val Loss: 0.3306
   Learning Rate: 0.000500
‚úÖ New best model! Val Acc=87.50%

EPOCH 2/60
Step=99; Training Loss=0.4114

üìä Epoch 1 Summary:
   Train Loss: 0.5773
   Val Acc: 87.50%
   Val Loss: 0.3306
   Learning Rate: 0.000500
‚úÖ New best model! Val Acc=87.50%

EPOCH 2/60
Step=199; Training Loss=0.4022

üìä Epoch 2 Summary:
   Train Loss: 0.4307
   Val Acc: 79.00%
   Val Loss: 0.4314
   Learning Rate: 0.000488
‚è≥ No improvement. Patience: 1/20

EPOCH 3/60
Step=199; Training Loss=0.4022

üìä Epoch 2 Summary:
   Train Loss: 0.4307
   Val Acc: 79.00%
   Val Loss: 0.4314
   Learning Rate: 0.000488
‚è≥ No improvement. Patience: 1/20

EPOCH 3/60
Step=299; Training Loss=0.3899

üìä Epoch 3 Summary:
   Train Loss: 0.3917
   Val Acc: 91.50%
   Val Loss: 0.2261
   Learning Rate: 0.000452
‚úÖ New

In [10]:
# ==================== TEST EVALUATION ====================
print("\nüìä Evaluating model on test set...")

model = model.to(DEVICE)
feature_extractor = feature_extractor.to(DEVICE)

test_accuracy = test_model(model, feature_extractor, test_loader, DEVICE)
print(f'\nüéØ Test Accuracy: {test_accuracy:.2f}%')

if test_accuracy > 99.4:
    print("‚úÖ PASSED: Accuracy > 99.4%")
else:
    print("‚ùå FAILED: Accuracy <= 99.4%")



üìä Evaluating model on test set...

üéØ Test Accuracy: 99.50%
‚úÖ PASSED: Accuracy > 99.4%


In [11]:
# ==================== SAVE MODEL ====================
print("\n" + "="*60)
print("SAVING MODEL")
print("="*60)

timestamp = int(time())
saved_model_dir = './saved_models/'
if not os.path.exists(saved_model_dir):
    os.makedirs(saved_model_dir)

print(f'Model Timestamp: {timestamp}')

model.eval()
feature_extractor.eval()

print("\nüîÑ Moving models to CPU for ONNX export...")
model_cpu = model.cpu()
feature_extractor_cpu = feature_extractor.cpu()

print("\nüì¶ Exporting Feature Extractor to ONNX...")
torch.onnx.export(
    feature_extractor_cpu,
    torch.randn(1, 16000),
    f'{saved_model_dir}/{timestamp}_frontend.onnx',
    input_names=['input'],
    dynamo=True,
    optimize=True,
    report=False,
    external_data=False,
)
print(f"‚úÖ Feature extractor saved: {saved_model_dir}/{timestamp}_frontend.onnx")

print("\nüì¶ Exporting Model to ONNX...")
sample_waveform = train_dataset[0]['x'].squeeze(0).unsqueeze(0).cpu()
sample_features = feature_extractor_cpu(sample_waveform)
torch.onnx.export(
    model_cpu,
    sample_features,
    f'{saved_model_dir}/{timestamp}_model.onnx',
    input_names=['input'],
    dynamo=True,
    optimize=True,
    report=False,
    external_data=False,
)
print(f"‚úÖ Model saved: {saved_model_dir}/{timestamp}_model.onnx")

fe_size = os.path.getsize(f'{saved_model_dir}/{timestamp}_frontend.onnx') / 1024
model_size = os.path.getsize(f'{saved_model_dir}/{timestamp}_model.onnx') / 1024
total_size = fe_size + model_size

print("\n" + "="*60)
print("SIZE REPORT (ONNX - Float32)")
print("="*60)
print(f"Feature Extractor: {fe_size:.2f} KB")
print(f"Model: {model_size:.2f} KB")
print(f"Total: {total_size:.2f} KB")

if total_size < 300:
    print("‚úÖ PASSED: Total size < 300 KB (before quantization)")
else:
    print("‚ö†Ô∏è  WARNING: Size > 300 KB - quantization required!")

print("\nüìù Saving results...")
output_dict = {
    'timestamp': timestamp,
    **CFG,
    'test_accuracy': test_accuracy
}

df = pd.DataFrame([output_dict])
output_path = './keyword_spotter_results.csv'
df.to_csv(output_path, mode='a', header=not os.path.exists(output_path), index=False)
print(f"‚úÖ Results saved to {output_path}")

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)



SAVING MODEL
Model Timestamp: 1764424228

üîÑ Moving models to CPU for ONNX export...

üì¶ Exporting Feature Extractor to ONNX...


W1129 14:50:29.259000 49064 torch/onnx/_internal/exporter/_registration.py:107] torchvision is not installed. Skipping torchvision::nms


[torch.onnx] Obtain model graph for `MelSpectrogramExtractor([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `MelSpectrogramExtractor([...]` with `torch.export.export(..., strict=False)`... ‚úÖ
[torch.onnx] Run decomposition...


W1129 14:50:29.664000 49064 torch/onnx/_internal/exporter/_registration.py:107] torchvision is not installed. Skipping torchvision::nms


[torch.onnx] Run decomposition... ‚úÖ
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ‚úÖ
Applied 4 of general pattern rewrite rules.
‚úÖ Feature extractor saved: ./saved_models//1764424228_frontend.onnx

üì¶ Exporting Model to ONNX...
[torch.onnx] Obtain model graph for `KeywordSpotterV2([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `KeywordSpotterV2([...]` with `torch.export.export(..., strict=False)`... ‚úÖ
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ‚úÖ
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ‚úÖ
Applied 16 of general pattern rewrite rules.
‚úÖ Model saved: ./saved_models//1764424228_model.onnx

SIZE REPORT (ONNX - Float32)
Feature Extractor: 332.22 KB
Model: 3057.24 KB
Total: 3389.46 KB

üìù Saving results...
‚úÖ Results saved to ./keyword_spotter_results.csv

TRAINING COMPLETE
[torch.onnx] Obtain model graph for `K