# Instrument Classifier

This notebook demonstrates a PyTorch Lightning project for multi-label musical instrument recognition from audio clips. It allows you to run the entire pipeline in Google Colab without any special modifications to the codebase.

## Setup

First, let's clone the repository and install the required dependencies.


In [None]:
# Check if running in Colab
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Clone the repository
    !git clone https://github.com/your-username/instrument-classifier.git
    %cd instrument-classifier
    
    # Install dependencies
    !pip install -r requirements.txt


## Import Libraries

Let's import the necessary libraries for our project.


In [None]:
import os
import torch
import torchaudio
import pytorch_lightning as pl
import librosa
import numpy as np
import matplotlib.pyplot as plt
import yaml
import pathlib
import zipfile
import urllib.request
import hashlib
import sys
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose
from tqdm.notebook import tqdm

# Set random seed for reproducibility
pl.seed_everything(42)


## Configuration

Let's define the configuration for our project.


In [None]:
# Create configs directory if it doesn't exist
os.makedirs('configs', exist_ok=True)

# Default configuration
default_config = """
# Common hyper‑parameters
sample_rate: 22050
n_mels: 64
hop_length: 512
batch_size: 32
num_epochs: 50
learning_rate: 3e-4
num_workers: 4
"""

# ResNet configuration
resnet_config = """
# ResNet‑34 override
model_name: resnet34
learning_rate: 1e-4
batch_size: 16
"""

# Write configurations to files
with open('configs/default.yaml', 'w') as f:
    f.write(default_config)
    
with open('configs/model_resnet.yaml', 'w') as f:
    f.write(resnet_config)

# Load configuration
cfg = yaml.safe_load(default_config)
resnet_cfg = {**cfg, **yaml.safe_load(resnet_config)}

# Display configuration
print("Default configuration:")
for key, value in cfg.items():
    print(f"  {key}: {value}")
    
print("\nResNet configuration:")
for key, value in resnet_cfg.items():
    print(f"  {key}: {value}")


## Download and Extract Data

Let's download and extract the IRMAS dataset.


In [None]:
IRMAS_URL = "https://zenodo.org/record/1290750/files/IRMAS-TrainingData.zip?download=1"
MD5 = "4fd9f5ed5a18d8e2687e6360b5f60afe"  # expected archive checksum

def md5(fname, chunk=2**20):
    m = hashlib.md5()
    with open(fname, 'rb') as fh:
        while True:
            data = fh.read(chunk)
            if not data: break
            m.update(data)
    return m.hexdigest()

def download_irmas(out_dir):
    out_dir = pathlib.Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    archive_path = out_dir / "IRMAS.zip"
    
    if not archive_path.exists():
        print("Downloading IRMAS ...")
        urllib.request.urlretrieve(IRMAS_URL, archive_path)
    else:
        print("Archive already exists, skipping download")
    
    print("Verifying checksum ...")
    if md5(archive_path) != MD5:
        print("Checksum mismatch!", file=sys.stderr)
        sys.exit(1)
    
    print("Extracting ...")
    with zipfile.ZipFile(archive_path) as zf:
        zf.extractall(out_dir)
    print("Done. Data at", out_dir)

# Create data directories
os.makedirs('data/raw', exist_ok=True)
os.makedirs('data/processed', exist_ok=True)

# Download and extract IRMAS dataset
download_irmas('data/raw')


## Preprocess Data

Let's preprocess the data by converting WAV files to log-mel spectrograms.


In [None]:
def process_file(wav_path, cfg):
    y, sr = librosa.load(wav_path, sr=cfg['sample_rate'], mono=True)
    mels = librosa.feature.melspectrogram(
        y=y, sr=sr, n_mels=cfg['n_mels'], hop_length=cfg['hop_length'], fmin=30
    )
    logmel = librosa.power_to_db(mels, ref=np.max).astype(np.float32)
    return logmel

