In [26]:

import sys
import warnings, tqdm

warnings.filterwarnings("ignore", category=tqdm.TqdmWarning)
sys.modules['tqdm.notebook'] = tqdm
sys.modules['tqdm.autonotebook'] = tqdm

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    import os

    # Always start fresh and clone the specific branch
    print("🗑️ Cleaning up any existing project...")
    %cd / content
    !rm -rf DL_Project

    #TODO: Fix the branch according to the latest changes
    print("📥 Cloning specific branch 'master'...")
    !git clone -b master https://github.com/ofekdd/DL_Project.git
    %cd DL_Project

    # Verify we're on the correct branch
    print("🔍 Verifying branch...")
    !git branch
    !git log --oneline -n 3

    # Install dependencies
    print("📦 Installing dependencies...")
    !pip install -r requirements.txt

    print("✅ Setup complete with branch 'master'!")

In [27]:
# Check the current working directory and ensure it is the project root
from pathlib import Path
print("CWD :", Path.cwd())                    # where the kernel is running
print("Exists?", Path('configs').is_dir())    # should be True if CWD is project root


CWD : /home/odahan/Technion/Semester_8/Deep_Learning/Project/notebooks
Exists? False


In [28]:
import yaml
import os

# Define the path to the YAML configuration file
workspace = '/home/odahan/Technion/Semester_8/Deep_Learning/Project'
yaml_path = 'configs/panns_enhanced.yaml' if IN_COLAB else f'{workspace}/configs/panns_enhanced.yaml'
print(yaml_path)
# Open and load the YAML file
with open(yaml_path, 'r') as file:
    cfg = yaml.safe_load(file)

print("PANNs-enhanced configuration:")
for key, value in cfg.items():
    print(f"  {key}: {value}")

/home/odahan/Technion/Semester_8/Deep_Learning/Project/configs/multi_stft_cnn.yaml
9cnn configuration:
  model_name: multi_stft_cnn
  sample_rate: 22050
  n_mels: 64
  hop_length: 512
  batch_size: 8
  num_epochs: 50
  learning_rate: 2e-4
  num_workers: 4
  n_branches: 9
  branch_output_dim: 128


In [None]:
# Download the IRMAS dataset if needed
from data.download_irmas import main as download_irmas_main, find_irmas_root
import pathlib
import os

# Check for existing dataset in user's home directory first
home_dataset_path = pathlib.Path.home() / "datasets" / "irmas" / "IRMAS.zip"

# Determine the appropriate download location based on environment
if IN_COLAB:
    # For Colab, use Google Drive to store the dataset (already mounted)
    DATA_CACHE = "/content/drive/MyDrive/datasets/IRMAS"
else:
    # For local environment, check if dataset exists in home directory
    if home_dataset_path.exists():
        print(f"Found existing dataset at {home_dataset_path}")
        DATA_CACHE = str(home_dataset_path.parent)
    else:
        # Fall back to project directory
        DATA_CACHE = "data/raw"

# Create the directory if it doesn't exist
os.makedirs(DATA_CACHE, exist_ok=True)
# Only download if we don't have the zip file already
zip_path = pathlib.Path(DATA_CACHE) / "IRMAS.zip"
if zip_path.exists():
    print(f"Dataset already exists at {zip_path}, skipping download...")
else:
    print(f"Downloading IRMAS dataset to {DATA_CACHE}...")
    download_irmas_main(pathlib.Path(DATA_CACHE))

# Find the IRMAS dataset root
irmas_root = find_irmas_root()
print(f"IRMAS root found at: {irmas_root}")

In [None]:
# Fix NumPy compatibility issue
import sys

print("🔧 Fixing NumPy compatibility...")

# Check current NumPy version
import numpy as np

print(f"Current NumPy version: {np.__version__}")

# If NumPy 2.0+, we need to downgrade or use a workaround
if int(np.__version__.split('.')[0]) >= 2:
    print("⚠️  NumPy 2.0+ detected. Installing compatible version...")
    !pip install "numpy<2.0" --quiet

    # Restart the kernel to load the new NumPy version
    print("🔄 Restarting kernel to load compatible NumPy...")
    import os

    os.kill(os.getpid(), 9)  # This will restart the kernel in Colab
else:
    print("✅ NumPy version is compatible")

In [None]:
# Convert the training dataset into multi-label format
from data.mix_labels import create_multilabel_dataset

