In [26]:

import sys
import warnings, tqdm

warnings.filterwarnings("ignore", category=tqdm.TqdmWarning)
sys.modules['tqdm.notebook'] = tqdm
sys.modules['tqdm.autonotebook'] = tqdm

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    import os

    # Always start fresh and clone the specific branch
    print("🗑️ Cleaning up any existing project...")
    %cd / content
    !rm -rf DL_Project

    #TODO: Fix the branch according to the latest changes
    print("📥 Cloning specific branch 'master'...")
    !git clone -b master https://github.com/ofekdd/DL_Project.git
    %cd DL_Project

    # Verify we're on the correct branch
    print("🔍 Verifying branch...")
    !git branch
    !git log --oneline -n 3

    # Install dependencies
    print("📦 Installing dependencies...")
    !pip install -r requirements.txt

    print("✅ Setup complete with branch 'master'!")

In [27]:
# Check the current working directory and ensure it is the project root
from pathlib import Path
print("CWD :", Path.cwd())                    # where the kernel is running
print("Exists?", Path('configs').is_dir())    # should be True if CWD is project root


CWD : /home/odahan/Technion/Semester_8/Deep_Learning/Project/notebooks
Exists? False


In [28]:
import yaml
import os

# Define the path to the YAML configuration file
workspace = '/home/odahan/Technion/Semester_8/Deep_Learning/Project'
yaml_path = 'configs/panns_enhanced.yaml' if IN_COLAB else f'{workspace}/configs/panns_enhanced.yaml'
print(yaml_path)
# Open and load the YAML file
with open(yaml_path, 'r') as file:
    cfg = yaml.safe_load(file)

print("PANNs-enhanced configuration:")
for key, value in cfg.items():
    print(f"  {key}: {value}")

/home/odahan/Technion/Semester_8/Deep_Learning/Project/configs/multi_stft_cnn.yaml
9cnn configuration:
  model_name: multi_stft_cnn
  sample_rate: 22050
  n_mels: 64
  hop_length: 512
  batch_size: 8
  num_epochs: 50
  learning_rate: 2e-4
  num_workers: 4
  n_branches: 9
  branch_output_dim: 128


In [None]:
# Download the IRMAS dataset (training + testing parts) if needed
from data.download_irmas import main as download_irmas_main
import pathlib
import os

# Environment-aware cache path
if IN_COLAB:
    DATA_CACHE = "/content/drive/MyDrive/datasets/IRMAS"
else:
    # Check for dataset in home directory
    home_dataset_path = pathlib.Path.home() / "datasets" / "irmas"
    DATA_CACHE = str(home_dataset_path if home_dataset_path.exists() else "data/raw")

# Create the dataset directory if it doesn't exist
os.makedirs(DATA_CACHE, exist_ok=True)

# Download the IRMAS datasets if not already downloaded
irmas_zips = [
    "IRMAS.TrainingData.zip",
    "IRMAS-TestingData-Part1.zip",
    "IRMAS-TestingData-Part2.zip",
    "IRMAS-TestingData-Part3.zip"
]
download_required = any(not (pathlib.Path(DATA_CACHE) / zip_name).exists() for zip_name in irmas_zips)

if download_required:
    print(f"📥 Downloading IRMAS datasets into: {DATA_CACHE}")
    download_irmas_main(pathlib.Path(DATA_CACHE))
else:
    print("✅ All IRMAS zip files already present. Skipping download.")

# Confirm dataset availability
expected_dirs = [
    "IRMAS-TrainingData",
    "IRMAS-TestingData-Part1",
    "IRMAS-TestingData-Part2",
    "IRMAS-TestingData-Part3"
]

print("\n🔍 Verifying dataset extraction:")
missing = []
for dir_name in expected_dirs:
    expected_path = pathlib.Path(DATA_CACHE) / dir_name
    if expected_path.exists():
        print(f"✅ {dir_name} found at {expected_path}")
    else:
        print(f"❌ {dir_name} missing at {expected_path}")
        missing.append(dir_name)

if missing:
    print("\n⚠️ Some dataset parts are missing. You may want to delete corrupted zip files and re-run the cell.")
else:
    print("🎉 All expected IRMAS dataset parts are ready.")

