# Audio Classification Tutorial

This notebook demonstrates how to use the audio_classify library for:
1. Loading and preprocessing audio data
2. Training a model
3. Running inference
4. Visualizing results

**Prerequisites:**
- UrbanSound8K dataset downloaded and extracted
- Update `DATA_ROOT` below with your dataset path


In [None]:
import sys
from pathlib import Path

# Add parent directory to path
sys.path.insert(0, str(Path().absolute().parent))

import torch
import numpy as np
import matplotlib.pyplot as plt
from datasets.urbansound8k import UrbanSound8K
from models.small_cnn import SmallCNN
from transforms.audio import get_mel_transform, wav_to_logmel
from utils.device import get_device, get_device_name

# Configuration
DATA_ROOT = "/path/to/UrbanSound8K"  # Update this!
DEVICE = get_device()
print(f"Using device: {get_device_name()} ({DEVICE})")


## 1. Load Dataset

Load a subset of the UrbanSound8K dataset (e.g., fold 10 for validation)


In [None]:
# Load dataset
dataset = UrbanSound8K(
    root=DATA_ROOT,
    folds=[10],  # Use fold 10 for validation
    target_sr=16000,
    duration=4.0,
    n_mels=64,
    n_fft=1024,
    hop_length=256,
    augment=None  # No augmentation for visualization
)

print(f"Dataset size: {len(dataset)}")
print(f"Number of classes: {len(dataset.class_ids)}")
print(f"Class names: {list(dataset.idx2name.values())}")


## 2. Inspect a Sample

Load and visualize a single sample from the dataset


In [None]:
# Get a sample
idx = 0
log_mel, label = dataset[idx]

print(f"Sample {idx}:")
print(f"  Log-mel shape: {log_mel.shape}")
print(f"  Label index: {label}")
print(f"  Class name: {dataset.idx2name[label]}")

# Visualize spectrogram
plt.figure(figsize=(10, 4))
arr = log_mel.squeeze(0).detach().cpu().numpy()
plt.imshow(arr, aspect='auto', origin='lower')
plt.title(f"Log-Mel Spectrogram: {dataset.idx2name[label]}")
plt.xlabel("Time frames")
plt.ylabel("Mel bins")
plt.colorbar(label="dB")
plt.tight_layout()
plt.show()


## 3. Create and Test Model

Create a SmallCNN model and test the forward pass


In [None]:
# Create model
num_classes = len(dataset.class_ids)
model = SmallCNN(n_classes=num_classes).to(DEVICE)

# Test forward pass
model.eval()
with torch.no_grad():
    x = log_mel.unsqueeze(0).to(DEVICE)  # Add batch dimension
    logits = model(x)
    probs = torch.softmax(logits, dim=-1)

print(f"Model output shape: {logits.shape}")
print(f"\nTop 3 predictions:")
top3_probs, top3_indices = torch.topk(probs[0], k=3)
for i, (prob, idx) in enumerate(zip(top3_probs, top3_indices)):
    class_name = dataset.idx2name[int(idx)]
    print(f"  {i+1}. {class_name}: {prob:.3f}")


## 4. Load Trained Model (if available)

Load a pre-trained model checkpoint for inference


In [None]:
CHECKPOINT_PATH = "artifacts/best_model.pt"

try:
    state_dict = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    model.load_state_dict(state_dict)
    model.eval()
    print(f"Loaded model from {CHECKPOINT_PATH}")
    
    # Run inference on sample
    with torch.no_grad():
        x = log_mel.unsqueeze(0).to(DEVICE)
        logits = model(x)
        probs = torch.softmax(logits, dim=-1)
    
    print(f"\nPredictions for sample (true label: {dataset.idx2name[label]}):")
    top5_probs, top5_indices = torch.topk(probs[0], k=5)
    for i, (prob, idx) in enumerate(zip(top5_probs, top5_indices)):
        class_name = dataset.idx2name[int(idx)]
        is_correct = "âœ“" if int(idx) == label else ""
        print(f"  {i+1}. {class_name}: {prob:.3f} {is_correct}")
        
except FileNotFoundError:
    print(f"Model checkpoint not found at {CHECKPOINT_PATH}")
    print("Train a model first using: python train.py --data_root <path> --epochs 5")


## Next Steps

- Train your own model: `python train.py --data_root <path> --epochs 5`
- Run inference on files: `python predict.py --wav <file> --data_root <path>`
- Try live streaming: `python scripts/stream_infer.py --data_root <path>`
- Explore the dataset: `python scripts/vis_dataset.py --data_root <path>`

For more examples, see the [examples/](../examples/) directory.
