# Instrument Classifier

This notebook demonstrates a PyTorch project for multi-label musical instrument recognition from audio clips. It allows you to run the entire pipeline in Google Colab without any special modifications to the codebase.

## Setup

First, let's clone the repository and install the required dependencies.


In [2]:
# Check if running in Colab
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Clone the repository
    !git clone https://github.com/your-username/instrument-classifier.git
    %cd instrument-classifier

    # Install dependencies
    !pip install -r requirements.txt


## Import Libraries

Let's import the necessary libraries for our project.


In [3]:
import os
import torch
import librosa
import numpy as np
import matplotlib.pyplot as plt
import yaml
import pathlib
import zipfile
import urllib.request
import hashlib
import sys
import time
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm


## Configuration

Let's define the configuration for our project.


In [4]:
# Create configs directory if it doesn't exist
os.makedirs('configs', exist_ok=True)

# Default configuration
default_config = """
# Common hyper‑parameters
sample_rate: 22050
n_mels: 64
hop_length: 512
batch_size: 32
num_epochs: 50
learning_rate: 3e-4
num_workers: 4
"""

# ResNet configuration
resnet_config = """
# ResNet‑34 override
model_name: resnet34
learning_rate: 1e-4
batch_size: 16
"""

# Write configurations to files
with open('configs/default.yaml', 'w') as f:
    f.write(default_config)

with open('configs/model_resnet.yaml', 'w') as f:
    f.write(resnet_config)

# Load configuration
cfg = yaml.safe_load(default_config)
resnet_cfg = {**cfg, **yaml.safe_load(resnet_config)}

# Display configuration
print("Default configuration:")
for key, value in cfg.items():
    print(f"  {key}: {value}")

print("\nResNet configuration:")
for key, value in resnet_cfg.items():
    print(f"  {key}: {value}")


Default configuration:
  sample_rate: 22050
  n_mels: 64
  hop_length: 512
  batch_size: 32
  num_epochs: 50
  learning_rate: 3e-4
  num_workers: 4

ResNet configuration:
  sample_rate: 22050
  n_mels: 64
  hop_length: 512
  batch_size: 16
  num_epochs: 50
  learning_rate: 1e-4
  num_workers: 4
  model_name: resnet34


## Download and Extract Data

Let's download and extract the IRMAS dataset.


In [5]:

#!/usr/bin/env python3
"""Download the IRMAS dataset (≈2 GB) and extract it.

Example:
    python data/download_irmas.py --out_dir data/raw
"""
import argparse, hashlib, os, sys, pathlib
import urllib.request
import zipfile

IRMAS_URL = "https://zenodo.org/record/1290750/files/IRMAS-TrainingData.zip?download=1"
MD5       = "4fd9f5ed5a18d8e2687e6360b5f60afe"  # expected archive checksum

def md5(fname, chunk=2**20):
    m = hashlib.md5()
    with open(fname, 'rb') as fh:
        while True:
            data = fh.read(chunk)
            if not data: break
            m.update(data)
    return m.hexdigest()

def main(out_dir: str):
    out_dir = pathlib.Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    archive_path = out_dir / "IRMAS.zip"

    if not archive_path.exists():
        print("Downloading IRMAS ...")
        try:
            urllib.request.urlretrieve(IRMAS_URL, archive_path)
        except Exception as e:
            print(f"Download failed: {e}", file=sys.stderr)
            sys.exit(1)
    else:
        print("Archive already exists, skipping download")

    print("Verifying checksum ...")
    if md5(archive_path) != MD5:
        print("Checksum mismatch! The downloaded file may be corrupted.", file=sys.stderr)
        print("Try deleting the file and running the script again.", file=sys.stderr)
        sys.exit(1)

    print("Extracting ...")
    try:
        with zipfile.ZipFile(archive_path) as zf:
            zf.extractall(out_dir)
        print("Done. Data at", out_dir)
    except zipfile.BadZipFile:
        print("Extraction failed: The file is not a valid zip archive.", file=sys.stderr)
        sys.exit(1)
    except Exception as e:
        print(f"Extraction failed: {e}", file=sys.stderr)
        sys.exit(1)