# Set IRMAS root to base path (used in later cells to access all parts)
irmas_root = pathlib.Path(DATA_CACHE)
print(f"\n📂 IRMAS base path set to: {irmas_root}")

In [None]:
# Fix NumPy compatibility issue
import sys

print("🔧 Fixing NumPy compatibility...")

# Check current NumPy version
import numpy as np

print(f"Current NumPy version: {np.__version__}")

# If NumPy 2.0+, we need to downgrade or use a workaround
if int(np.__version__.split('.')[0]) >= 2:
    print("⚠️  NumPy 2.0+ detected. Installing compatible version...")
    !pip install "numpy<2.0" --quiet

    # Restart the kernel to load the new NumPy version
    print("🔄 Restarting kernel to load compatible NumPy...")
    import os

    os.kill(os.getpid(), 9)  # This will restart the kernel in Colab
else:
    print("✅ NumPy version is compatible")

In [None]:
from data.download_irmas import load_irmas_audio_dataset, load_irmas_testing_dataset
import pathlib

if irmas_root and irmas_root.exists():

    print("📁 Dataset creation settings from config:")
    print(f"   max_original_samples: {cfg.get('max_original_samples', 50)}")
    print(f"   num_mixtures: {cfg.get('num_mixtures', 100)}")
    print(f"   min_instruments: {cfg.get('min_instruments', 1)}")
    print(f"   max_instruments: {cfg.get('max_instruments', 2)}")

    # Normalize root
    base_root = irmas_root.parent if irmas_root.name == "IRMAS-TrainingData" else irmas_root

    # Define separate paths
    training_path = base_root / "IRMAS-TrainingData"
    testing_paths = [base_root / f"IRMAS-TestingData-Part{i}" for i in range(1, 4)]

    print(f"📂 IRMAS paths:")
    print(f"   ├─ Training: {training_path}")
    for tp in testing_paths:
        print(f"   └─ Test Part: {tp}")

    # Load training data (single-label)
    original_dataset = load_irmas_audio_dataset(base_root, cfg,
                                            max_samples=cfg.get("max_original_samples"))

    # Load test data (multi-label)
    test_datasets = []
    for _ in testing_paths:        # paths 1-3 are *ignored* here
        test_datasets.extend(load_irmas_testing_dataset(base_root, cfg))

    # Merge datasets if needed
    total_loaded = len(original_dataset) + len(test_datasets)
    print(f"\n📊 Final dataset summary:")
    print(f"   ✅ Training samples loaded: {len(original_dataset)}")
    print(f"   ✅ Testing samples loaded: {len(test_datasets)}")
    print(f"   ✅ Total samples loaded:   {total_loaded}")

else:
    print("❌ IRMAS root not found or invalid. Please run the download step first.")

In [29]:
if irmas_root:
    print(f"IRMAS dataset found at: {irmas_root}")
    PROCESSED_DIR = "/content/IRMAS_features" if IN_COLAB else "data/processed"

    # Use config value for original data percentage
    original_data_percentage = cfg.get('original_data_percentage', 0.1)
    print(f"Using {original_data_percentage*100}% of original IRMAS data (from config)")

    from data.preprocess import preprocess_data

    preprocess_data(
        irmas_root=irmas_root,
        out_dir=PROCESSED_DIR,
        cfg=cfg,
        original_data_percentage=original_data_percentage
    )

    print(f"✅ Preprocessing complete with mixed labels. Features saved to {PROCESSED_DIR}")

else:
    print("Could not locate IRMAS dataset after download. Check paths and try again.")

Downloading IRMAS dataset to data/raw...
Archive already exists, skipping download
Verifying checksum ...
Extracting ...
Done. Data at data/raw
IRMAS dataset found at: data/raw/IRMAS-TrainingData

To preprocess the data, you can run:
python data/preprocess.py --in_dir data/raw/IRMAS-TrainingData --out_dir data/processed

Or execute this command in the next cell:
!python data/preprocess.py --in_dir data/raw/IRMAS-TrainingData --out_dir data/processed


In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# 📦  Configure paths & echo training settings
# ──────────────────────────────────────────────────────────────────────────────
print("🔧 Configuring data paths and training settings...")

