# Instrument Classifier

This notebook demonstrates a PyTorch project for multi-label musical instrument recognition from audio clips. It allows you to run the entire pipeline in Google Colab without any special modifications to the codebase.

## Setup

First, let's clone the repository and install the required dependencies.


In [None]:
from __future__ import annotations

# Check if running in Colab
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Clone the repository
    !git clone https://github.com/ofekdd/DL_Project.git
    %cd DL_Project

    # Install dependencies
    !pip install -r requirements.txt


fatal: destination path 'DL_Project' already exists and is not an empty directory.
/content/DL_Project


## Import Libraries

Let's import the necessary libraries for our project.


In [None]:
import os
import torch
import librosa
import numpy as np
import matplotlib.pyplot as plt
import yaml
import pathlib
import zipfile
import urllib.request
import hashlib
import sys
import time
from torch.utils.data import DataLoader, Dataset
import warnings, tqdm
warnings.filterwarnings("ignore", category=tqdm.TqdmWarning)

# ---- force classic text progress bar ----
import importlib, sys
import tqdm
sys.modules['tqdm.notebook'] = tqdm
sys.modules['tqdm.autonotebook'] = tqdm
from tqdm import tqdm  # now `tqdm(...)` is always the console bar



## Configuration

Let's define the configuration for our project.


In [None]:
# Create configs directory if it doesn't exist
os.makedirs('configs', exist_ok=True)

# Default configuration
default_config = """
# Common hyper‑parameters
sample_rate: 22050
n_mels: 64
hop_length: 512
batch_size: 32
num_epochs: 50
learning_rate: 3e-4
num_workers: 4
"""

# ResNet configuration
resnet_config = """
# ResNet‑34 override
model_name: resnet34
learning_rate: 1e-4
batch_size: 16
"""


# Write configurations to files
with open('configs/default.yaml', 'w', encoding='utf-8') as f:
    f.write(default_config)

with open('configs/model_resnet.yaml', 'w', encoding='utf-8') as f:
    f.write(resnet_config)
# Load configuration
cfg = yaml.safe_load(default_config)
resnet_cfg = {**cfg, **yaml.safe_load(resnet_config)}

# Display configuration
print("Default configuration:")
for key, value in cfg.items():
    print(f"  {key}: {value}")

print("\nResNet configuration:")
for key, value in resnet_cfg.items():
    print(f"  {key}: {value}")


Default configuration:
  sample_rate: 22050
  n_mels: 64
  hop_length: 512
  batch_size: 32
  num_epochs: 50
  learning_rate: 3e-4
  num_workers: 4

ResNet configuration:
  sample_rate: 22050
  n_mels: 64
  hop_length: 512
  batch_size: 16
  num_epochs: 50
  learning_rate: 1e-4
  num_workers: 4
  model_name: resnet34


## Download and Extract Data

Let's download and extract the IRMAS dataset.


In [None]:
#!/usr/bin/env python3
"""Download the **raw** IRMAS zip (≈ 3 GB) to a persistent location
(Drive if running in Colab) and *optionally* extract it to a **scratch** folder
inside the current runtime.  The extracted data is therefore rebuilt every
session, keeping Google Drive usage tiny (≈ 3 GB).

Typical Colab usage
-------------------
```python
!python data/download_irmas.py                   # zip → Drive, extract → /content/IRMAS
```

Local (non‑Colab) usage – behaves like the original script:
```bash
python data/download_irmas.py --out_dir data/raw  # zip + extract inside repo
```
"""

import argparse, hashlib, os, pathlib, sys, urllib.request, zipfile

IRMAS_URL = (
    "https://zenodo.org/records/1290750/files/IRMAS-TrainingData.zip?download=1"
)
MD5 = "4fd9f5ed5a18d8e2687e6360b5f60afe"  # expected archive checksum


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def inside_colab() -> bool:
    return "google.colab" in sys.modules


def ensure_drive_mounted():
    """Mount Google Drive once per Colab session."""
    if not inside_colab():
        return
    from google.colab import drive  # type: ignore  # noqa

    drive_root = pathlib.Path("/content/drive")
    if not drive_root.is_dir():
        print("Mounting Google Drive …")
        drive.mount("/content/drive")


