# Instrument Classification – Colab Runtime
This notebook sets up the environment, downloads the IRMAS dataset **once** to your Google Drive (shared folder), regenerates features each session, and trains the model.

In [None]:

# --- Colab & Drive setup ----------------------------------------------------
import sys, pathlib, os
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    print("Drive mounted.")
else:
    print("Running outside Colab")


In [None]:

# --- Project clone & dependency install ------------------------------------
REPO_URL = "https://github.com/ofekdd/DL_Project.git"   # <- adjust if needed
REPO_DIR = "DL_Project"

if not pathlib.Path(REPO_DIR).is_dir():
    !git clone $REPO_URL
%cd $REPO_DIR

!pip -q install -r requirements.txt


In [None]:

# --- Dataset download (raw zip only on Drive) ------------------------------
import subprocess, pathlib, os, sys

DATA_CACHE = "/content/drive/MyDrive/DL_Shared/IRMAS" if IN_COLAB else "data/raw/IRMAS"
!python data/download_irmas.py --out_dir $DATA_CACHE


In [None]:

# --- Feature preprocessing (done each session into /content) ---------------
import pathlib, os, sys, subprocess, json, shutil

FEATURE_DIR = "/content/IRMAS_features"
if not pathlib.Path(FEATURE_DIR).is_dir():
    print("Preprocessing train split ...")
    !python data/preprocess.py --in_dir $DATA_CACHE/IRMAS-TrainingData --out_dir $FEATURE_DIR/train
    print("Preprocessing test split ...")
    !python data/preprocess.py --in_dir $DATA_CACHE/IRMAS-TestingData --out_dir $FEATURE_DIR/test
else:
    print("Features already exist in this runtime – skipping.")


In [None]:

# --- Training --------------------------------------------------------------
import torch, yaml
from training.train import main as train_main

CONFIG = "configs/model_resnet.yaml"
!python -m training.train --config $CONFIG


In [None]:

# --- Inference demo --------------------------------------------------------
CKPT_PATH = !ls lightning_logs/*/checkpoints/*.ckpt | tail -n 1
TEST_WAV = f"{DATA_CACHE}/IRMAS-TestingData/0001.wav"
!python inference/predict.py {CKPT_PATH[0]} $TEST_WAV


In [None]:

# --- Visualization ---------------------------------------------------------
import matplotlib.pyplot as plt
import librosa
import librosa.display
import numpy as np
import torch
import yaml
from data.preprocess import generate_multi_stft
from inference.predict import predict
from models import ResNetSpec

def visualize_audio(wav_path, cfg):
    # Load audio
    y, sr = librosa.load(wav_path, sr=cfg['sample_rate'], mono=True)

    # Compute multi-band STFT spectrograms
    specs_dict = generate_multi_stft(y, sr)

    # Plot waveform and selected spectrograms
    plt.figure(figsize=(15, 12))

    # Plot waveform
    plt.subplot(4, 1, 1)
    librosa.display.waveshow(y, sr=sr)
    plt.title('Waveform')

    # Select three spectrograms to visualize (one from each frequency band with the middle FFT size)
    keys_to_plot = [
        ("0-1000Hz", 512),
        ("1000-4000Hz", 512),
        ("4000-11025Hz", 512)
    ]

    for i, key in enumerate(keys_to_plot):
        if key in specs_dict:
            plt.subplot(4, 1, i+2)
            spec = specs_dict[key]
            hop_length = 512 // 4  # hop_length for FFT size 512
            librosa.display.specshow(spec, sr=sr, x_axis='time', hop_length=hop_length)
            plt.colorbar(format='%+2.0f dB')
            plt.title(f'Spectrogram: {key[0]}, FFT size: {key[1]}')
        else:
            print(f"Spectrogram for {key} not found")

    plt.tight_layout()
    plt.show()

# Load configuration
cfg = yaml.safe_load(open("configs/default.yaml"))

# Visualize a sample audio file
if pathlib.Path(TEST_WAV).exists():
    visualize_audio(TEST_WAV, cfg)
    
    # Load model and make predictions
    if CKPT_PATH and pathlib.Path(CKPT_PATH[0]).exists():
        model = ResNetSpec(11)
        model.load_state_dict(torch.load(CKPT_PATH[0], map_location="cpu")["state_dict"])
        results = predict(model, TEST_WAV, cfg)
        
        # Display results
        plt.figure(figsize=(10, 6))
        plt.bar(results.keys(), results.values())
        plt.xticks(rotation=45, ha='right')
        plt.ylabel('Confidence')
        plt.title('Instrument Detection Confidence')
        plt.tight_layout()
        plt.show()
else:
    print(f"Test WAV file not found: {TEST_WAV}")