# 1) show YAML-driven parameters
base_max_samples = cfg.get("max_samples", None)
print("📁 Base configuration from YAML:")
print(f"   max_samples            : {base_max_samples}")
print(f"   max_original_samples   : {cfg.get('max_original_samples', 50)}")
print(f"   num_mixtures           : {cfg.get('num_mixtures', 100)}")
print(f"   min_instruments        : {cfg.get('min_instruments', 1)}")
print(f"   max_instruments        : {cfg.get('max_instruments', 2)}")
print(f"   original_data_percentage : {cfg.get('original_data_percentage', 0.1)}")

# 2) notebook override notice
if base_max_samples != cfg.get("max_samples"):
    print(f"⚠️  Notebook override: max_samples changed to {cfg.get('max_samples')}")

# 3) define processed-data locations
PROCESSED_DIR = "/content/IRMAS_features" if IN_COLAB else "data/processed"
cfg.update(
    {
        "data_dir": PROCESSED_DIR,
        "train_dir": f"{PROCESSED_DIR}/train",
        "val_dir": f"{PROCESSED_DIR}/val",
        "test_dir": f"{PROCESSED_DIR}/test",
    }
)

print("\n📂 Data directories:")
print(f"   Processed data : {cfg['data_dir']}")
print(f"   Training       : {cfg['train_dir']}")
print(f"   Validation     : {cfg['val_dir']}")
print(f"   Test           : {cfg['test_dir']}")

# ──────────────────────────────────────────────────────────────────────────────
# 🔍  Verify directory existence & detailed sample counts
# ──────────────────────────────────────────────────────────────────────────────
import pathlib

def count_sample_folders(split_path: pathlib.Path):
    counts = dict(original=0, mixed=0, irmasTest=0, other=0)
    if not split_path.exists():
        return counts  # all zeros
    for d in split_path.iterdir():
        if not d.is_dir():
            continue
        n = d.name
        if n.startswith("original_"):
            counts["original"] += 1
        elif n.startswith("mixed_"):
            counts["mixed"] += 1
        elif n.startswith("irmasTest_"):
            counts["irmasTest"] += 1
        else:
            counts["other"] += 1
    return counts


print("\n🔍 Verifying data directories & split composition:")
for split in ["train", "val", "test"]:
    path = pathlib.Path(cfg[f"{split}_dir"])
    counts = count_sample_folders(path)
    total = sum(counts.values())

    status = "✅" if total else "❌"
    print(f"\n{status} {split.upper()} ({path}): {total} total sample folders")
    print(f"      • original_:  {counts['original']}")
    print(f"      • mixed_:     {counts['mixed']}")
    print(f"      • irmasTest_: {counts['irmasTest']}")
    if counts["other"]:
        print(f"      • other:      {counts['other']}  ← check if expected!")

    # sanity warnings
    if split in ("train", "val") and counts["irmasTest"]:
        print("      ⚠️  Unexpected irmasTest_ folders in this split!")
    if split == "test" and (counts["original"] or counts["mixed"]):
        print("      ⚠️  Test split should contain ONLY irmasTest_ folders.")

# ──────────────────────────────────────────────────────────────────────────────
# 🎛️  Final training configuration summary
# ──────────────────────────────────────────────────────────────────────────────
print("\n✅ Final training configuration:")
print(f"   Training samples limit : {cfg.get('max_samples', 'unlimited')}")
print(f"   Batch size             : {cfg.get('batch_size')}")
print(
    f"   Validation limit       : {cfg.get('limit_val_batches', 1.0)} "
    f"({'percentage' if cfg.get('limit_val_batches', 1.0) <= 1 else 'batches'})"
)
print(f"   Learning rate          : {cfg.get('learning_rate')}")
print(f"   Epochs                 : {cfg.get('num_epochs')}")


In [None]:
# Import required modules for the model
import torch
from var import LABELS
from models.panns_enhanced import MultiSTFTCNN_WithPANNs
from data.download_pnn import download_panns_checkpoint

n_classes = len(LABELS)

# Download PANNs checkpoint if needed
panns_path = download_panns_checkpoint()