def md5(path: pathlib.Path, chunk: int = 2 ** 20) -> str:
    h = hashlib.md5()
    with path.open("rb") as fh:
        for blk in iter(lambda: fh.read(chunk), b""):
            h.update(blk)
    return h.hexdigest()


def download_zip(zip_path: pathlib.Path):
    print("Downloading IRMAS zip …")
    try:
        urllib.request.urlretrieve(IRMAS_URL, zip_path)
    except Exception as e:
        print(f"Download failed: {e}", file=sys.stderr)
        sys.exit(1)


# ---------------------------------------------------------------------------
# Main routine
# ---------------------------------------------------------------------------

def main(zip_dir: pathlib.Path):
    """Ensure the IRMAS zip is present in *zip_dir* (persistent).
    In Colab, also extract to /content/IRMAS every session.
    """
    zip_dir.mkdir(parents=True, exist_ok=True)
    zip_path = zip_dir / "IRMAS.zip"

    # 1) download once ------------------------------------------------------
    if zip_path.exists():
        print("Zip already in cache – skipping download")
    else:
        download_zip(zip_path)

    # 2) checksum -----------------------------------------------------------
    print("Verifying checksum …")
    if md5(zip_path) != MD5:
        print("❌  MD5 mismatch – delete the zip and rerun.", file=sys.stderr)
        sys.exit(1)

    # 3) decide where to extract -------------------------------------------
    if inside_colab():
        extract_root = pathlib.Path("/content/IRMAS")
    else:
        extract_root = zip_path.parent / "IRMAS"

    if extract_root.is_dir():
        print("Extracted folder already exists – skipping unzip")
        print("Data ready at", extract_root)
        return

    print(f"Extracting to {extract_root} … (this happens each new runtime)")
    with zipfile.ZipFile(zip_path) as zf:
        zf.extractall(extract_root)
    print("✅  Done. Data at", extract_root)


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    default_out = (
        "/content/drive/MyDrive/datasets/IRMAS"
        if inside_colab()
        else "data/raw/IRMAS"
    )

    parser = argparse.ArgumentParser(
        description="Download IRMAS zip to a persistent folder. In Colab the zip "
        "lives on Drive but extraction goes to /content/IRMAS so Drive quota "
        "remains low."
    )
    parser.add_argument(
        "--out_dir",
        default=default_out,
        help="Folder where the IRMAS.zip will be cached (Drive for Colab).",
    )

    args, _ = parser.parse_known_args()

    ensure_drive_mounted()
    main(pathlib.Path(args.out_dir))


Mounting Google Drive …
Mounted at /content/drive
Downloading IRMAS zip …
Verifying checksum …
Extracting to /content/IRMAS … (this happens each new runtime)
✅  Done. Data at /content/IRMAS


## Preprocess Data

Let's preprocess the data by converting WAV files to log-mel spectrograms.


In [None]:
from tqdm import tqdm

def generate_multi_stft(
    y: np.ndarray,
    sr: int,
    n_ffts = (256, 512, 1024),
    band_ranges = ((0, 1000), (1000, 4000), (4000, 11025))
):
    """
    Returns { (band_label, n_fft): log-magnitude spectrogram }.
    3 window sizes × 3 frequency bands = 9 specs per clip.
    """
    result = {}
    for n_fft in n_ffts:
        hop_length = n_fft // 4
        mag = np.abs(librosa.stft(y, n_fft=n_fft, hop_length=hop_length))

        freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
        for f_low, f_high in band_ranges:
            band_label = f"{f_low}-{f_high}Hz"
            band_mask  = (freqs >= f_low) & (freqs < f_high)
            band_spec  = mag[band_mask]
            log_spec   = librosa.amplitude_to_db(band_spec,
                                                 ref=np.max).astype(np.float32)
            result[(band_label, n_fft)] = log_spec
    return result


# --------------------------------------------------------------------
# 2. Single-file processor
# --------------------------------------------------------------------
def process_file(wav_path: pathlib.Path, cfg: dict):
    y, sr = librosa.load(wav_path, sr=cfg["sample_rate"], mono=True)
    return generate_multi_stft(y, sr)