if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--out_dir", default="data/raw", help="Destination directory")

    # Handle Jupyter notebook execution by ignoring unknown arguments
    import sys
    if 'ipykernel_launcher.py' in sys.argv[0]:
        args, unknown = p.parse_known_args()
    else:
        args = p.parse_args()

    main(args.out_dir)

# When running in notebook, explicitly call main() with the output directory
main("data/raw")

Archive already exists, skipping download
Verifying checksum ...
Extracting ...
Done. Data at data/raw
Archive already exists, skipping download
Verifying checksum ...
Extracting ...
Done. Data at data/raw


## Preprocess Data

Let's preprocess the data by converting WAV files to log-mel spectrograms.


In [None]:
from tqdm import tqdm

def generate_multi_stft(
    y: np.ndarray,
    sr: int,
    n_ffts=(256, 512, 1024),
    band_ranges=((0, 1000), (1000, 4000), (4000, 11025))
):
    """
    Generates 9 spectrograms: 3 window sizes × 3 frequency bands.

    Parameters:
        y (np.ndarray): Audio waveform
        sr (int): Sampling rate
        n_ffts (tuple): FFT window sizes
        band_ranges (tuple): Frequency band ranges (Hz)

    Returns:
        dict: { (band_label, n_fft): spectrogram }
    """
    result = {}

    for n_fft in n_ffts:
        hop_length = n_fft // 4
        stft = librosa.stft(y, n_fft=n_fft, hop_length=hop_length)
        mag = np.abs(stft)

        freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)

        for (f_low, f_high) in band_ranges:
            band_label = f"{f_low}-{f_high}Hz"
            # Get frequency indices within this band
            band_mask = (freqs >= f_low) & (freqs < f_high)
            band_spec = mag[band_mask, :]
            # Convert to log scale
            log_spec = librosa.power_to_db(band_spec, ref=np.max).astype(np.float32)
            result[(band_label, n_fft)] = log_spec

    return result

def process_file(wav_path, cfg):
    y, sr = librosa.load(wav_path, sr=cfg['sample_rate'], mono=True)
    return generate_multi_stft(y, sr)

def preprocess_data(in_dir, out_dir, cfg):
    in_dir, out_dir = pathlib.Path(in_dir), pathlib.Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    # Create train and validation directories
    train_dir = out_dir / 'train'
    val_dir = out_dir / 'val'
    train_dir.mkdir(exist_ok=True)
    val_dir.mkdir(exist_ok=True)

    # Get all WAV files
    wav_files = list(in_dir.rglob("*.wav"))
    print(f"Found {len(wav_files)} WAV files")

    # Split into train and validation sets (90/10 split)
    np.random.shuffle(wav_files)
    split_idx = int(len(wav_files) * 0.9)
    train_files = wav_files[:split_idx]
    val_files = wav_files[split_idx:]

    # Process training files
    print("Processing training files...")
    for wav in tqdm(train_files):
        specs_dict = process_file(wav, cfg)

        # Create a directory for this audio file
        rel_dir = wav.relative_to(in_dir).with_suffix("")
        file_out_dir = train_dir / rel_dir
        file_out_dir.mkdir(parents=True, exist_ok=True)

        # Save each spectrogram with band and FFT size information in the filename
        for (band_label, n_fft), spec in specs_dict.items():
            spec_filename = f"{band_label}_fft{n_fft}.npy"
            np.save(file_out_dir / spec_filename, spec)

    # Process validation files
    print("Processing validation files...")
    for wav in tqdm(val_files):
        specs_dict = process_file(wav, cfg)

        # Create a directory for this audio file
        rel_dir = wav.relative_to(in_dir).with_suffix("")
        file_out_dir = val_dir / rel_dir
        file_out_dir.mkdir(parents=True, exist_ok=True)

        # Save each spectrogram with band and FFT size information in the filename
        for (band_label, n_fft), spec in specs_dict.items():
            spec_filename = f"{band_label}_fft{n_fft}.npy"
            np.save(file_out_dir / spec_filename, spec)

    print(f"Processed {len(train_files)} training files and {len(val_files)} validation files")

# Check if WAV files exist in data/raw/IRMAS
irmas_path = pathlib.Path('data/raw/IRMAS')
if irmas_path.exists() and any(irmas_path.rglob("*.wav")):
    print("Found WAV files in data/raw/IRMAS")
    preprocess_data('data/raw/IRMAS', 'data/processed', cfg)