# Create the enhanced model with PANNs
model = MultiSTFTCNN_WithPANNs(
    n_classes=n_classes,
    pretrained_path=panns_path,
    freeze_backbone=False
)

print("PANNs-Enhanced Architecture:")
print(model)

try:
    from torchinfo import summary

    class ModelWrapper(torch.nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model

        def forward(self, x1, x2, x3):
            return self.model([x1, x2, x3])

    wrapped_model = ModelWrapper(model)

    dummy_inputs = [
        torch.zeros(1, 1, 128, 100),
        torch.zeros(1, 1, 128, 100),
        torch.zeros(1, 1, 128, 100)
    ]

    print("\nModel Summary:")
    summary(wrapped_model, input_data=dummy_inputs, verbose=1)

except ImportError:
    print("\nInstall torchinfo for detailed model summary: pip install torchinfo")
except Exception as e:
    print(f"\nCould not generate model summary: {e}")
    print("This is normal - the model architecture is still correctly defined.")

print("\n🔧 Manual Model Summary:")
print("   📊 Input: 3 spectrograms (optimized window sizes for each frequency band)")
print("   🧠 Architecture: 3 PANNs feature extractors + fusion layer + classifier")
print(f"   📤 Output: {n_classes} instrument classes")

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"   📈 Total Parameters: {total_params:,}")
print(f"   🎯 Trainable Parameters: {trainable_params:,}")
print("   🚀 Using PANNs pretrained weights for enhanced feature extraction")

print("\n🧪 Testing PANNs-enhanced model with dummy data...")
try:
    dummy_input = [torch.zeros(2, 1, 20, 30) for _ in range(3)]
    output = model(dummy_input)
    print("   ✅ Model test successful!")
    print(f"   📊 Input: 3 tensors of shape {dummy_input[0].shape}")
    print(f"   📤 Output shape: {output.shape}")
    print(f"   🎯 Output range: [{output.min():.3f}, {output.max():.3f}]")
    print("   ℹ️ The PANNs model already applies sigmoid in its classifier")
except Exception as e:
    print(f"   ❌ Model test failed: {e}")


In [None]:
# ╔════════════════════════════════════════════════════════════════════╗
# 🚀  TRAINING LAUNCH  –  with split summary & extra sanity checks
# ╚════════════════════════════════════════════════════════════════════╝
import pathlib
from collections import Counter
import traceback
import sys

print("🚀 Starting training…\n")

# ─────────────────────────────────────────────────────────────────────
# 1) Echo key configuration values
# ─────────────────────────────────────────────────────────────────────
print("📁 Configuration:")
print(f"   max_samples            : {cfg.get('max_samples', 'all')}")
print(f"   train_dir              : {cfg.get('train_dir', 'not set')}")
print(f"   val_dir                : {cfg.get('val_dir', 'not set')}")
print(f"   test_dir               : {cfg.get('test_dir', 'not set')}")
print(f"   batch_size             : {cfg.get('batch_size', 'not set')}")
print(
    f"   limit_val_batches      : {cfg.get('limit_val_batches', 1.0)} "
    f"({'percentage' if cfg.get('limit_val_batches', 1.0) <= 1 else 'batches'})"
)
print(f"   num_sanity_val_steps   : {cfg.get('num_sanity_val_steps', 'default')}\n")

# ─────────────────────────────────────────────────────────────────────
# 2) Split-directory summary helper
# ─────────────────────────────────────────────────────────────────────
def summarize_split(path: pathlib.Path):
    if not path.exists():
        return 0, "missing dir"
    counts = Counter(
        (p.name.split("_")[0] for p in path.iterdir() if p.is_dir())
    )  # original / mixed / irmasTest / …
    total = sum(counts.values())
    detail = ", ".join(f"{k}:{v}" for k, v in counts.items()) if counts else "empty"
    return total, detail


for split in ("train", "val", "test"):
    p = pathlib.Path(cfg.get(f"{split}_dir", ""))
    total, detail = summarize_split(p)
    status = "✅" if total else "⚠️"
    print(f"{status} {split.upper():5} → {total:4} sample folders ({detail})")

print()  # spacer