# --------------------------------------------------------------------
# 3. Bulk preprocessing with 90 / 10 split
# --------------------------------------------------------------------
def preprocess_data(in_dir: str, out_dir: str, cfg: dict):
    in_dir  = pathlib.Path(in_dir)
    out_dir = pathlib.Path(out_dir)
    train_dir, val_dir = out_dir / "train", out_dir / "val"
    for d in (train_dir, val_dir):
        d.mkdir(parents=True, exist_ok=True)

    wav_files = list(in_dir.rglob("*.wav"))
    if not wav_files:
        print(f"‼️  No WAV files under {in_dir}")
        return
    print(f"Found {len(wav_files)} WAV files")

    np.random.shuffle(wav_files)
    split = int(len(wav_files) * 0.9)
    splits = {"train": wav_files[:split], "val": wav_files[split:]}

    for split_name, files in splits.items():
        print(f"Processing {split_name} files …")
        base_out = train_dir if split_name == "train" else val_dir
        for wav in tqdm(files):
            specs = process_file(wav, cfg)

            rel_dir = wav.relative_to(in_dir).with_suffix("")
            (base_out / rel_dir).mkdir(parents=True, exist_ok=True)

            for (band_label, n_fft), spec in specs.items():
                fname = f"{band_label}_fft{n_fft}.npy"
                np.save(base_out / rel_dir / fname, spec)

    print(f"✅  Done. {len(splits['train'])} train / {len(splits['val'])} val clips")


# --------------------------------------------------------------------
# 4. Locate audio automatically
# --------------------------------------------------------------------
def find_irmas_root() -> pathlib.Path | None:
    """Return the first existing path that contains IRMAS WAVs."""
    candidates = [
        pathlib.Path("/content/IRMAS/IRMAS-TrainingData"),   # Colab scratch
        pathlib.Path("data/raw/IRMAS/IRMAS-TrainingData"),   # legacy
        pathlib.Path("data/raw/IRMAS-TrainingData"),         # legacy alt
        pathlib.Path("data/raw"),                            # fallback
    ]
    for p in candidates:
        if p.exists() and any(p.rglob("*.wav")):
            return p
    return None


# --------------------------------------------------------------------
# 5. Run
# --------------------------------------------------------------------
cfg = yaml.safe_load(open("configs/default.yaml"))
irmas_root = find_irmas_root()

if irmas_root is None:
    print("❌  No IRMAS audio found. Did you run download_irmas.py yet?")
else:
    print(f"Using IRMAS root: {irmas_root}")
    preprocess_data(irmas_root, "data/processed", cfg)

Using IRMAS root: /content/IRMAS/IRMAS-TrainingData
Found 6705 WAV files
Processing train files …


100%|██████████| 6034/6034 [02:42<00:00, 37.05it/s]


Processing val files …


100%|██████████| 671/671 [00:16<00:00, 41.93it/s]

✅  Done. 6034 train / 671 val clips





## Define Models

Let's define the models for our project.


In [None]:
import torch.nn as nn
from torchvision.models import resnet34

# Create models directory if it doesn't exist
os.makedirs('models', exist_ok=True)

class CNNBaseline(nn.Module):
    """Simple CNN baseline for audio classification."""
    def __init__(self, n_classes=11):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, n_classes)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))  # [B,16,32,32]
        x = self.pool(self.relu(self.conv2(x)))  # [B,32,16,16]
        x = self.pool(self.relu(self.conv3(x)))  # [B,64,8,8]
        x = x.view(-1, 64 * 8 * 8)
        x = self.relu(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        return x

class ResNetSpec(nn.Module):
    """ResNet‑34 backbone adapted for single‑channel spectrogram input."""
    def __init__(self, n_classes=11):
        super().__init__()
        self.backbone = resnet34(weights=None)
        self.backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Replace FC
        self.backbone.fc = nn.Sequential(
            nn.Linear(self.backbone.fc.in_features, n_classes),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.backbone(x)

# Display model architectures
print("CNN Baseline Architecture:")
print(CNNBaseline())
print("\nResNet Architecture:")
print(ResNetSpec())


CNN Baseline Architecture:
CNNBaseline(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=4096, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=11, bias=True)
  (relu): ReLU()
  (sigmoid): Sigmoid()
)

ResNet Architecture:
ResNetSpec(
  (backbone): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), paddi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Define Dataset and DataLoader

Let's define the dataset and dataloader for our project.


In [None]:
LABELS = ["cello", "clarinet", "flute", "acoustic_guitar", "organ", "piano", "saxophone", "trumpet", "violin", "voice", "other"]

def pad_collate(batch):
    xs, ys = zip(*batch)
    # find max mel bins in batch
    H = max(x.shape[1] for x in xs)
    W = max(x.shape[2] for x in xs)      # optional time padding
    padded = []
    for x in xs:
        pad_h = H - x.shape[1]
        pad_w = W - x.shape[2]
        x_padded = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h))  # (left, right, top, bottom)
        padded.append(x_padded)
    return torch.stack(padded), torch.stack(ys)