def preprocess_data(in_dir, out_dir, cfg):
    in_dir, out_dir = pathlib.Path(in_dir), pathlib.Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    
    # Create train and validation directories
    train_dir = out_dir / 'train'
    val_dir = out_dir / 'val'
    train_dir.mkdir(exist_ok=True)
    val_dir.mkdir(exist_ok=True)
    
    # Get all WAV files
    wav_files = list(in_dir.rglob("*.wav"))
    print(f"Found {len(wav_files)} WAV files")
    
    # Split into train and validation sets (90/10 split)
    np.random.shuffle(wav_files)
    split_idx = int(len(wav_files) * 0.9)
    train_files = wav_files[:split_idx]
    val_files = wav_files[split_idx:]
    
    # Process training files
    print("Processing training files...")
    for wav in tqdm(train_files):
        spec = process_file(wav, cfg)
        rel = wav.relative_to(in_dir).with_suffix(".npy")
        out_path = train_dir / rel
        out_path.parent.mkdir(parents=True, exist_ok=True)
        np.save(out_path, spec)
    
    # Process validation files
    print("Processing validation files...")
    for wav in tqdm(val_files):
        spec = process_file(wav, cfg)
        rel = wav.relative_to(in_dir).with_suffix(".npy")
        out_path = val_dir / rel
        out_path.parent.mkdir(parents=True, exist_ok=True)
        np.save(out_path, spec)
    
    print(f"Processed {len(train_files)} training files and {len(val_files)} validation files")

# Preprocess data
preprocess_data('data/raw/IRMAS', 'data/processed', cfg)


## Define Models

Let's define the models for our project.


In [None]:
import torch.nn as nn
from torchvision.models import resnet34

# Create models directory if it doesn't exist
os.makedirs('models', exist_ok=True)