# ─────────────────────────────────────────────────────────────────────
# 3) Guard against an empty validation set
# ─────────────────────────────────────────────────────────────────────
val_total, _ = summarize_split(pathlib.Path(cfg["val_dir"]))
if val_total == 0:
    raise RuntimeError(
        "Validation split is empty! "
        "Increase `original_data_percentage` or check preprocessing."
    )

# ─────────────────────────────────────────────────────────────────────
# 4) Training routine with robust import / fallback
# ─────────────────────────────────────────────────────────────────────
try:
    from training.panns_train import main as train_main  # prefer PANNs variant
    print("✅ Imported training.panns_train.main")
except ImportError:
    print("ℹ️  training.panns_train not available – falling back to training.train")
    try:
        from training.train import main as train_main
    except ImportError as e:
        sys.exit(f"❌ Could not import training module: {e}")

# --------------------------------------------------------------------
# Launch training
# --------------------------------------------------------------------
try:
    train_main(cfg)  # your Lightning entry point
    print("🏁 Training completed successfully!")

    # OPTIONAL: automatic test-set evaluation
    # Comment out if your train_main already runs tests internally
    if hasattr(train_main, "__code__") and "run_test" in train_main.__code__.co_varnames:
        print("\n🔬 Running test-set evaluation…")
        train_main(cfg, run_test=True)
        print("✅ Test evaluation completed!")

except Exception as err:
    print(f"❌ Training error: {err}")
    traceback.print_exc(limit=2)
    # Additional debugging for empty dataset issues
    if "num_samples=0" in str(err).lower():
        td = pathlib.Path(cfg["train_dir"])
        if td.exists():
            items = list(td.iterdir())
            print(f"🔍 Train dir contains {len(items)} items. First few:")
            for itm in items[:5]:
                print("   •", itm.name)
        else:
            print(f"❌ Train dir {td} does not exist!")


In [None]:
# ╔══════════════════════════════════════════════╗
# 🚀  Inference on IRMAS test set (quick demo)   ║
# ╚══════════════════════════════════════════════╝
from pathlib import Path
from utils.model_loader import load_model
from inference.predict import predict_with_ground_truth
from var import LABELS
import glob
import re

# ── 1. pick best checkpoint ─────────────────────────────────────────
def find_best_checkpoint(log_dir: str = "lightning_logs") -> str | None:
    """
    Return the .ckpt that has the highest val_mAP.
    Ignores   - last.ckpt
              - files without epoch= / val_mAP=
    """
    ckpts = glob.glob(f"{log_dir}/*/checkpoints/*.ckpt")
    if not ckpts:
        print(f"❌ No checkpoints in {log_dir}")
        return None

    pat_ep  = re.compile(r"epoch=(\d+)")
    pat_map = re.compile(r"val_mAP=([0-9]+(?:\.[0-9]*)?)")

    best, best_map, best_ep = None, -1.0, -1
    for c in ckpts:
        m_ep  = pat_ep.search(c)
        m_map = pat_map.search(c)
        if not (m_ep and m_map):              # ← skip e.g. “last.ckpt”
            continue
        ep  = int(m_ep.group(1))
        mp  = float(m_map.group(1).rstrip("."))  # trims stray dot
        if mp > best_map or (mp == best_map and ep > best_ep):
            best, best_map, best_ep = c, mp, ep

    if best:
        print(f"✅ Using checkpoint {best} (epoch {best_ep}, mAP {best_map:.4f})")
        return best

    # Fallback – take the *first* ckpt (often last.ckpt) just so the
    # rest of the code can still run, but warn the user.
    print("⚠️  No metric-tagged checkpoints found; falling back to", ckpts[0])
    return ckpts[0]


ckpt = find_best_checkpoint()
model = load_model(ckpt, n_classes=len(LABELS)).eval()

# ── 2. load per-class thresholds (optional) ─────────────────────────
def load_thr(path):
    try:
        return yaml.safe_load(open(path))["thresholds"]
    except Exception:
        return {}

f1_thr  = load_thr("configs/optimal_thresholds_f1.yaml")

# ── 3. collect true *test* WAVs ──────────────────────────────────────
wav_root = Path(irmas_root)