if irmas_root:
    print("Creating multi-label dataset from IRMAS...")
    print(f"📁 Dataset creation settings from config:")
    print(f"   max_original_samples: {cfg.get('max_original_samples', 50)}")
    print(f"   num_mixtures: {cfg.get('num_mixtures', 100)}")
    print(f"   min_instruments: {cfg.get('min_instruments', 1)}")
    print(f"   max_instruments: {cfg.get('max_instruments', 2)}")

    # Fix the path issue
    if irmas_root.name == "IRMAS-TrainingData":
        corrected_root = irmas_root.parent
        print(f"🔧 Adjusting path from {irmas_root} to {corrected_root}")
    else:
        corrected_root = irmas_root

    print(f"📁 Using root path: {corrected_root}")

    # Create both original and mixed datasets using config parameters
    original_dataset, mixed_dataset = create_multilabel_dataset(
        irmas_root=corrected_root,
        cfg=cfg  # All parameters now come from config
    )

    # Show final summary
    if mixed_dataset:
        print(f"\n📈 Dataset Summary:")
        print(f"   Original samples: {len(original_dataset)}")
        print(f"   Mixed samples: {len(mixed_dataset)}")
        print(f"   Total: {len(original_dataset) + len(mixed_dataset)}")
else:
    print("IRMAS root not found. Please run the download cell first.")

In [29]:

if irmas_root:
    print(f"IRMAS dataset found at: {irmas_root}")
    PROCESSED_DIR = "/content/IRMAS_features" if IN_COLAB else "data/processed"

    if 'mixed_dataset' in globals() and mixed_dataset:
        print(f"\nFound {len(mixed_dataset)} mixed samples from previous cell")

        # Use config value for original data percentage
        original_data_percentage = cfg.get('original_data_percentage', 0.1)
        print(f"Using {original_data_percentage*100}% of original IRMAS data (from config)")

        from data.preprocess import preprocess_mixed_data

        preprocess_mixed_data(
            irmas_root=irmas_root,
            mixed_dataset=mixed_dataset,
            out_dir=PROCESSED_DIR,
            cfg=cfg,
            original_data_percentage=original_data_percentage
        )

        print(f"✅ Preprocessing complete with mixed labels. Features saved to {PROCESSED_DIR}")

    else:
        print("No mixed dataset found. Running standard preprocessing...")
        print(f"To preprocess the data, you can run:")
        print(f"python data/preprocess.py --in_dir {irmas_root} --out_dir {PROCESSED_DIR}")

        # Run standard preprocessing
        preprocess_cmd = f"!python data/preprocess.py --in_dir {irmas_root} --out_dir {PROCESSED_DIR} --config configs/default.yaml"
        print(f"\nExecuting: {preprocess_cmd}")
        !python data/preprocess.py --in_dir {irmas_root} --out_dir {PROCESSED_DIR} --config configs/default.yaml

else:
    print("Could not locate IRMAS dataset after download. Check paths and try again.")

Downloading IRMAS dataset to data/raw...
Archive already exists, skipping download
Verifying checksum ...
Extracting ...
Done. Data at data/raw
IRMAS dataset found at: data/raw/IRMAS-TrainingData

To preprocess the data, you can run:
python data/preprocess.py --in_dir data/raw/IRMAS-TrainingData --out_dir data/processed

Or execute this command in the next cell:
!python data/preprocess.py --in_dir data/raw/IRMAS-TrainingData --out_dir data/processed


In [None]:
# Verify the train/val/test split after preprocessing

PROCESSED_DIR = "/content/IRMAS_features" if IN_COLAB else "data/processed"


def count_samples_in_dir(dir_path):
    """Count samples in a directory (both original and mixed)."""
    if not pathlib.Path(dir_path).exists():
        return 0, 0

    # Count directories (each represents one sample)
    all_dirs = [d for d in pathlib.Path(dir_path).iterdir() if d.is_dir()]
    mixed_dirs = [d for d in all_dirs if 'mixed_' in d.name]
    original_dirs = [d for d in all_dirs if 'mixed_' not in d.name]

    return len(original_dirs), len(mixed_dirs)


# Check each split
for split in ['train', 'val', 'test']:
    split_dir = f"{PROCESSED_DIR}/{split}"
    original_count, mixed_count = count_samples_in_dir(split_dir)
    total_count = original_count + mixed_count

    print(f"📁 {split.upper()} split:")
    print(f"   Original samples: {original_count}")
    print(f"   Mixed samples: {mixed_count}")
    print(f"   Total: {total_count}")
    print()

print("✅ Data split verification complete!")

In [35]:
# Import required modules for the model
import torch
from var import LABELS
from models.panns_enhanced import MultiSTFTCNN_WithPANNs
from data.download_pnn import download_panns_checkpoint