class NpyDataset(Dataset):
    def __init__(self, root):
        self.files = list(pathlib.Path(root).rglob("*.npy"))
        self.label_map = {label: i for i, label in enumerate(LABELS)}

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        spec = np.load(self.files[idx])
        x = torch.tensor(spec).unsqueeze(0)  # [1,H,W]

        # Parse label from folder name
        label_str = self.files[idx].parent.name.split("_")[0]
        y = torch.zeros(len(LABELS))

        # Map label string to index
        if label_str in self.label_map:
            y[self.label_map[label_str]] = 1.0

        return x, y

# Create datasets and dataloaders
train_ds = NpyDataset("data/processed/train")
val_ds = NpyDataset("data/processed/val")

print(f"Training dataset size: {len(train_ds)}")
print(f"Validation dataset size: {len(val_ds)}")

# Create dataloaders
train_loader = DataLoader(train_ds, batch_size=resnet_cfg['batch_size'], shuffle=True, num_workers=resnet_cfg['num_workers'], collate_fn=pad_collate)
val_loader = DataLoader(val_ds, batch_size=resnet_cfg['batch_size'], shuffle=False, num_workers=resnet_cfg['num_workers'], collate_fn=pad_collate)


Training dataset size: 54306
Validation dataset size: 6039




## Define Model

Let's define the model for our project.


In [12]:
from torchmetrics import Accuracy, Precision, Recall, F1Score
import torch.nn as nn
import torch

# -----------------------------------------------------------
# 1. Metrics collection that knows its target device
# -----------------------------------------------------------
class MetricCollection:
    def __init__(self, n_classes: int, device: torch.device | str = "cpu"):
        self.accuracy  = Accuracy(task="multilabel",  num_labels=n_classes).to(device)
        self.precision = Precision(task="multilabel", num_labels=n_classes).to(device)
        self.recall    = Recall(task="multilabel",    num_labels=n_classes).to(device)
        self.f1        = F1Score(task="multilabel",   num_labels=n_classes).to(device)

    def __call__(self, preds: torch.Tensor, targets: torch.Tensor) -> dict:
        return {
            "accuracy":  self.accuracy(preds, targets),
            "precision": self.precision(preds, targets),
            "recall":    self.recall(preds, targets),
            "f1":        self.f1(preds, targets),
        }