wav_files = []
for part_dir in wav_root.glob("IRMAS-TestingData-Part*"):
    # pick the inner folder if it exists, else use the part dir itself
    inner = part_dir / "IRMAS-TestingData"
    search_root = inner if inner.exists() else part_dir
    wav_files.extend(search_root.rglob("*.wav"))

max_n = cfg.get("max_test_samples")
if max_n:
    wav_files = wav_files[:max_n]

print(f"🗂️  Inference on {len(wav_files)} WAVs")

# ── 4. run prediction ------------------------------------------------
LOG_EVERY = 5          # print every 20th file
for i, wav in enumerate(wav_files, 1):
    res = predict_with_ground_truth(
        model, str(wav), cfg, thresholds=f1_thr
    )

    # print only every k-th file (or change condition to whatever you prefer)
    if i % LOG_EVERY == 0 or i == 1 or i == len(wav_files):
        acts = ", ".join(res['active_instruments']) or "None"
        print(f"{i:>4}/{len(wav_files)} {wav.name} → {acts}")

print("✅  Inference finished!")


## Threshold Optimization

Optimize classification thresholds to improve instrument detection accuracy.


In [None]:
# ╔══════════════════════════════════════════════╗
# 🎯  Threshold optimisation & quick evaluation  ║
# ╚══════════════════════════════════════════════╝
import glob, re, sys, yaml, traceback
from pathlib import Path
import numpy as np
from utils.model_loader import load_model
from inference.predict import predict_with_ground_truth
from visualization.threshold_optimization import (
    find_optimal_thresholds, save_thresholds
)
from data.dataset import create_dataloaders
from var import LABELS

# ── helper to reuse the robust ckpt picker ───────────────────────────
def best_ckpt(log_dir="lightning_logs"):
    ckpts = glob.glob(f"{log_dir}/*/checkpoints/*.ckpt")
    if not ckpts:
        sys.exit(f"❌ No checkpoints inside {log_dir}")

    re_ep  = re.compile(r"epoch=(\d+)")
    re_map = re.compile(r"val_mAP=([0-9]+(?:\.[0-9]*)?)")

    best, best_map, best_ep = None, -1.0, -1
    for path in ckpts:
        m_ep, m_map = re_ep.search(path), re_map.search(path)
        if not (m_ep and m_map):            # ← skip malformed names
            continue

        ep   = int(m_ep.group(1))
        vmap = float(m_map.group(1).rstrip("."))   # remove stray dot

        if vmap > best_map or (vmap == best_map and ep > best_ep):
            best, best_map, best_ep = path, vmap, ep

    if best is None:
        sys.exit("🤔 Found ckpts but none had both epoch= and val_mAP= in the name")

    print(f"🏆 {Path(best).name}  (epoch {best_ep}, mAP {best_map:.3f})")
    return best

ckpt = best_ckpt()
model = load_model(ckpt, n_classes=len(LABELS))

# ── 1. optimise thresholds on validation split ───────────────────────
val_dir = Path(cfg["val_dir"])
_, val_loader = create_dataloaders(
    train_dir=val_dir, val_dir=val_dir,
    batch_size=cfg["batch_size"], num_workers=cfg["num_workers"],
    use_multi_stft=True
)

print("\n📊 Optimising F1 thresholds …")
f1_thr = find_optimal_thresholds(model, val_loader, metric="f1")
save_thresholds(f1_thr, "configs/optimal_thresholds_f1.yaml", "f1")

print("📊 Optimising balanced-acc thresholds …")
bal_thr = find_optimal_thresholds(model, val_loader, metric="balanced")
save_thresholds(bal_thr, "configs/optimal_thresholds_balanced.yaml", "balanced")
print("✅ Threshold optimisation done!\n")

# ── 2. quick demo on 5 random test files ─────────────────────────────
wav_root = Path(irmas_root)
wav_pool = list(wav_root.glob("IRMAS-TestingData-Part*/*.wav"))
np.random.seed(42)
demo = np.random.choice(wav_pool, 5, replace=False)

print("🔍 Demo predictions (F1 thresholds)")
for w in demo:
    res = predict_with_ground_truth(model, str(w), cfg, thresholds=f1_thr)
    print(f"{Path(w).name:35} → {res['active_instruments']}")