n_classes = len(LABELS)

# Download PANNs checkpoint if needed
panns_path = download_panns_checkpoint()

# Create the enhanced model with PANNs
model = MultiSTFTCNN_WithPANNs(
    n_classes=n_classes,  # Number of instrument classes
    pretrained_path=panns_path,
    freeze_backbone=False  # Use full model for inference
)

print("PANNs-Enhanced Architecture:")
print(model)

# Fixed model summary for MultiSTFTCNN
try:
    from torchinfo import summary

    # Create a wrapper class that handles the input format correctly
    class ModelWrapper(torch.nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model

        def forward(self, x1, x2, x3, x4, x5, x6, x7, x8, x9):
            # Convert individual tensors back to list format
            x_list = [x1, x2, x3, x4, x5, x6, x7, x8, x9]
            return self.model(x_list)

    # Wrap the model
    wrapped_model = ModelWrapper(model)

    # Create dummy input tensors with realistic dimensions
    # Each spectrogram will have different frequency bins based on the FFT size and frequency band
    dummy_inputs = [
        torch.zeros(1, 1, 32, 100),   # Band 1, FFT 256
        torch.zeros(1, 1, 64, 100),   # Band 1, FFT 512
        torch.zeros(1, 1, 128, 100),  # Band 1, FFT 1024
        torch.zeros(1, 1, 48, 100),   # Band 2, FFT 256
        torch.zeros(1, 1, 96, 100),   # Band 2, FFT 512
        torch.zeros(1, 1, 192, 100),  # Band 2, FFT 1024
        torch.zeros(1, 1, 89, 100),   # Band 3, FFT 256
        torch.zeros(1, 1, 178, 100),  # Band 3, FFT 512
        torch.zeros(1, 1, 356, 100),  # Band 3, FFT 1024
    ]

    print("\nModel Summary:")
    summary(wrapped_model, input_data=dummy_inputs, verbose=1)

except ImportError:
    print("\nInstall torchinfo for detailed model summary: pip install torchinfo")
except Exception as e:
    print(f"\nCould not generate model summary: {e}")
    print("This is normal - the model architecture is still correctly defined.")

# Alternative: Simple manual summary
print(f"\n🔧 Manual Model Summary:")
print("   📊 Input: 9 spectrograms (3 frequency bands × 3 FFT sizes)")
print(f"   🧠 Architecture: 9 PANNs feature extractors + fusion layer + classifier")
print(f"   📤 Output: {n_classes} instrument classes")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"   📈 Total Parameters: {total_params:,}")
print(f"   🎯 Trainable Parameters: {trainable_params:,}")
print(f"   🚀 Using PANNs pretrained weights for enhanced feature extraction")

# Test with actual dummy data to verify the model works
print(f"\n🧪 Testing PANNs-enhanced model with dummy data...")
try:
    # Create dummy input in the correct format (list of tensors)
    dummy_input = [torch.zeros(2, 1, 20, 30) for _ in range(9)]  # Batch size 2
    output = model(dummy_input)
    print(f"   ✅ Model test successful!")
    print(f"   📊 Input: 9 tensors of shape {dummy_input[0].shape}")
    print(f"   📤 Output shape: {output.shape}")
    print(f"   🎯 Output range: [{output.min():.3f}, {output.max():.3f}]")
    print(f"   ℹ️ The PANNs model already applies sigmoid in its classifier")
except Exception as e:
    print(f"   ❌ Model test failed: {e}")