# -----------------------------------------------------------
# 2. Model wrapper
# -----------------------------------------------------------
class InstrumentModel(nn.Module):
    def __init__(self, cfg: dict):
        super().__init__()
        n_classes = len(LABELS)
        backbone  = ResNetSpec if cfg.get("model_name", "cnn") == "resnet34" else CNNBaseline
        self.model   = backbone(n_classes)
        self.device_ = torch.device(cfg.get("device", "cpu"))
        self.metrics = MetricCollection(n_classes, device=self.device_)
        self.lr      = float(cfg.get("learning_rate", 1e-4))

    # ---- forward / training utils ----------------------------------------
    def forward(self, x):                       # x on self.device_
        return self.model(x)

    def compute_loss_and_metrics(self, batch, stage="train"):
        x, y  = (t.to(self.device_) for t in batch)   # ensure batch on same device
        preds = self(x)
        loss  = torch.nn.functional.binary_cross_entropy(preds, y)
        return loss, self.metrics(preds, y), preds

    def get_optimizer(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

# -----------------------------------------------------------
# 3. Instantiate on the desired device BEFORE training loop
# -----------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet_cfg["device"] = str(device)               # make sure cfg knows the device

model = InstrumentModel(resnet_cfg).to(device)   # moves model **and** its metrics
print(f"Model initialised on {device}")

NameError: name 'MultilabelF1Score' is not defined

## Train Model

Let's train the model using PyTorch.


In [None]:
import os
from tqdm.notebook import tqdm

# Create directories for saving models
os.makedirs('checkpoints', exist_ok=True)

class EarlyStopping:
    """Custom implementation of early stopping"""
    def __init__(self, patience=5, mode='max', min_delta=0):
        self.patience = patience
        self.mode = mode
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, score):
        if self.best_score is None:
            self.best_score = score
            return False

        if self.mode == 'max':
            if score > self.best_score + self.min_delta:
                self.best_score = score
                self.counter = 0
            else:
                self.counter += 1
        else:  # mode == 'min'
            if score < self.best_score - self.min_delta:
                self.best_score = score
                self.counter = 0
            else:
                self.counter += 1

        if self.counter >= self.patience:
            self.early_stop = True
            return True
        return False

def train_model(model, train_loader, val_loader, num_epochs, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """Train the model using standard PyTorch training loop"""
    model = model.to(device)
    optimizer = model.get_optimizer()

    # Initialize early stopping
    early_stopping = EarlyStopping(patience=5, mode='max')

    # Initialize best model tracking
    best_f1 = 0.0
    best_model_path = None

    # Training loop
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_metrics = {}

        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        for batch in progress_bar:
            # Move batch to device
            batch = [x.to(device) for x in batch]

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass, compute loss and metrics
            loss, metrics, _ = model.compute_loss_and_metrics(batch, stage='train')

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Update metrics
            train_loss += loss.item()
            for k, v in metrics.items():
                if k not in train_metrics:
                    train_metrics[k] = 0.0
                train_metrics[k] += v.item()

            # Update progress bar
            progress_bar.set_postfix({
                'loss': loss.item(),
                'f1': metrics['f1'].item()
            })

        # Compute average metrics for the epoch
        train_loss /= len(train_loader)
        for k in train_metrics:
            train_metrics[k] /= len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_metrics = {}

        with torch.no_grad():
            progress_bar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
            for batch in progress_bar:
                # Move batch to device
                batch = [x.to(device) for x in batch]

                # Forward pass, compute loss and metrics
                loss, metrics, _ = model.compute_loss_and_metrics(batch, stage='val')

                # Update metrics
                val_loss += loss.item()
                for k, v in metrics.items():
                    if k not in val_metrics:
                        val_metrics[k] = 0.0
                    val_metrics[k] += v.item()

                # Update progress bar
                progress_bar.set_postfix({
                    'loss': loss.item(),
                    'f1': metrics['f1'].item()
                })

        # Compute average metrics for the epoch
        val_loss /= len(val_loader)
        for k in val_metrics:
            val_metrics[k] /= len(val_loader)

        # Print epoch summary
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {train_loss:.4f}, F1: {train_metrics["f1"]:.4f}')
        print(f'  Val Loss: {val_loss:.4f}, F1: {val_metrics["f1"]:.4f}')

        # Save best model
        if val_metrics['f1'] > best_f1:
            best_f1 = val_metrics['f1']
            best_model_path = f'checkpoints/best-{epoch+1:02d}-{val_metrics["f1"]:.2f}.pt'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': val_metrics['f1'],
                'val_loss': val_loss
            }, best_model_path)
            print(f'  Saved best model to {best_model_path}')

        # Check early stopping
        if early_stopping(val_metrics['f1']):
            print(f'Early stopping triggered after {epoch+1} epochs')
            break

    # Load best model
    if best_model_path:
        checkpoint = torch.load(best_model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f'Loaded best model from {best_model_path} with F1: {checkpoint["val_f1"]:.4f}')

    return model