class CNNBaseline(nn.Module):
    """Simple CNN baseline for audio classification."""
    def __init__(self, n_classes=11):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, n_classes)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))  # [B,16,32,32]
        x = self.pool(self.relu(self.conv2(x)))  # [B,32,16,16]
        x = self.pool(self.relu(self.conv3(x)))  # [B,64,8,8]
        x = x.view(-1, 64 * 8 * 8)
        x = self.relu(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        return x

class ResNetSpec(nn.Module):
    """ResNet‑34 backbone adapted for single‑channel spectrogram input."""
    def __init__(self, n_classes=11):
        super().__init__()
        self.backbone = resnet34(weights=None)
        self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Replace FC
        self.backbone.fc = nn.Sequential(
            nn.Linear(self.backbone.fc.in_features, n_classes),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.backbone(x)

# Display model architectures
print("CNN Baseline Architecture:")
print(CNNBaseline())
print("\nResNet Architecture:")
print(ResNetSpec())


## Define Dataset and DataLoader

Let's define the dataset and dataloader for our project.


In [None]:
LABELS = ["cello", "clarinet", "flute", "acoustic_guitar", "organ", "piano", "saxophone", "trumpet", "violin", "voice", "other"]

class NpyDataset(Dataset):
    def __init__(self, root):
        self.files = list(pathlib.Path(root).rglob("*.npy"))
        self.label_map = {label: i for i, label in enumerate(LABELS)}
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        spec = np.load(self.files[idx])
        x = torch.tensor(spec).unsqueeze(0)  # [1,H,W]
        
        # Parse label from folder name
        label_str = self.files[idx].parent.name.split("_")[0]
        y = torch.zeros(len(LABELS))
        
        # Map label string to index
        if label_str in self.label_map:
            y[self.label_map[label_str]] = 1.0
        
        return x, y

# Create datasets and dataloaders
train_ds = NpyDataset("data/processed/train")
val_ds = NpyDataset("data/processed/val")

print(f"Training dataset size: {len(train_ds)}")
print(f"Validation dataset size: {len(val_ds)}")

# Create dataloaders
train_loader = DataLoader(train_ds, batch_size=resnet_cfg['batch_size'], shuffle=True, num_workers=resnet_cfg['num_workers'])
val_loader = DataLoader(val_ds, batch_size=resnet_cfg['batch_size'], shuffle=False, num_workers=resnet_cfg['num_workers'])


## Define Lightning Module

Let's define the PyTorch Lightning module for our project.


In [None]:
from torchmetrics import Accuracy, Precision, Recall, F1Score

class MetricCollection:
    def __init__(self, n_classes):
        self.accuracy = Accuracy(task="multilabel", num_labels=n_classes)
        self.precision = Precision(task="multilabel", num_labels=n_classes)
        self.recall = Recall(task="multilabel", num_labels=n_classes)
        self.f1 = F1Score(task="multilabel", num_labels=n_classes)
    
    def __call__(self, preds, targets):
        return {
            "accuracy": self.accuracy(preds, targets),
            "precision": self.precision(preds, targets),
            "recall": self.recall(preds, targets),
            "f1": self.f1(preds, targets)
        }

class LitModel(pl.LightningModule):
    def __init__(self, cfg):
        super().__init__()
        n_classes = len(LABELS)
        if cfg.get("model_name", "cnn") == "resnet34":
            self.model = ResNetSpec(n_classes)
        else:
            self.model = CNNBaseline(n_classes)
        self.metrics = MetricCollection(n_classes)
        self.lr = cfg["learning_rate"]
        self.save_hyperparameters(cfg)
    
    def forward(self, x):
        return self.model(x)
    
    def common_step(self, batch, stage):
        x, y = batch
        preds = self(x)
        loss = torch.nn.functional.binary_cross_entropy(preds, y)
        metrics = self.metrics(preds, y)
        self.log_dict({f"{stage}/loss": loss, **{f"{stage}/{k}": v for k, v in metrics.items()}},
                      prog_bar=True)
        return loss
    
    def training_step(self, batch, _):
        return self.common_step(batch, "train")
    
    def validation_step(self, batch, _):
        return self.common_step(batch, "val")
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

# Initialize model
model = LitModel(resnet_cfg)
print("Model initialized with ResNet-34 backbone")


## Train Model

Let's train the model using PyTorch Lightning.


In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

def default_callbacks():
    return [
        ModelCheckpoint(
            monitor="val/f1",
            mode="max",
            save_top_k=1,
            filename="best-{epoch:02d}-{val/f1:.2f}",
            verbose=True
        ),
        EarlyStopping(
            monitor="val/f1",
            mode="max",
            patience=5,
            verbose=True
        )
    ]

# Create trainer
trainer = pl.Trainer(
    max_epochs=resnet_cfg['num_epochs'],
    callbacks=default_callbacks(),
    accelerator="auto"
)

# Train model
trainer.fit(model, train_loader, val_loader)


## Inference

Let's perform inference on a sample audio file.


In [None]:
def extract_features(path, cfg):
    y, sr = librosa.load(path, sr=cfg['sample_rate'], mono=True)
    mels = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=cfg['n_mels'], hop_length=cfg['hop_length'])
    logmel = librosa.power_to_db(mels, ref=np.max).astype(np.float32)
    return torch.tensor(logmel).unsqueeze(0).unsqueeze(0)

def predict(model, wav_path, cfg):
    model.eval()
    x = extract_features(wav_path, cfg)
    with torch.no_grad():
        preds = model(x).squeeze().numpy()
    return {label: float(preds[i]) for i, label in enumerate(LABELS)}

# Get a sample WAV file for inference
sample_wav = list(pathlib.Path('data/raw/IRMAS').rglob("*.wav"))[0]
print(f"Sample WAV file: {sample_wav}")

# Perform inference
results = predict(model, sample_wav, resnet_cfg)

# Display results
print("\nPrediction results:")
for instrument, confidence in sorted(results.items(), key=lambda x: x[1], reverse=True):
    print(f"{instrument}: {confidence:.4f}")

# Visualize results
plt.figure(figsize=(10, 6))
plt.bar(results.keys(), results.values())
plt.xticks(rotation=45, ha='right')
plt.ylabel('Confidence')
plt.title('Instrument Detection Confidence')
plt.tight_layout()
plt.show()


## Visualize Audio and Spectrogram

Let's visualize the audio waveform and spectrogram of the sample file.


In [None]:
def visualize_audio(wav_path, cfg):
    # Load audio
    y, sr = librosa.load(wav_path, sr=cfg['sample_rate'], mono=True)
    
    # Compute spectrogram
    mels = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=cfg['n_mels'], hop_length=cfg['hop_length'])
    logmel = librosa.power_to_db(mels, ref=np.max)
    
    # Plot waveform and spectrogram
    plt.figure(figsize=(12, 8))
    
    plt.subplot(2, 1, 1)
    librosa.display.waveshow(y, sr=sr)
    plt.title('Waveform')
    
    plt.subplot(2, 1, 2)
    librosa.display.specshow(logmel, sr=sr, x_axis='time', y_axis='mel', hop_length=cfg['hop_length'])
    plt.colorbar(format='%+2.0f dB')
    plt.title('Mel Spectrogram')
    
    plt.tight_layout()
    plt.show()

# Visualize sample audio
visualize_audio(sample_wav, resnet_cfg)


## Conclusion

In this notebook, we've demonstrated the complete pipeline for multi-label musical instrument recognition:

1. Setting up the environment
2. Downloading and preprocessing the IRMAS dataset
3. Defining and training a ResNet-34 model
4. Performing inference on audio files
5. Visualizing the results

This notebook can be run in Google Colab without any special modifications to the codebase.