9 CNN Baseline Architecture:
MultiSTFTCNN(
  (branches): ModuleList(
    (0-8): 9 x STFTBranch(
      (cnn): Sequential(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU()
        (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (10): ReLU()
        (11): AdaptiveAvgPool2d(output_size=(1, 1))
        (12): Flatten(start_dim=1, end_dim=-1)
      )
    )
  )
  (class

In [40]:
# Configure data paths and training settings from config
print("🔧 Configuring data paths and training settings...")

# Get base settings from config
base_max_samples = cfg.get('max_samples', None)
print(f"📁 Base configuration from YAML:")
print(f"   max_samples: {base_max_samples}")
print(f"   max_original_samples: {cfg.get('max_original_samples', 50)}")
print(f"   num_mixtures: {cfg.get('num_mixtures', 100)}")
print(f"   min_instruments: {cfg.get('min_instruments', 1)}")
print(f"   max_instruments: {cfg.get('max_instruments', 2)}")
print(f"   original_data_percentage: {cfg.get('original_data_percentage', 0.1)}")

# Optional notebook-level override for quick experimentation
# Uncomment and modify these lines to override config values:
# cfg['max_samples'] = 30  # Override for faster notebook testing
# cfg['max_original_samples'] = 25  # Override dataset creation
# cfg['num_mixtures'] = 50  # Override number of synthetic mixtures

# Check if any overrides were applied
if base_max_samples != cfg.get('max_samples'):
    print(f"⚠️  Notebook override: max_samples changed to {cfg.get('max_samples')}")

# Add the processed data directory to config
PROCESSED_DIR = "/content/IRMAS_features" if IN_COLAB else "data/processed"
cfg['data_dir'] = PROCESSED_DIR
cfg['train_dir'] = f"{PROCESSED_DIR}/train"
cfg['val_dir'] = f"{PROCESSED_DIR}/val"
cfg['test_dir'] = f"{PROCESSED_DIR}/test"

print(f"\n📂 Data directories:")
print(f"   Processed data: {PROCESSED_DIR}")
print(f"   Training: {cfg['train_dir']}")
print(f"   Validation: {cfg['val_dir']}")
print(f"   Test: {cfg['test_dir']}")

# Verify that data directories exist and contain samples
import pathlib

print(f"\n🔍 Verifying data directories:")
for split in ['train', 'val', 'test']:
    split_dir = f"{PROCESSED_DIR}/{split}"
    if pathlib.Path(split_dir).exists():
        sample_count = len([d for d in pathlib.Path(split_dir).iterdir() if d.is_dir()])
        print(f"   {split}: {sample_count} samples")
        if sample_count == 0:
            print(f"   ⚠️  Warning: {split} directory is empty!")
    else:
        print(f"   ❌ {split} directory doesn't exist!")

# Display final configuration summary
print(f"\n✅ Final training configuration:")
print(f"   Training samples limit: {cfg.get('max_samples', 'unlimited')}")
print(f"   Batch size: {cfg.get('batch_size')}")
print(f"   Validation limit: {cfg.get('limit_val_batches', 1.0)} ({'percentage' if cfg.get('limit_val_batches', 1.0) <= 1 else 'batches'})")
print(f"   Learning rate: {cfg.get('learning_rate')}")
print(f"   Epochs: {cfg.get('num_epochs')}")

Training with limited samples: 1


In [None]:
# Training with better error handling and debugging
print("🚀 Starting training...")
print(f"📁 Configuration:")
print(f"   max_samples: {cfg.get('max_samples', 'all')}")
print(f"   train_dir: {cfg.get('train_dir', 'not set')}")
print(f"   val_dir: {cfg.get('val_dir', 'not set')}")
print(f"   batch_size: {cfg.get('batch_size', 'not set')}")
print(f"   limit_val_batches: {cfg.get('limit_val_batches', 1.0)} ({'percentage' if cfg.get('limit_val_batches', 1.0) <= 1 else 'batches'})")
print(f"   num_sanity_val_steps: {cfg.get('num_sanity_val_steps', 'default')}")

# Optional: Override validation settings for even faster development
# cfg['limit_val_batches'] = 0.05  # Use only 5% for even faster validation
# cfg['num_sanity_val_steps'] = 1   # Minimal sanity checks

try:
    # Try direct import first
    from training.panns_train import main as train_main

    print("✅ Direct import successful, PANNs-enhanced training...")

    # Debug: Check if training directory has data
    train_dir = cfg.get('train_dir')
    if train_dir and pathlib.Path(train_dir).exists():
        sample_dirs = [d for d in pathlib.Path(train_dir).iterdir() if d.is_dir()]
        print(f"🔍 Found {len(sample_dirs)} training samples")
        if len(sample_dirs) == 0:
            print("❌ Training directory is empty! Cannot proceed.")
            print("💡 Make sure you've run the preprocessing step successfully.")
        else:
            # Show first few samples
            print(f"📂 Sample directories: {[d.name for d in sample_dirs[:3]]}")
            train_main(cfg)
            print("🎉 Training completed successfully!")
    else:
        print(f"❌ Training directory not found: {train_dir}")
        print("💡 Make sure you've run the preprocessing step successfully.")

except ImportError as import_error:
    print(f"❌ Import error: {import_error}")
    print("💡 This likely means pytorch-lightning is not installed.")

    if IN_COLAB:
        print("📦 Installing pytorch-lightning...")
        !pip install pytorch-lightning>=2.0.0 --quiet
        print("✅ Dependency installed, retrying...")

        # Retry after installation
        try:
            from training.train import main as train_main

            train_main(cfg)
            print("🎉 Training completed successfully!")
        except Exception as retry_error:
            print(f"❌ Still failing after dependency installation: {retry_error}")
            import traceback

            traceback.print_exc()
    else:
        print("Please install: pip install pytorch-lightning>=2.0.0")

except Exception as general_error:
    print(f"❌ Training error: {general_error}")
    print("📋 Error details:")
    import traceback

    traceback.print_exc()

    # Additional debugging
    if "num_samples=0" in str(general_error):
        print("\n🔍 Debugging empty dataset issue:")
        train_dir = cfg.get('train_dir')
        if train_dir:
            print(f"   Checking {train_dir}...")
            if pathlib.Path(train_dir).exists():
                sample_dirs = list(pathlib.Path(train_dir).iterdir())
                print(f"   Found {len(sample_dirs)} items in training directory")
                for item in sample_dirs[:5]:
                    print(f"     - {item.name} ({'dir' if item.is_dir() else 'file'})")
            else:
                print(f"   Directory {train_dir} does not exist!")
        else:
            print("   train_dir not set in config!")

Error with direct import: expected str, bytes or os.PathLike object, not dict
Falling back to shell command
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type             | Params
---------------------------------------------
0 | model   | MultiSTFTCNN     | 850 K 
1 | metrics | MetricCollection | 0     
---------------------------------------------
850 K     Trainable params
0         Non-trainable params
850 K     Total params
3.403     Total estimated model params size (MB)
  rank_zero_warn(
Epoch 0: 100%|█| 1/1 [00:01<00:00,  1.01s/it, v_num=2, train/loss=0.738, train/m
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                       | 0/160 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                          | 0/160 [00:00<?, ?it/s][A
Validation DataLoader 0:   1%|                  | 1/160 [00:02<06:30,  2.46s/it][

In [None]:
# Inference and visualization using the test set

import glob
import re
from pathlib import Path
import yaml
import torch
import librosa
import matplotlib.pyplot as plt
import numpy as np


In [None]:
# Setup inference with adaptive thresholds
print("🚀 Running inference with adaptive thresholds")

# Step 1: Find the best checkpoint
def find_best_checkpoint(lightning_logs_dir="lightning_logs"):
    """Find the best checkpoint based on validation mAP"""
    # Find all checkpoints
    checkpoint_pattern = f"{lightning_logs_dir}/*/checkpoints/*.ckpt"
    checkpoints = glob.glob(checkpoint_pattern)

    if not checkpoints:
        print(f"❌ No checkpoints found in {lightning_logs_dir}")
        return None

    print(f"🔍 Found {len(checkpoints)} checkpoint(s)")

    best_checkpoint = None
    best_map = -1
    best_epoch = -1

    for ckpt_path in checkpoints:
        ckpt_name = Path(ckpt_path).name

        # Parse metrics from filename
        epoch_match = re.search(r'epoch=(\d+)', ckpt_name)
        map_match = re.search(r'val_mAP=([0-9.]+)(?=\.ckpt)', ckpt_name)

        if epoch_match and map_match:
            epoch = int(epoch_match.group(1))
            val_map = float(map_match.group(1))

            if val_map > best_map or (val_map == best_map and epoch > best_epoch):
                best_checkpoint = ckpt_path
                best_map = val_map
                best_epoch = epoch

    if best_checkpoint:
        print(f"✅ Selected best checkpoint: {best_checkpoint}")
        print(f"   Epoch: {best_epoch}, val_mAP: {best_map:.4f}")
        return best_checkpoint
    else:
        # Fallback to first available checkpoint
        print(f"⚠️ Selecting first available checkpoint: {checkpoints[0]}")
        return checkpoints[0]

# Step 2: Load thresholds and model
def load_adaptive_thresholds(threshold_file="configs/optimal_thresholds_f1.yaml"):
    """Load instrument-specific thresholds from file"""
    try:
        with open(threshold_file, 'r') as f:
            threshold_data = yaml.safe_load(f)
            thresholds = threshold_data.get('thresholds', {})
            print(f"✅ Loaded {len(thresholds)} thresholds from {threshold_file}")
            return thresholds
    except Exception as e:
        print(f"⚠️ Could not load thresholds: {e}")
        return None

# Find best checkpoint
ckpt_path = find_best_checkpoint()

# Load thresholds
f1_thresholds = load_adaptive_thresholds("configs/optimal_thresholds_f1.yaml")
balanced_thresholds = load_adaptive_thresholds("configs/optimal_thresholds_balanced.yaml")

# Setup model and config
if ckpt_path:
    # Load configuration
    config_path = yaml_path
    with open(config_path, 'r') as f:
        cfg = yaml.safe_load(f)

    # Import the model loading function
    from inference.predict import load_model_from_checkpoint
    from var import LABELS

    # Load model
    model = load_model_from_checkpoint(ckpt_path, len(LABELS))
    model.eval()

    print("✅ Model loaded successfully!")

    # Find test files
    if 'irmas_root' in globals() and irmas_root:
        test_files = list(Path(irmas_root).rglob("*.wav"))[:5]  # Limit to 5 files

        if test_files:
            print(f"📊 Running inference on {len(test_files)} test files")

            # Import prediction function
            from inference.predict import predict_with_ground_truth

            for i, wav_file in enumerate(test_files):
                print(f"\n🎵 File {i+1}/{len(test_files)}: {wav_file.name}")

                # Run with different threshold strategies
                strategies = {
                    "Fixed (0.5)": None,
                    "F1 Optimized": f1_thresholds,
                    "Balanced": balanced_thresholds
                }

                results = {}
                for name, thresholds in strategies.items():
                    if name != "Fixed (0.5)" and thresholds is None:
                        continue

                    # Set default threshold based on strategy
                    default_threshold = 0.5 if name == "Fixed (0.5)" else 0.5

                    # Run prediction with appropriate thresholds
                    result = predict_with_ground_truth(
                        model, str(wav_file), cfg,
                        threshold=default_threshold,
                        thresholds=thresholds
                    )

                    results[name] = result

                # Display ground truth if available
                if "ground_truth" in results["Fixed (0.5)"]:
                    gt = results["Fixed (0.5)"]["ground_truth"]
                    print(f"🎯 Ground truth: {gt}")

                # Compare results
                print("\n📊 Detected instruments by threshold strategy:")
                for name, result in results.items():
                    active = result["active_instruments"]
                    active_str = ", ".join(active) if active else "None"
                    acc = result.get("accuracy", None)
                    acc_str = f" (Accuracy: {acc:.2f})" if acc is not None else ""
                    print(f"   {name}: {active_str}{acc_str}")

                # Show top 5 predictions with scores
                print("\n📊 Top 5 predictions (with Fixed threshold):")
                top_preds = sorted(results["Fixed (0.5)"]["predictions"].items(), key=lambda x: x[1], reverse=True)[:5]
                for label, score in top_preds:
                    threshold = f1_thresholds.get(label, 0.5) if f1_thresholds else 0.5
                    active = "✅" if score >= threshold else "❌"
                    print(f"   {active} {label:<15} {score:.4f} (threshold: {threshold:.2f})")

                # Visualize first file only
                if i == 0:
                    try:
                        # Load audio for visualization
                        y, sr = librosa.load(str(wav_file), sr=cfg['sample_rate'])

                        # Plot waveform
                        plt.figure(figsize=(12, 4))
                        plt.plot(np.linspace(0, len(y)/sr, len(y)), y)
                        plt.title(f"Waveform: {wav_file.name}")
                        plt.xlabel("Time (s)")
                        plt.ylabel("Amplitude")
                        plt.show()

                        # Plot spectrogram
                        D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max)
                        plt.figure(figsize=(12, 6))
                        librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='log')
                        plt.colorbar(format='%+2.0f dB')
                        plt.title(f"Spectrogram: {wav_file.name}")
                        plt.show()
                    except Exception as e:
                        print(f"⚠️ Visualization error: {e}")
        else:
            print("❌ No test files found")
    else:
        print("❌ IRMAS root not found")
else:
    print("❌ No checkpoint found. Please run training first.")


## Threshold Optimization

Optimize classification thresholds to improve instrument detection accuracy.


In [None]:
# Run threshold optimization if needed
print("🎯 Threshold Optimization")

# Check if we already have a best checkpoint
if not 'ckpt_path' in globals() or ckpt_path is None:
    print("❌ No checkpoint found! Run the training cell first.")
    ckpt_path = None
else:
    print(f"✅ Using checkpoint: {ckpt_path}")

    # Check if thresholds already exist
    f1_thresholds_exist = os.path.exists('configs/optimal_thresholds_f1.yaml')
    balanced_thresholds_exist = os.path.exists('configs/optimal_thresholds_balanced.yaml')

    if f1_thresholds_exist and balanced_thresholds_exist:
        print("✅ Threshold files already exist. Set regenerate=True to recreate them.")
        regenerate = False  # Change to True to force regeneration
    else:
        print("⚠️ Threshold files not found. Will generate them.")
        regenerate = True

    if regenerate:
        print("\n🧪 Running threshold optimization...")

        try:
            # Import the threshold optimization module
            from visualization.threshold_optimization import find_optimal_thresholds, save_thresholds
            from data.dataset import create_dataloaders
            from inference.predict import load_model_from_checkpoint
            from var import LABELS

            # Load model
            model = load_model_from_checkpoint(ckpt_path, len(LABELS), cfg)

            # Create validation dataloader
            val_dir = cfg.get('val_dir', 'data/processed/val')
            _, val_loader = create_dataloaders(
                train_dir="data/processed/train",  # Not used
                val_dir=val_dir,
                batch_size=cfg.get('batch_size', 32),
                num_workers=cfg.get('num_workers', 4),
                use_multi_stft=True
            )

            # Generate F1-optimized thresholds
            print("\n📊 Optimizing thresholds for F1 score...")
            f1_thresholds = find_optimal_thresholds(model, val_loader, metric='f1')
            save_thresholds(f1_thresholds, 'configs/optimal_thresholds_f1.yaml', 'f1')

            # Generate balanced accuracy thresholds
            print("\n📊 Optimizing thresholds for balanced accuracy...")
            balanced_thresholds = find_optimal_thresholds(model, val_loader, metric='balanced')
            save_thresholds(balanced_thresholds, 'configs/optimal_thresholds_balanced.yaml', 'balanced')

            print("\n✅ Threshold optimization complete!")
            print("   You can now use these thresholds for improved instrument detection.")

        except Exception as e:
            print(f"❌ Error during threshold optimization: {e}")
            print("   You can still run threshold optimization manually:")
            print("   python visualization/optimize_thresholds.py CHECKPOINT_PATH")
    else:
        print("\nℹ️ To manually run threshold optimization:")
        print("   python visualization/optimize_thresholds.py CHECKPOINT_PATH")
        print("   python visualization/optimize_thresholds.py --metric balanced CHECKPOINT_PATH")


## Adaptive Threshold Evaluation (New Feature)


In [None]:
print("🔍 Loading model for instrument recognition")

# Load adaptive thresholds for improved accuracy
f1_thresholds = None
try:
    # Get thresholds path based on environment
    f1_thresholds_path = 'configs/optimal_thresholds_f1.yaml' if not IN_COLAB else '/content/DL_Project/configs/optimal_thresholds_f1.yaml'

    if os.path.exists(f1_thresholds_path):
        print(f"✅ Found F1-optimized thresholds at {f1_thresholds_path}")
        with open(f1_thresholds_path, 'r') as f:
            threshold_data = yaml.safe_load(f)
            f1_thresholds = threshold_data.get('thresholds', {})
            print(f"   Loaded {len(f1_thresholds)} instrument-specific thresholds")
    else:
        print(f"⚠️ F1 thresholds file not found")

except Exception as e:
    print(f"⚠️ Error loading thresholds: {e}")


# Find the best checkpoint
def find_best_checkpoint(lightning_logs_dir=None):
    """Find the best checkpoint with highest validation mAP"""
    # Auto-detect lightning_logs directory
    if lightning_logs_dir is None:
        possible_dirs = [
            "/content/DL_Project/lightning_logs",
            "/content/DL_Project/DL_Project/lightning_logs",
            "lightning_logs",
            "./lightning_logs"
        ]

        for dir_path in possible_dirs:
            if Path(dir_path).exists():
                lightning_logs_dir = dir_path
                print(f"🔍 Found lightning_logs at: {lightning_logs_dir}")
                break

        if lightning_logs_dir is None:
            print("❌ Could not find lightning_logs directory")
            return None

    # Get all checkpoint files
    checkpoint_pattern = f"{lightning_logs_dir}/*/checkpoints/*.ckpt"
    checkpoints = glob.glob(checkpoint_pattern)

    if not checkpoints:
        print(f"❌ No checkpoints found")
        return None

    print(f"🔍 Found {len(checkpoints)} checkpoint(s)")

    # Find best checkpoint based on validation mAP
    best_checkpoint = None
    best_map = -1
    best_epoch = -1

    for ckpt_path in checkpoints:
        ckpt_name = Path(ckpt_path).name
        epoch_match = re.search(r'epoch=(\d+)', ckpt_name)
        map_match = re.search(r'val_mAP=([0-9.]+)(?=\.ckpt)', ckpt_name)

        if epoch_match and map_match:
            epoch = int(epoch_match.group(1))
            try:
                val_map = float(map_match.group(1))
                # Select highest mAP, or highest epoch as tiebreaker
                if val_map > best_map or (val_map == best_map and epoch > best_epoch):
                    best_checkpoint = ckpt_path
                    best_map = val_map
                    best_epoch = epoch
            except ValueError:
                pass

    if best_checkpoint:
        print(f"\n🏆 Selected best checkpoint: {Path(best_checkpoint).name}")
        print(f"   📊 Epoch: {best_epoch}, val_mAP: {best_map:.4f}")
        return best_checkpoint
    elif checkpoints:  # Fallback to first checkpoint
        print(f"💡 Using first available checkpoint as fallback")
        return checkpoints[0]
    return None


# Load checkpoint and run inference
ckpt_path = find_best_checkpoint()
config_path = yaml_path

if not ckpt_path:
    print("❌ No checkpoint found! Make sure training completed successfully.")
else:
    print(f"✅ Will use checkpoint: {ckpt_path}")

    # Prepare test data
    PROCESSED_DIR = "/content/IRMAS_features" if IN_COLAB else "data/processed"
    test_data_dir = f"{PROCESSED_DIR}/test"

    # Load model and test files
    try:
        # Load configuration
        with open(config_path, 'r') as f:
            cfg = yaml.safe_load(f)

        # Load model using evaluation helper
        from visualization.evaluation import load_model_from_checkpoint

        model = load_model_from_checkpoint(ckpt_path, len(LABELS), cfg)
        print("✅ Model loaded successfully")

        # Get test files
        if 'irmas_root' in globals() and irmas_root:
            # Sample a few test files from the dataset
            all_wav_files = list(pathlib.Path(irmas_root).rglob("*.wav"))
            np.random.seed(42)  # For reproducibility
            np.random.shuffle(all_wav_files)
            val_split = int(len(all_wav_files) * 0.9)
            test_wav_files = all_wav_files[val_split:][:5]  # Limit to 5 files

            print(f"📊 Running inference on {len(test_wav_files)} test files")

            # Run inference on test files
            for i, wav_file in enumerate(test_wav_files):
                wav_path = str(wav_file)
                print(f"\n🎵 Testing file {i + 1}/{len(test_wav_files)}: {pathlib.Path(wav_path).name}")

                # Run prediction with adaptive thresholds if available
                from inference.predict import predict_with_ground_truth

                result = predict_with_ground_truth(
                    model, wav_path, cfg,
                    show_ground_truth=True,
                    thresholds=f1_thresholds
                )

                # Display ground truth if available
                if "ground_truth" in result and result["ground_truth"]:
                    print(f"🎯 Ground truth: {result['ground_truth']}")

                # Display sorted predictions
                print("📊 Predicted instrument probabilities:")
                print("=" * 40)

                sorted_scores = sorted(result["predictions"].items(), key=lambda x: x[1], reverse=True)
                for label, score in sorted_scores:
                    confidence = "🔥" if score > 0.5 else "  "
                    print(f"  {confidence} {label:<15} {score:.4f}")

                # Show prediction status
                if "correct" in result:
                    status = "✅ CORRECT" if result["correct"] else "❌ INCORRECT"
                    print(f"🎯 Prediction status: {status}")

                # Visualize first file only
                if i == 0:
                    try:
                        from visualization.visualization import visualize_audio

                        print("\n📈 Visualizing waveform & spectrograms:")
                        visualize_audio(wav_path, cfg)
                    except Exception as viz_error:
                        print(f"⚠️ Visualization error: {viz_error}")

            print(f"\n🎉 Inference complete on {len(test_wav_files)} test files!")

            # Run comprehensive evaluation if test directory exists
            if pathlib.Path(test_data_dir).exists():
                print(f"\n🔍 Running comprehensive evaluation...")
                try:
                    from visualization.evaluation import run_comprehensive_evaluation

                    eval_results = run_comprehensive_evaluation(
                        checkpoint_path=ckpt_path,
                        test_dir=test_data_dir,
                        config_path=config_path,
                        threshold=0.5
                    )

                    if eval_results:
                        print("✅ Comprehensive evaluation completed!")
                except Exception as e:
                    print(f"⚠️ Evaluation error: {e}")
            else:
                print("⚠️ Test data directory not found for comprehensive evaluation")
        else:
            print("❌ Original IRMAS dataset root not found")
    except Exception as e:
        print(f"❌ Error during model loading or inference: {e}")
        import traceback

        traceback.print_exc()