In [None]:
# TODO: import stuff
import sys
import random
import os
import torch
import librosa
import numpy as np
import matplotlib.pyplot as plt
import yaml
import urllib.request
from torch.utils.data import DataLoader, Dataset
import warnings, tqdm
from visualization.visualization import visualize_audio

warnings.filterwarnings("ignore", category=tqdm.TqdmWarning)
sys.modules['tqdm.notebook'] = tqdm
sys.modules['tqdm.autonotebook'] = tqdm
from tqdm import tqdm  # now `tqdm(...)` is always the console bar

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Clone the repository
    !git clone https://github.com/ofekdd/DL_Project.git
    %cd DL_Project

    # Install dependencies
    !pip install -r requirements.txt



In [None]:
import yaml
import os

yaml_path = 'configs/multi_stft_cnn.yaml'

# Open and load the YAML file
with open(yaml_path, 'r') as file:
    cfg = yaml.safe_load(file)

print("9cnn configuration:")
for key, value in cfg.items():
    print(f"  {key}: {value}")

In [None]:
# Download the IRMAS dataset if needed
from data.download_irmas import main as download_irmas_main, find_irmas_root
import pathlib

# Determine the appropriate download location based on environment
if IN_COLAB:
    # For Colab, use Google Drive to store the dataset (already mounted)
    DATA_CACHE = "/content/drive/MyDrive/datasets/IRMAS"
else:
    # For local environment, store in the project directory
    DATA_CACHE = "data/raw"

# Create the directory if it doesn't exist
os.makedirs(DATA_CACHE, exist_ok=True)

# Download and extract the dataset
print(f"Downloading IRMAS dataset to {DATA_CACHE}...")
download_irmas_main(pathlib.Path(DATA_CACHE))

# Find the IRMAS dataset root
irmas_root = find_irmas_root()
if irmas_root:
    print(f"IRMAS dataset found at: {irmas_root}")

    # Define the processing output directory
    PROCESSED_DIR = "/content/IRMAS_features" if IN_COLAB else "data/processed"

    # Suggest the next step
    print(f"\nTo preprocess the data, you can run:")
    print(f"python data/preprocess.py --in_dir {irmas_root} --out_dir {PROCESSED_DIR}")

    # Optional: Add a cell to automatically run preprocessing if needed
    preprocess_cmd = f"!python data/preprocess.py --in_dir {irmas_root} --out_dir {PROCESSED_DIR}"
    print(f"\nOr execute this command in the next cell:")
    print(preprocess_cmd)
else:
    print("Could not locate IRMAS dataset after download. Check paths and try again.")

In [None]:
# Check if preprocessing is needed and run the preprocessing step
import os

# Create the directory if it doesn't exist
os.makedirs(PROCESSED_DIR, exist_ok=True)

# Check if preprocessing is needed
if irmas_root and not os.path.exists(f"{PROCESSED_DIR}/train") or len(os.listdir(f"{PROCESSED_DIR}/train")) == 0:
    print(f"Starting preprocessing from {irmas_root} to {PROCESSED_DIR}...")

    # Run the preprocessing command
    !python data/preprocess.py --in_dir {irmas_root} --out_dir {PROCESSED_DIR} --config configs/default.yaml

    print(f"Preprocessing complete. Features saved to {PROCESSED_DIR}")
else:
    print(f"Processed data already exists at {PROCESSED_DIR} - skipping preprocessing")

In [None]:
# Import required modules for the model
import torch
from var import LABELS
from models.multi_stft_cnn import MultiSTFTCNN

n_classes = len(LABELS)

# Create the model
model = MultiSTFTCNN(
    n_classes=n_classes,  # Number of instrument classes
    n_branches=9,  # 3 FFT sizes × 3 frequency bands
    branch_output_dim=128  # Default value for feature dimension
)

print("9 CNN Baseline Architecture:")
print(model)

# Optional: Print model summary if torchinfo is available
try:
    from torchinfo import summary
    # Create dummy input for the model (9 spectrograms with random dimensions)
    dummy_input = [torch.zeros(1, 1, 20, 30) for _ in range(9)]
    print("\nModel Summary:")
    summary(model, input_data=dummy_input)
except ImportError:
    print("\nInstall torchinfo for detailed model summary: pip install torchinfo")

In [None]:
try:
    from training.train import main as train_main
    train_main(cfg)
    print("Training completed!")
except Exception as e:
    print(f"Error with direct import: {e}")
    print("Falling back to shell command")
    !python -m training.train --config {cfg}

In [None]:
# TODO: Test and visualize