# Instrument Classification – Colab Runtime
This notebook sets up the environment, downloads the IRMAS dataset **once** to your Google Drive (shared folder), regenerates features each session, and trains the model.

In [None]:

# --- Colab & Drive setup ----------------------------------------------------
import sys, pathlib, os
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    print("Drive mounted.")
else:
    print("Running outside Colab")


In [None]:

# --- Project clone & dependency install ------------------------------------
REPO_URL = "https://github.com/ofekdd/DL_Project.git"   # <- adjust if needed
REPO_DIR = "DL_Project"

if not pathlib.Path(REPO_DIR).is_dir():
    !git clone $REPO_URL
%cd $REPO_DIR

!pip -q install -r requirements.txt


In [None]:

# --- Dataset download (raw zip only on Drive) ------------------------------
import subprocess, pathlib, os, sys

DATA_CACHE = "/content/drive/MyDrive/DL_Shared/IRMAS" if IN_COLAB else "data/raw/IRMAS"
!python data/download_irmas.py --out_dir $DATA_CACHE


In [None]:

# --- Feature preprocessing (done each session into /content) ---------------
import pathlib, os, sys, subprocess, json, shutil

FEATURE_DIR = "/content/IRMAS_features"
if not pathlib.Path(FEATURE_DIR).is_dir():
    print("Preprocessing train split ...")
    !python data/preprocess.py --in_dir $DATA_CACHE/IRMAS-TrainingData --out_dir $FEATURE_DIR/train
    print("Preprocessing test split ...")
    !python data/preprocess.py --in_dir $DATA_CACHE/IRMAS-TestingData --out_dir $FEATURE_DIR/test
else:
    print("Features already exist in this runtime – skipping.")


In [None]:

# --- Training --------------------------------------------------------------
import torch, yaml
from training.train import main as train_main

CONFIG = "configs/model_resnet.yaml"
!python -m training.train --config $CONFIG


In [None]:

# --- Inference demo --------------------------------------------------------
CKPT_PATH = !ls lightning_logs/*/checkpoints/*.ckpt | tail -n 1
TEST_WAV = f"{DATA_CACHE}/IRMAS-TestingData/0001.wav"
!python inference/predict.py {CKPT_PATH[0]} $TEST_WAV