else:
    # Check if WAV files exist in data/raw/IRMAS-TrainingData
    training_data_path = pathlib.Path('data/raw/IRMAS-TrainingData')
    if training_data_path.exists() and any(training_data_path.rglob("*.wav")):
        print("Found WAV files in data/raw/IRMAS-TrainingData")
        preprocess_data('data/raw/IRMAS-TrainingData', 'data/processed', cfg)
    else:
        # Check if WAV files exist directly in data/raw
        raw_path = pathlib.Path('data/raw')
        if raw_path.exists() and any(raw_path.rglob("*.wav")):
            print("Found WAV files in data/raw")
            preprocess_data('data/raw', 'data/processed', cfg)
        else:
            print("No WAV files found in data/raw or its subdirectories. Please check the extraction path.")


Found WAV files in data/raw/IRMAS-TrainingData
Found 6705 WAV files
Processing training files...


 25%|██▌       | 1534/6034 [00:49<02:27, 30.42it/s]

## Define Models

Let's define the models for our project.


In [None]:
import torch.nn as nn
from torchvision.models import resnet34

# Create models directory if it doesn't exist
os.makedirs('models', exist_ok=True)

class CNNBaseline(nn.Module):
    """Simple CNN baseline for audio classification."""
    def __init__(self, n_classes=11):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, n_classes)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))  # [B,16,32,32]
        x = self.pool(self.relu(self.conv2(x)))  # [B,32,16,16]
        x = self.pool(self.relu(self.conv3(x)))  # [B,64,8,8]
        x = x.view(-1, 64 * 8 * 8)
        x = self.relu(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        return x

class ResNetSpec(nn.Module):
    """ResNet‑34 backbone adapted for single‑channel spectrogram input."""
    def __init__(self, n_classes=11):
        super().__init__()
        self.backbone = resnet34(weights=None)
        self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Replace FC
        self.backbone.fc = nn.Sequential(
            nn.Linear(self.backbone.fc.in_features, n_classes),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.backbone(x)

# Display model architectures
print("CNN Baseline Architecture:")
print(CNNBaseline())
print("\nResNet Architecture:")
print(ResNetSpec())


## Define Dataset and DataLoader

Let's define the dataset and dataloader for our project.


In [None]:
LABELS = ["cello", "clarinet", "flute", "acoustic_guitar", "organ", "piano", "saxophone", "trumpet", "violin", "voice", "other"]

class NpyDataset(Dataset):
    def __init__(self, root):
        self.files = list(pathlib.Path(root).rglob("*.npy"))
        self.label_map = {label: i for i, label in enumerate(LABELS)}

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        spec = np.load(self.files[idx])
        x = torch.tensor(spec).unsqueeze(0)  # [1,H,W]

        # Parse label from folder name
        label_str = self.files[idx].parent.name.split("_")[0]
        y = torch.zeros(len(LABELS))

        # Map label string to index
        if label_str in self.label_map:
            y[self.label_map[label_str]] = 1.0

        return x, y

# Create datasets and dataloaders
train_ds = NpyDataset("data/processed/train")
val_ds = NpyDataset("data/processed/val")

print(f"Training dataset size: {len(train_ds)}")
print(f"Validation dataset size: {len(val_ds)}")

# Create dataloaders
train_loader = DataLoader(train_ds, batch_size=resnet_cfg['batch_size'], shuffle=True, num_workers=resnet_cfg['num_workers'])
val_loader = DataLoader(val_ds, batch_size=resnet_cfg['batch_size'], shuffle=False, num_workers=resnet_cfg['num_workers'])


## Define Model

Let's define the model for our project.


In [None]:
from torchmetrics import Accuracy, Precision, Recall, F1Score

class MetricCollection:
    def __init__(self, n_classes):
        self.accuracy = Accuracy(task="multilabel", num_labels=n_classes)
        self.precision = Precision(task="multilabel", num_labels=n_classes)
        self.recall = Recall(task="multilabel", num_labels=n_classes)
        self.f1 = F1Score(task="multilabel", num_labels=n_classes)

    def __call__(self, preds, targets):
        return {
            "accuracy": self.accuracy(preds, targets),
            "precision": self.precision(preds, targets),
            "recall": self.recall(preds, targets),
            "f1": self.f1(preds, targets)
        }

class InstrumentModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        n_classes = len(LABELS)
        if cfg.get("model_name", "cnn") == "resnet34":
            self.model = ResNetSpec(n_classes)
        else:
            self.model = CNNBaseline(n_classes)
        self.metrics = MetricCollection(n_classes)
        self.lr = cfg["learning_rate"]
        self.cfg = cfg

    def forward(self, x):
        return self.model(x)

    def compute_loss_and_metrics(self, batch, stage="train"):
        x, y = batch
        preds = self(x)
        loss = torch.nn.functional.binary_cross_entropy(preds, y)
        metrics = self.metrics(preds, y)
        return loss, metrics, preds

    def get_optimizer(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

# Initialize model
model = InstrumentModel(resnet_cfg)
print("Model initialized with ResNet-34 backbone")


## Train Model

Let's train the model using PyTorch.


In [None]:
import os
from tqdm.notebook import tqdm

# Create directories for saving models
os.makedirs('checkpoints', exist_ok=True)

class EarlyStopping:
    """Custom implementation of early stopping"""
    def __init__(self, patience=5, mode='max', min_delta=0):
        self.patience = patience
        self.mode = mode
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, score):
        if self.best_score is None:
            self.best_score = score
            return False

        if self.mode == 'max':
            if score > self.best_score + self.min_delta:
                self.best_score = score
                self.counter = 0
            else:
                self.counter += 1
        else:  # mode == 'min'
            if score < self.best_score - self.min_delta:
                self.best_score = score
                self.counter = 0
            else:
                self.counter += 1

        if self.counter >= self.patience:
            self.early_stop = True
            return True
        return False

def train_model(model, train_loader, val_loader, num_epochs, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """Train the model using standard PyTorch training loop"""
    model = model.to(device)
    optimizer = model.get_optimizer()

    # Initialize early stopping
    early_stopping = EarlyStopping(patience=5, mode='max')

    # Initialize best model tracking
    best_f1 = 0.0
    best_model_path = None

    # Training loop
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_metrics = {}

        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        for batch in progress_bar:
            # Move batch to device
            batch = [x.to(device) for x in batch]

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass, compute loss and metrics
            loss, metrics, _ = model.compute_loss_and_metrics(batch, stage='train')

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Update metrics
            train_loss += loss.item()
            for k, v in metrics.items():
                if k not in train_metrics:
                    train_metrics[k] = 0.0
                train_metrics[k] += v.item()

            # Update progress bar
            progress_bar.set_postfix({
                'loss': loss.item(),
                'f1': metrics['f1'].item()
            })

        # Compute average metrics for the epoch
        train_loss /= len(train_loader)
        for k in train_metrics:
            train_metrics[k] /= len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_metrics = {}

        with torch.no_grad():
            progress_bar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
            for batch in progress_bar:
                # Move batch to device
                batch = [x.to(device) for x in batch]

                # Forward pass, compute loss and metrics
                loss, metrics, _ = model.compute_loss_and_metrics(batch, stage='val')

                # Update metrics
                val_loss += loss.item()
                for k, v in metrics.items():
                    if k not in val_metrics:
                        val_metrics[k] = 0.0
                    val_metrics[k] += v.item()

                # Update progress bar
                progress_bar.set_postfix({
                    'loss': loss.item(),
                    'f1': metrics['f1'].item()
                })

        # Compute average metrics for the epoch
        val_loss /= len(val_loader)
        for k in val_metrics:
            val_metrics[k] /= len(val_loader)

        # Print epoch summary
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {train_loss:.4f}, F1: {train_metrics["f1"]:.4f}')
        print(f'  Val Loss: {val_loss:.4f}, F1: {val_metrics["f1"]:.4f}')

        # Save best model
        if val_metrics['f1'] > best_f1:
            best_f1 = val_metrics['f1']
            best_model_path = f'checkpoints/best-{epoch+1:02d}-{val_metrics["f1"]:.2f}.pt'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': val_metrics['f1'],
                'val_loss': val_loss
            }, best_model_path)
            print(f'  Saved best model to {best_model_path}')

        # Check early stopping
        if early_stopping(val_metrics['f1']):
            print(f'Early stopping triggered after {epoch+1} epochs')
            break

    # Load best model
    if best_model_path:
        checkpoint = torch.load(best_model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f'Loaded best model from {best_model_path} with F1: {checkpoint["val_f1"]:.4f}')

    return model

# Train model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
model = train_model(model, train_loader, val_loader, num_epochs=resnet_cfg['num_epochs'], device=device)


## Inference

Let's perform inference on a sample audio file.


In [None]:
def extract_features(path, cfg):
    y, sr = librosa.load(path, sr=cfg['sample_rate'], mono=True)
    specs_dict = generate_multi_stft(y, sr)

    # For inference, we'll use the middle frequency band (1000-4000Hz) with the middle FFT size (512)
    # This is a simplification - in a real application, you might want to use all bands and FFT sizes
    key = ("1000-4000Hz", 512)
    if key in specs_dict:
        spec = specs_dict[key]
        return torch.tensor(spec).unsqueeze(0).unsqueeze(0)
    else:
        # Fallback to the first available spectrogram if the preferred one is not available
        first_key = list(specs_dict.keys())[0]
        spec = specs_dict[first_key]
        return torch.tensor(spec).unsqueeze(0).unsqueeze(0)

def predict(model, wav_path, cfg):
    model.eval()
    x = extract_features(wav_path, cfg)
    with torch.no_grad():
        preds = model(x).squeeze().numpy()
    return {label: float(preds[i]) for i, label in enumerate(LABELS)}

# Get a sample WAV file for inference
sample_wav = list(pathlib.Path('data/raw/IRMAS').rglob("*.wav"))[0]
print(f"Sample WAV file: {sample_wav}")

# Perform inference
results = predict(model, sample_wav, resnet_cfg)

# Display results
print("\nPrediction results:")
for instrument, confidence in sorted(results.items(), key=lambda x: x[1], reverse=True):
    print(f"{instrument}: {confidence:.4f}")

# Visualize results
plt.figure(figsize=(10, 6))
plt.bar(results.keys(), results.values())
plt.xticks(rotation=45, ha='right')
plt.ylabel('Confidence')
plt.title('Instrument Detection Confidence')
plt.tight_layout()
plt.show()


## Visualize Audio and Spectrogram

Let's visualize the audio waveform and spectrogram of the sample file.


In [None]:
def visualize_audio(wav_path, cfg):
    # Load audio
    y, sr = librosa.load(wav_path, sr=cfg['sample_rate'], mono=True)

    # Compute multi-band STFT spectrograms
    specs_dict = generate_multi_stft(y, sr)

    # Plot waveform and selected spectrograms
    plt.figure(figsize=(15, 12))

    # Plot waveform
    plt.subplot(4, 1, 1)
    librosa.display.waveshow(y, sr=sr)
    plt.title('Waveform')

    # Select three spectrograms to visualize (one from each frequency band with the middle FFT size)
    keys_to_plot = [
        ("0-1000Hz", 512),
        ("1000-4000Hz", 512),
        ("4000-11025Hz", 512)
    ]

    for i, key in enumerate(keys_to_plot):
        if key in specs_dict:
            plt.subplot(4, 1, i+2)
            spec = specs_dict[key]
            hop_length = 512 // 4  # hop_length for FFT size 512
            librosa.display.specshow(spec, sr=sr, x_axis='time', hop_length=hop_length)
            plt.colorbar(format='%+2.0f dB')
            plt.title(f'Spectrogram: {key[0]}, FFT size: {key[1]}')
        else:
            print(f"Spectrogram for {key} not found")

    plt.tight_layout()
    plt.show()

# Visualize sample audio
visualize_audio(sample_wav, resnet_cfg)


## Conclusion

In this notebook, we've demonstrated the complete pipeline for multi-label musical instrument recognition:

1. Setting up the environment
2. Downloading and preprocessing the IRMAS dataset
3. Defining and training a ResNet-34 model
4. Performing inference on audio files
5. Visualizing the results

This notebook can be run in Google Colab without any special modifications to the codebase.