# Train model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
model = train_model(model, train_loader, val_loader, num_epochs=resnet_cfg['num_epochs'], device=device)


Using device: cuda


Epoch 1/50 [Train]:   0%|          | 0/3395 [00:02<?, ?it/s]


RuntimeError: Encountered different devices in metric calculation (see stacktrace for details). This could be due to the metric class not being on the same device as input. Instead of `metric=MultilabelAccuracy(...)` try to do `metric=MultilabelAccuracy(...).to(device)` where device corresponds to the device of the input.

## Inference

Let's perform inference on a sample audio file.


In [None]:
def extract_features(path, cfg):
    y, sr = librosa.load(path, sr=cfg['sample_rate'], mono=True)
    specs_dict = generate_multi_stft(y, sr)

    # For inference, we'll use the middle frequency band (1000-4000Hz) with the middle FFT size (512)
    # This is a simplification - in a real application, you might want to use all bands and FFT sizes
    key = ("1000-4000Hz", 512)
    if key in specs_dict:
        spec = specs_dict[key]
        return torch.tensor(spec).unsqueeze(0).unsqueeze(0)
    else:
        # Fallback to the first available spectrogram if the preferred one is not available
        first_key = list(specs_dict.keys())[0]
        spec = specs_dict[first_key]
        return torch.tensor(spec).unsqueeze(0).unsqueeze(0)

def predict(model, wav_path, cfg):
    model.eval()
    x = extract_features(wav_path, cfg)
    with torch.no_grad():
        preds = model(x).squeeze().numpy()
    return {label: float(preds[i]) for i, label in enumerate(LABELS)}

# Get a sample WAV file for inference
sample_wav = list(pathlib.Path('data/raw/IRMAS').rglob("*.wav"))[0]
print(f"Sample WAV file: {sample_wav}")

# Perform inference
results = predict(model, sample_wav, resnet_cfg)

# Display results
print("\nPrediction results:")
for instrument, confidence in sorted(results.items(), key=lambda x: x[1], reverse=True):
    print(f"{instrument}: {confidence:.4f}")

# Visualize results
plt.figure(figsize=(10, 6))
plt.bar(results.keys(), results.values())
plt.xticks(rotation=45, ha='right')
plt.ylabel('Confidence')
plt.title('Instrument Detection Confidence')
plt.tight_layout()
plt.show()


## Visualize Audio and Spectrogram

Let's visualize the audio waveform and spectrogram of the sample file.


In [None]:
def visualize_audio(wav_path, cfg):
    # Load audio
    y, sr = librosa.load(wav_path, sr=cfg['sample_rate'], mono=True)

    # Compute multi-band STFT spectrograms
    specs_dict = generate_multi_stft(y, sr)

    # Plot waveform and selected spectrograms
    plt.figure(figsize=(15, 12))

    # Plot waveform
    plt.subplot(4, 1, 1)
    librosa.display.waveshow(y, sr=sr)
    plt.title('Waveform')

    # Select three spectrograms to visualize (one from each frequency band with the middle FFT size)
    keys_to_plot = [
        ("0-1000Hz", 512),
        ("1000-4000Hz", 512),
        ("4000-11025Hz", 512)
    ]

    for i, key in enumerate(keys_to_plot):
        if key in specs_dict:
            plt.subplot(4, 1, i+2)
            spec = specs_dict[key]
            hop_length = 512 // 4  # hop_length for FFT size 512
            librosa.display.specshow(spec, sr=sr, x_axis='time', hop_length=hop_length)
            plt.colorbar(format='%+2.0f dB')
            plt.title(f'Spectrogram: {key[0]}, FFT size: {key[1]}')
        else:
            print(f"Spectrogram for {key} not found")

    plt.tight_layout()
    plt.show()

# Visualize sample audio
visualize_audio(sample_wav, resnet_cfg)


## Conclusion

In this notebook, we've demonstrated the complete pipeline for multi-label musical instrument recognition:

1. Setting up the environment
2. Downloading and preprocessing the IRMAS dataset
3. Defining and training a ResNet-34 model
4. Performing inference on audio files
5. Visualizing the results

This notebook can be run in Google Colab without any special modifications to the codebase.