# **Cw1 - Music Isolation and Genre Classification**

### Enrique Saldivar Corona

# 1. Setup

In [1]:
import os
import numpy as np
import librosa
import torch
import torchaudio
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import soundfile as sf
import torchvision.models as models
import random
import glob
import pandas as pd
from tqdm.notebook import tqdm
import os
import time
import shutil
import zipfile
import warnings
import traceback
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.metrics import classification_report, accuracy_score, f1_score


In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [10]:

# Prioritize CUDA (NVIDIA GPU on Colab)
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

Using device: cuda


###  1.1 copy to local storage

In [15]:

print("Starting dataset copy (ZIP files) from Google Drive to Colab local storage.")
start_copy_time = time.time()

# --- Define Source Paths on Google Drive ---

DRIVE_PROJECT_PATH = "/content/drive/MyDrive/cw1_DL" # Your project folder in Drive
DRIVE_FMA_ZIP_PATH = os.path.join(DRIVE_PROJECT_PATH, "fma_small.zip") # Example name
DRIVE_ESC50_ZIP_PATH = os.path.join(DRIVE_PROJECT_PATH, "esc50_audio.zip") # Example name (assuming you zipped only the 'audio' folder contents)
DRIVE_FMA_METADATA_ZIP_PATH = os.path.join(DRIVE_PROJECT_PATH, "fma_metadata.zip") # Example name


# --- Define Destination Paths on Colab Local Disk ---
# Using /content/ is standard for temporary storage
COLAB_TEMP_ZIP_PATH = "/content/temp_zip" # Temporary location to store zips before extraction
COLAB_DATA_PATH = "/content/data"          # Final destination for extracted data

# --- Define final extracted data paths (ensure these match Cell [27] Config) ---
COLAB_FMA_AUDIO_PATH = os.path.join(COLAB_DATA_PATH, "fma_small")
COLAB_ESC50_AUDIO_PATH = os.path.join(COLAB_DATA_PATH, "esc50_audio")
COLAB_FMA_METADATA_PATH = os.path.join(COLAB_DATA_PATH, "fma_metadata")


# --- Create Destination Folders ---
# Clear previous data/zips if they exist
if os.path.exists(COLAB_DATA_PATH):
    print(f"Removing existing local data directory: {COLAB_DATA_PATH}")
    shutil.rmtree(COLAB_DATA_PATH)
if os.path.exists(COLAB_TEMP_ZIP_PATH):
     print(f"Removing existing temp zip directory: {COLAB_TEMP_ZIP_PATH}")
     shutil.rmtree(COLAB_TEMP_ZIP_PATH)

print(f"Creating local directories: {COLAB_TEMP_ZIP_PATH}, {COLAB_DATA_PATH}")
os.makedirs(COLAB_TEMP_ZIP_PATH, exist_ok=True)
os.makedirs(COLAB_DATA_PATH, exist_ok=True)
# Create subdirectories for extraction targets explicitly
os.makedirs(COLAB_FMA_AUDIO_PATH, exist_ok=True)
os.makedirs(COLAB_ESC50_AUDIO_PATH, exist_ok=True)
os.makedirs(COLAB_FMA_METADATA_PATH, exist_ok=True)


# --- Step 1: Copy ZIP files from Drive to Colab local temp storage ---
copy_errors = False
files_to_copy = {
    "FMA Audio ZIP": DRIVE_FMA_ZIP_PATH,
    "ESC-50 Audio ZIP": DRIVE_ESC50_ZIP_PATH,
    "FMA Metadata ZIP": DRIVE_FMA_METADATA_ZIP_PATH,
}

copied_zip_paths = {}

for name, drive_path in files_to_copy.items():
    if os.path.exists(drive_path):
        dest_path = os.path.join(COLAB_TEMP_ZIP_PATH, os.path.basename(drive_path))
        print(f"Copying {name} from {drive_path} to {dest_path}...")
        try:
            shutil.copy(drive_path, dest_path)
            copied_zip_paths[name] = dest_path # Store path for extraction
            print(f"{name} copy finished.")
        except Exception as e:
            print(f"!!! ERROR copying {name}: {e} !!!")
            copy_errors = True
    else:
        print(f"!!! WARNING: Source ZIP file not found: {drive_path} !!!")


copy_end_time = time.time()
print(f"\nZIP file copy attempt finished in {copy_end_time - start_copy_time:.2f} seconds.")

# --- Step 2: Extract ZIP files locally ---
extract_start_time = time.time()
print("\nStarting extraction...")

if "FMA Audio ZIP" in copied_zip_paths:
    print(f"Extracting FMA audio ({copied_zip_paths['FMA Audio ZIP']}) to {COLAB_FMA_AUDIO_PATH}...")
    # Use -q for quiet, -d specifies destination. Overwrites existing files without prompt.
    !unzip -q "{copied_zip_paths['FMA Audio ZIP']}" -d "{COLAB_FMA_AUDIO_PATH}"
    print("FMA audio extraction finished.")


if "ESC-50 Audio ZIP" in copied_zip_paths:
    print(f"Extracting ESC-50 audio ({copied_zip_paths['ESC-50 Audio ZIP']}) to {COLAB_ESC50_AUDIO_PATH}...")
    !unzip -q "{copied_zip_paths['ESC-50 Audio ZIP']}" -d "{COLAB_ESC50_AUDIO_PATH}"
    print("ESC-50 audio extraction finished.")


if "FMA Metadata ZIP" in copied_zip_paths:
     print(f"Extracting FMA metadata ({copied_zip_paths['FMA Metadata ZIP']}) to {COLAB_FMA_METADATA_PATH}...")
     !unzip -q "{copied_zip_paths['FMA Metadata ZIP']}" -d "{COLAB_FMA_METADATA_PATH}"
     print("FMA metadata extraction finished.")



extract_end_time = time.time()
print(f"\nExtraction finished in {extract_end_time - extract_start_time:.2f} seconds.")
total_time = time.time()
print(f"Total time (Copy + Extract): {total_time - start_copy_time:.2f} seconds.")

# --- Step 3: Verify extraction (Optional but recommended) ---
if not copy_errors: # Only verify if copy seemed okay
    print("\nVerifying extracted files (listing counts/contents):")
    print("\nFMA Audio subfolders (first level):")
    !ls "{COLAB_FMA_AUDIO_PATH}" | head -n 10 # List first 10 subfolders
    !find "{COLAB_FMA_AUDIO_PATH}" -maxdepth 1 -type d | wc -l # Count total subfolders

    print("\nESC-50 Audio files:")
    !ls "{COLAB_ESC50_AUDIO_PATH}" | head -n 10 # List first 10 files
    !ls "{COLAB_ESC50_AUDIO_PATH}" | wc -l # Count total files

    print("\nFMA Metadata files:")
    !ls "{COLAB_FMA_METADATA_PATH}" # List files in metadata
else:
    print("\n!!! Verification skipped due to copy errors or missing files. Please check paths and ZIP contents. !!!")


Starting dataset copy (ZIP files) from Google Drive to Colab local storage...
Removing existing local data directory: /content/data
Removing existing temp zip directory: /content/temp_zip
Creating local directories: /content/temp_zip, /content/data
Copying FMA Audio ZIP from /content/drive/MyDrive/cw1_DL/fma_small.zip to /content/temp_zip/fma_small.zip...
FMA Audio ZIP copy finished.
Copying ESC-50 Audio ZIP from /content/drive/MyDrive/cw1_DL/esc50_audio.zip to /content/temp_zip/esc50_audio.zip...
ESC-50 Audio ZIP copy finished.
Copying FMA Metadata ZIP from /content/drive/MyDrive/cw1_DL/fma_metadata.zip to /content/temp_zip/fma_metadata.zip...
FMA Metadata ZIP copy finished.

ZIP file copy attempt finished in 33.07 seconds.

Starting extraction...
Extracting FMA audio (/content/temp_zip/fma_small.zip) to /content/data/fma_small...
FMA audio extraction finished.
Extracting ESC-50 audio (/content/temp_zip/esc50_audio.zip) to /content/data/esc50_audio...
ESC-50 audio extraction finished.

# 2. Configuration

In [11]:
# Configuration

# Data Paths (Ensure these point to LOCAL Colab data

COLAB_DATA_PATH = "/content/data"
# Check if local paths exist, otherwise fall back to Drive (optional, safer to rely on copy)
if os.path.exists(os.path.join(COLAB_DATA_PATH, "fma_small", "fma_small")):
    print("Using LOCAL Colab data paths (adjusting for nested folders)...")
    FMA_AUDIO_PATH = os.path.join(COLAB_DATA_PATH, "fma_small", "fma_small")
    FMA_METADATA_PATH = os.path.join(COLAB_DATA_PATH, "fma_metadata", "fma_metadata")
    AMBIENCE_DIR = os.path.join(COLAB_DATA_PATH, "esc50_audio", "audio")
else:
    print("WARNING: Local data not found, falling back to Drive paths (TRAINING WILL BE SLOW)")
    DRIVE_BASE_PATH = "/content/drive/MyDrive/cw1_DL"
    FMA_AUDIO_PATH = os.path.join(DRIVE_BASE_PATH, "fma_small")
    FMA_METADATA_PATH = os.path.join(DRIVE_BASE_PATH, "fma_metadata") # Check nesting
    AMBIENCE_DIR = os.path.join(DRIVE_BASE_PATH, "ESC-50-master", "audio")

# These variables now use the paths determined above
MUSIC_DIR = FMA_AUDIO_PATH
GENRE_DATA_DIR = FMA_AUDIO_PATH
FMA_METADATA_CSV = os.path.join(FMA_METADATA_PATH, "tracks.csv")

#  Output Directory Options
DRIVE_PROJECT_PATH = "/content/drive/MyDrive/cw1_DL"
OUTPUT_DIR = os.path.join(DRIVE_PROJECT_PATH, "output")

# Case Study File Path (Used after training)
CASE_STUDY_FILE = os.path.join(DRIVE_PROJECT_PATH, "Case_study_city.mp3")

Using LOCAL Colab data paths (adjusting for nested folders)...


### 2.1 Parameters


In [3]:
# --- STFT Parameters ---
N_FFT = 2048
HOP_LENGTH = 512
WIN_LENGTH = N_FFT
WINDOW = 'hann'
SAMPLE_RATE = 44100
# --- CHANGE 1: Revert to 5-second chunks for Task 1 ---
AUDIO_CHUNK_DURATION_S = 5
AUDIO_CHUNK_SAMPLES = int(AUDIO_CHUNK_DURATION_S * SAMPLE_RATE)
print(f"Using Audio Chunk Duration: {AUDIO_CHUNK_DURATION_S}s")

# --- Training Hyperparameters ---
LEARNING_RATE_SEP = 1e-5 # Keep initial LR for scheduler
LEARNING_RATE_CLS = 1e-3
# --- CHANGE 2: Reduce Batch Size for 5s chunks ---
BATCH_SIZE = 32 # Reduce from 64. If this *still* fails OOM, try 16 next.
print(f"Using Batch Size: {BATCH_SIZE}")
# --- CHANGE 3: Set Task 1 Epochs ---
NUM_EPOCHS_SEP = 30 # Target for Task 1 (run fresh)
NUM_EPOCHS_CLS = 30 # Max for Task 2 (Early stopping active)
print(f"Task 1 Epochs: {NUM_EPOCHS_SEP}, Task 2 Max Epochs: {NUM_EPOCHS_CLS}")

GENRE_CLASSES = ['Electronic', 'Experimental', 'Folk', 'Hip-Hop', 'Instrumental', 'International', 'Pop', 'Rock'] # Should be these 8
NUM_GENRES = len(GENRE_CLASSES) # Should be 8
GENRE_MAP = {name: i for i, name in enumerate(GENRE_CLASSES)} # This line MUST be present

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"\nConfiguration Set:")
print(f" - Output directory: {OUTPUT_DIR}")
print(f" - Attempting FMA audio from: {MUSIC_DIR}")
print(f" - Attempting Ambience from: {AMBIENCE_DIR}")
print(f" - Attempting Metadata from: {FMA_METADATA_CSV}")

Using Audio Chunk Duration: 5s
Using Batch Size: 32
Task 1 Epochs: 30, Task 2 Max Epochs: 30

Configuration Set:
 - Output directory: /content/drive/MyDrive/cw1_DL/output
 - Attempting FMA audio from: /content/data/fma_small/fma_small
 - Attempting Ambience from: /content/data/esc50_audio/audio
 - Attempting Metadata from: /content/data/fma_metadata/fma_metadata/tracks.csv


In [None]:
!pip install mir_eval==0.7  # Install the mir_eval library, version 0.7 is usually suitable.

In [7]:
!pip install --upgrade mir_eval



3. Data handling

# 3 Data Handling

In [16]:
# Data Handling


#  Helper Functions

def load_audio(path, target_sr=SAMPLE_RATE):
    """Loads audio file, converts to mono, and resamples."""
    waveform = None; sr = None
    try:
        waveform_ta, sr_ta = torchaudio.load(path); waveform = waveform_ta; sr = sr_ta
    except Exception:
        try:

            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                waveform_lr, sr_lr = librosa.load(path, sr=target_sr, mono=True);
            waveform = torch.from_numpy(waveform_lr).unsqueeze(0); sr = target_sr
        except Exception: return None
    if waveform is not None and waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True)
    if sr is not None and sr != target_sr and waveform is not None: waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
    if waveform is None: return None
    return waveform.squeeze(0)

def create_mixture(music_wav, ambience_wav, target_snr_db):
    """Creates a mixture of music and ambience at a target SNR."""
    if not isinstance(music_wav, torch.Tensor): music_wav = torch.tensor(music_wav);
    if not isinstance(ambience_wav, torch.Tensor): ambience_wav = torch.tensor(ambience_wav);
    len_music = music_wav.shape[0]; len_ambience = ambience_wav.shape[0]
    if len_ambience < len_music: ambience_wav = ambience_wav.repeat(len_music // len_ambience + 1)[:len_music]
    elif len_ambience > len_music: start = np.random.randint(0, len_ambience - len_music + 1); ambience_wav = ambience_wav[start : start + len_music]
    power_music = torch.mean(music_wav**2); power_ambience = torch.mean(ambience_wav**2)
    if power_music < 1e-10 or power_ambience < 1e-10: return music_wav, music_wav
    snr_linear = 10**(target_snr_db / 10.0); target_power_ambience = power_music / snr_linear
    scaling_factor = torch.sqrt(target_power_ambience / (power_ambience + 1e-8)); scaled_ambience_wav = ambience_wav * scaling_factor
    mixture_wav = music_wav + scaled_ambience_wav; max_amp = torch.max(torch.abs(mixture_wav))
    if max_amp > 1.0: mixture_wav /= max_amp
    return mixture_wav, music_wav

def get_spectrogram(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH, win_length=WIN_LENGTH, window=WINDOW):
    """Computes the STFT (magnitude and phase) of a waveform ON CPU."""

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        window_fn = torch.hann_window; stft_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, win_length=win_length, window_fn=window_fn, power=None, return_complex=True)

        spectrogram_complex = stft_transform(waveform.cpu());
    magnitude = torch.abs(spectrogram_complex); phase = torch.angle(spectrogram_complex); return magnitude, phase

def waveform_from_spectrogram(magnitude, phase, n_fft=N_FFT, hop_length=HOP_LENGTH, win_length=WIN_LENGTH, window=WINDOW):
    """Reconstructs waveform from magnitude and phase using iSTFT."""

    magnitude = magnitude.to(device); phase = phase.to(device); stft_complex = torch.polar(magnitude, phase); window_fn = torch.hann_window
    istft_transform = torchaudio.transforms.InverseSpectrogram(n_fft=n_fft, hop_length=hop_length, win_length=win_length, window_fn=window_fn).to(device)
    est_len = int(phase.shape[-1] * hop_length); waveform = istft_transform(stft_complex, length=est_len); return waveform


# Dataset Classes


class MusicAmbienceDataset(Dataset):
    def __init__(self, music_files, ambience_files, target_snr_db_range=(-5, 10), chunk_samples=AUDIO_CHUNK_SAMPLES, sr=SAMPLE_RATE):
        self.music_files = music_files; self.ambience_files = ambience_files; self.target_snr_db_range = target_snr_db_range;
        self.chunk_samples = chunk_samples; self.sr = sr;
        if not music_files or not ambience_files: raise ValueError("Music or ambience file list is empty!")
        print(f"Initialized MusicAmbienceDataset with {len(music_files)} music files.")

    def __len__(self): return len(self.music_files) * 3

    def __getitem__(self, idx):
        music_path = None
        ambience_path = None
        try:
            # Determine file paths BEFORE the main try block for better error reporting
            music_idx = idx % len(self.music_files)
            ambience_idx = random.randint(0, len(self.ambience_files) - 1)
            music_path = self.music_files[music_idx]
            ambience_path = self.ambience_files[ambience_idx]

            # Main processing
            music_wav = load_audio(music_path, target_sr=self.sr)

            if music_wav is None or len(music_wav) < self.chunk_samples:

                 return None

            ambience_wav = load_audio(ambience_path, target_sr=self.sr)
            if ambience_wav is None:

                return None

            #  Chunking and Mixing
            start_idx = random.randint(0, len(music_wav) - self.chunk_samples)
            music_chunk = music_wav[start_idx : start_idx + self.chunk_samples]
            target_snr = random.uniform(self.target_snr_db_range[0], self.target_snr_db_range[1])
            mixture_chunk, clean_music_chunk = create_mixture(music_chunk, ambience_wav, target_snr)

            #  Spectrograms
            mixture_mag, mixture_phase = get_spectrogram(mixture_chunk)
            clean_music_mag, _ = get_spectrogram(clean_music_chunk) # Uses CPU get_spectrogram

            # Mask Calculation
            mask = clean_music_mag / (mixture_mag + 1e-8)
            mask = torch.clamp(mask, 0.0, 1.0)


            return mixture_mag.cpu(), mask.cpu(), mixture_phase.cpu(), clean_music_chunk.cpu()


        except Exception as e:
            print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
            print(f"!!! ERROR in MusicAmbienceDataset __getitem__ (Worker Crashed?) !!!")
            print(f"!!! Index: {idx}")
            print(f"!!! Music Path: {music_path}")
            print(f"!!! Ambience Path: {ambience_path}")
            print(f"!!! Error Type: {type(e).__name__}")
            print(f"!!! Error Message: {e}")

            traceback.print_exc()
            print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

            return None



class GenreDataset(Dataset):
    def __init__(self, file_paths, labels, chunk_samples=AUDIO_CHUNK_SAMPLES, sr=SAMPLE_RATE, n_mels=128):
        self.file_paths = file_paths; self.labels = labels; self.chunk_samples = chunk_samples; self.sr = sr; self.n_mels = n_mels

        self.mel_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=self.sr, n_fft=N_FFT, hop_length=HOP_LENGTH, win_length=WIN_LENGTH, n_mels=self.n_mels)
        self.log_mel_spec_transform = torchaudio.transforms.AmplitudeToDB()
        print(f"Initialized GenreDataset with {len(file_paths)} files.")

    def __len__(self): return len(self.file_paths)
    def __getitem__(self, idx):
        try:
            path = self.file_paths[idx]; label = self.labels[idx]; waveform = load_audio(path, target_sr=self.sr)
            if waveform is None or len(waveform) < self.chunk_samples: return None
            start_idx = random.randint(0, len(waveform) - self.chunk_samples); chunk = waveform[start_idx : start_idx + self.chunk_samples]

            mel_spec = self.mel_spectrogram_transform(chunk.cpu()); log_mel_spec = self.log_mel_spec_transform(mel_spec)
            return log_mel_spec, torch.tensor(label, dtype=torch.long)
        except Exception as e:

            return None

#  Define collate_fn
def collate_skip_none(batch):
    """Custom collate function that filters out None items from a batch."""
    filtered_batch = [item for item in batch if item is not None]
    if not filtered_batch: return None
    try: return torch.utils.data.default_collate(filtered_batch)
    except Exception as e:

        return None

# Data Preparation Logic

print("--- Data Preparation ---")

#  Step 1: Find initial lists of files
print("Scanning for initial file lists from local Colab storage...")

initial_music_files = glob.glob(os.path.join(MUSIC_DIR, '**', '*.mp3'), recursive=True)
all_ambience_files = glob.glob(os.path.join(AMBIENCE_DIR, '*.wav'), recursive=False)
print(f"Found {len(initial_music_files)} initial music files (Local).")
print(f"Found {len(all_ambience_files)} ambience files (Local).")
if not initial_music_files: print(f"!!! WARNING: Initial music file list empty. Check MUSIC_DIR path: {MUSIC_DIR}")
if not all_ambience_files: print(f"!!! WARNING: Ambience file list empty. Check AMBIENCE_DIR path: {AMBIENCE_DIR}")

# Step 2: Pre-filter problematic FMA files

print("\nPre-filtering FMA music files from local storage...")
valid_music_files = []
invalid_music_files_count = 0
start_filter_time = time.time()
for i, fpath in enumerate(initial_music_files):
    if (i + 1) % 1000 == 0:
        print(f"  Checked {i+1}/{len(initial_music_files)} files...")
    try:
        wav = load_audio(fpath)
        if wav is not None and len(wav) > AUDIO_CHUNK_SAMPLES:
            valid_music_files.append(fpath)
        else:
            invalid_music_files_count += 1
    except Exception as e_filter:
        invalid_music_files_count += 1
end_filter_time = time.time()
print(f"\nPre-filtering complete in {end_filter_time - start_filter_time:.2f} seconds.")
print(f"Found {len(valid_music_files)} valid music files.")
if invalid_music_files_count > 0:
    print(f"Excluded {invalid_music_files_count} problematic or short music files during pre-filtering.")

all_music_files = valid_music_files
random.shuffle(all_music_files)

# Step 3: Map Valid FMA Files to Genres

print("\nMapping valid FMA files to genres...")
all_genre_files = []
all_genre_labels = []
try:
    # Uses FMA_METADATA_CSV which points to nested local copy
    if not os.path.exists(FMA_METADATA_CSV): raise FileNotFoundError(f"Metadata CSV not found at {FMA_METADATA_CSV}")
    tracks_df = pd.read_csv(FMA_METADATA_CSV, index_col=0, header=[0, 1])
    small_subset_indices = tracks_df[tracks_df[('set', 'subset')] == 'small'].index
    print(f"Checking {len(small_subset_indices)} tracks from FMA small subset against valid files...")
    valid_music_files_dict = {os.path.basename(f): f for f in all_music_files}
    processed_count = 0
    for track_id in small_subset_indices:
        tid_str = f"{track_id:06d}"
        filename = f"{tid_str}.mp3"
        if filename in valid_music_files_dict:
            full_path = valid_music_files_dict[filename]
            track_info = tracks_df.loc[track_id]
            genre_name = track_info[('track', 'genre_top')]
            if genre_name in GENRE_MAP:
                all_genre_files.append(full_path)
                all_genre_labels.append(GENRE_MAP[genre_name])
                processed_count += 1
    print(f"Mapped {processed_count} valid small subset tracks to genres.")
except FileNotFoundError as e: print(f"!!! ERROR: {e}")
except ImportError: print("!!! ERROR: pandas library is required. `pip install pandas`")
except KeyError as e: print(f"!!! ERROR: Column name issue reading FMA metadata CSV (Maybe {e} ?).")
except Exception as e: print(f"Error during FMA genre mapping: {e}")

print(f"Total genre files prepared for Task 2: {len(all_genre_files)}")
if not all_genre_files: print("!!! WARNING: Genre file list empty after filtering/mapping.")


#  Step 4: Data Splitting

print("\nSplitting data...")
music_train_files, music_val_files, music_test_files = [], [], []
genre_files_train, genre_files_val, genre_files_test = [], [], []
genre_labels_train, genre_labels_val, genre_labels_test = [], [], []

if all_music_files:
    music_train_val, music_test_files = train_test_split(all_music_files, test_size=0.15, random_state=42)
    if len(music_train_val) > 1:
        music_train_files, music_val_files = train_test_split(music_train_val, test_size=0.1765, random_state=42)
    else: music_train_files = music_train_val

if all_genre_files and len(all_genre_files) == len(all_genre_labels):
    unique_labels, counts = np.unique(all_genre_labels, return_counts=True)
    min_samples_for_stratify = 3
    can_stratify = counts.min() >= min_samples_for_stratify if len(counts) > 0 else False
    stratify_opt = all_genre_labels if can_stratify else None
    if not can_stratify: print("Warning: Cannot stratify genre split (min samples per class < 3).")

    genre_files_train_val, genre_files_test, genre_labels_train_val, genre_labels_test = train_test_split(
        all_genre_files, all_genre_labels, test_size=0.15, stratify=stratify_opt, random_state=42)
    stratify_opt_2 = genre_labels_train_val if can_stratify else None
    if len(genre_files_train_val) > 1:
         genre_files_train, genre_files_val, genre_labels_train, genre_labels_val = train_test_split(
             genre_files_train_val, genre_labels_train_val, test_size=0.1765, stratify=stratify_opt_2, random_state=42)
    else: genre_files_train, genre_labels_train = genre_files_train_val, genre_labels_train_val
else:
    print("Warning: Cannot split genre data.")

print(f"Task 1 Splits: Train={len(music_train_files)}, Val={len(music_val_files)}, Test={len(music_test_files)}")
print(f"Task 2 Splits: Train={len(genre_files_train)}, Val={len(genre_files_val)}, Test={len(genre_files_test)}")


#  Step 5: Create Datasets and DataLoaders
print("\nCreating Datasets and DataLoaders...")

num_workers_flag = 2

pin_memory_flag = True if device.type == 'cuda' else False

train_sep_dataset = MusicAmbienceDataset(music_train_files, all_ambience_files) if music_train_files and all_ambience_files else None
val_sep_dataset = MusicAmbienceDataset(music_val_files, all_ambience_files) if music_val_files and all_ambience_files else None
test_sep_dataset = MusicAmbienceDataset(music_test_files, all_ambience_files) if music_test_files and all_ambience_files else None

train_genre_dataset = GenreDataset(genre_files_train, genre_labels_train) if genre_files_train else None
val_genre_dataset = GenreDataset(genre_files_val, genre_labels_val) if genre_files_val else None
test_genre_dataset = GenreDataset(genre_files_test, genre_labels_test) if genre_files_test else None


train_sep_loader = DataLoader(train_sep_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers_flag, collate_fn=collate_skip_none, pin_memory=pin_memory_flag, persistent_workers=True if num_workers_flag > 0 else False) if train_sep_dataset else None
val_sep_loader = DataLoader(val_sep_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers_flag, collate_fn=collate_skip_none, pin_memory=pin_memory_flag, persistent_workers=True if num_workers_flag > 0 else False) if val_sep_dataset else None
test_sep_loader = DataLoader(test_sep_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers_flag, collate_fn=collate_skip_none, pin_memory=pin_memory_flag, persistent_workers=True if num_workers_flag > 0 else False) if test_sep_dataset else None

train_genre_loader = DataLoader(train_genre_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers_flag, collate_fn=collate_skip_none, pin_memory=pin_memory_flag, persistent_workers=True if num_workers_flag > 0 else False) if train_genre_dataset else None
val_genre_loader = DataLoader(val_genre_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers_flag, collate_fn=collate_skip_none, pin_memory=pin_memory_flag, persistent_workers=True if num_workers_flag > 0 else False) if val_genre_dataset else None
test_genre_loader = DataLoader(test_genre_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers_flag, collate_fn=collate_skip_none, pin_memory=pin_memory_flag, persistent_workers=True if num_workers_flag > 0 else False) if test_genre_dataset else None

print(f"Train Sep Loader: {'Created' if train_sep_loader else 'Not Created'} (Workers: {num_workers_flag})")
print(f"Val Sep Loader: {'Created' if val_sep_loader else 'Not Created'} (Workers: {num_workers_flag})")
print(f"Test Sep Loader: {'Created' if test_sep_loader else 'Not Created'} (Workers: {num_workers_flag})")
print(f"Train Genre Loader: {'Created' if train_genre_loader else 'Not Created'} (Workers: {num_workers_flag})")
print(f"Val Genre Loader: {'Created' if val_genre_loader else 'Not Created'} (Workers: {num_workers_flag})")
print(f"Test Genre Loader: {'Created' if test_genre_loader else 'Not Created'} (Workers: {num_workers_flag})")
print("Data preparation section finished.")

--- Data Preparation ---
Scanning for initial file lists from local Colab storage...
Found 8000 initial music files (Local).
Found 2000 ambience files (Local).

Pre-filtering FMA music files from local storage...
  Checked 1000/8000 files...
  Checked 2000/8000 files...
  Checked 3000/8000 files...
  Checked 4000/8000 files...
  Checked 5000/8000 files...
  Checked 6000/8000 files...
  Checked 7000/8000 files...
  Checked 8000/8000 files...

Pre-filtering complete in 528.87 seconds.
Found 7994 valid music files.
Excluded 6 problematic or short music files during pre-filtering.

Mapping valid FMA files to genres...
Checking 8000 tracks from FMA small subset against valid files...
Mapped 7994 valid small subset tracks to genres.
Total genre files prepared for Task 2: 7994

Splitting data...
Task 1 Splits: Train=5594, Val=1200, Test=1200
Task 2 Splits: Train=5594, Val=1200, Test=1200

Creating Datasets and DataLoaders...
Initialized MusicAmbienceDataset with 5594 music files.
Initialized 

 # 4. Model Definitions

In [4]:
# Define the deep learning models for separation (Task 1) and classification (Task 2).



import torch
import torch.nn as nn
import torchvision.models as models # For pre-trained models

# Task 1: Separation (U-Net)

class ConvBlock(nn.Module):
    """(Conv2d => BN => LeakyReLU) * 2"""
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x):
        return self.block(x)

class EncoderBlock(nn.Module):
    """ConvBlock followed by MaxPool"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_block = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        skip = self.conv_block(x)
        pooled = self.pool(skip)
        return skip, pooled

class DecoderBlock(nn.Module):
    """UpConv + Concatenate + ConvBlock + Optional Dropout"""
    def __init__(self, in_channels, skip_channels, out_channels, dropout_p=0.5):
        super().__init__()

        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)


        self.conv_block = ConvBlock(out_channels + skip_channels, out_channels)

        self.dropout = nn.Dropout2d(p=dropout_p) if dropout_p > 0 else nn.Identity() # Identity if no dropout

    def _match_height_width(self, up_x, skip_x):
        """Helper to crop skip connection if needed after up-convolution."""



        h_diff = skip_x.shape[2] - up_x.shape[2]
        w_diff = skip_x.shape[3] - up_x.shape[3]
        if h_diff > 0 or w_diff > 0:


             crop_h_start = h_diff // 2
             crop_w_start = w_diff // 2
             skip_x = skip_x[:, :, crop_h_start:crop_h_start + up_x.shape[2], crop_w_start:crop_w_start + up_x.shape[3]]
        elif h_diff < 0 or w_diff < 0:

             up_x = nn.functional.pad(up_x, (-w_diff // 2, w_diff - (-w_diff // 2),
                                            -h_diff // 2, h_diff - (-h_diff // 2)))

        return up_x, skip_x

    def forward(self, x, skip):


        up_x = self.upconv(x)
        # Match spatial dimensions before concatenation
        up_x_matched, skip_matched = self._match_height_width(up_x, skip)
        # Concatenate along the channel dimension
        concat_x = torch.cat([up_x_matched, skip_matched], dim=1)
        # Apply dropout
        concat_x = self.dropout(concat_x)
        # Apply convolutional block
        out = self.conv_block(concat_x)
        return out

class SeparationUNet(nn.Module):
    """More standard U-Net for audio source separation (mask estimation)."""
    def __init__(self, n_channels=1, n_classes=1, features=[16, 32, 64, 128, 256]):


        super().__init__()
        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()

        # Encoder Path
        in_ch = n_channels
        for feature in features:
            self.encoders.append(EncoderBlock(in_ch, feature))
            in_ch = feature
        print(f"U-Net Encoder features: {features}")

        # Bottleneck
        bottleneck_features = features[-1] * 2
        self.bottleneck = ConvBlock(features[-1], bottleneck_features)
        print(f"U-Net Bottleneck features: {bottleneck_features}")

        # Decoder Path

        decoder_features = features[::-1]
        in_ch_decode = bottleneck_features
        for i in range(len(decoder_features)):
            skip_ch = decoder_features[i]
            out_ch = decoder_features[i]
            self.decoders.append(DecoderBlock(in_ch_decode, skip_ch, out_ch))
            in_ch_decode = out_ch
        print(f"U-Net Decoder features (output channels per block): {decoder_features}")

        # Final Output Convolution

        self.final_conv = nn.Conv2d(features[0], n_classes, kernel_size=1)
        # Sigmoid activation to ensure mask output is between 0 and 1
        self.final_activation = nn.Sigmoid()

    def forward(self, x):

        skip_connections = []

        #  Encoder
        for encoder in self.encoders:
            skip, x = encoder(x)
            skip_connections.append(skip)


        # Bottleneck
        x = self.bottleneck(x)


        #  Decoder

        skip_connections = skip_connections[::-1] # Reverse list
        for i in range(len(self.decoders)):
            skip = skip_connections[i]
            x = self.decoders[i](x, skip)


        # --- Final Output ---
        mask = self.final_conv(x)
        mask = self.final_activation(mask)


        return mask


# Task 2: Classification Model
# Using a pre-trained ResNet18 adapted for audio

def get_pretrained_genre_classifier(num_genres=NUM_GENRES, pretrained=True):
    """Loads a pre-trained ResNet18 model and adapts it for genre classification."""
    print(f"Loading {'pre-trained' if pretrained else 'randomly initialized'} ResNet18 model...")
    weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
    model = models.resnet18(weights=weights)

    # Adapt Input Layer
    original_conv1 = model.conv1
    model.conv1 = nn.Conv2d(1, original_conv1.out_channels, kernel_size=original_conv1.kernel_size,
                            stride=original_conv1.stride, padding=original_conv1.padding, bias=False)
    print("Adapted model.conv1 to accept 1 input channel.")
    if pretrained:
        try:
             original_weights_tensor = original_conv1.weight.data
        except AttributeError:
             original_weights_tensor = weights.state_dict()['conv1.weight']
        model.conv1.weight.data = torch.mean(original_weights_tensor, dim=1, keepdim=True)
        print("Initialized new conv1 weights by averaging original RGB weights.")

    # Adapt Output Layer
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_genres)
    print(f"Adapted model.fc to output {num_genres} classes.")
    return model

print("Model definitions updated: Using enhanced U-Net for Task 1.")


Model definitions updated: Using enhanced U-Net for Task 1.


# 5. Training Loop

In [5]:
# Training Loop Definitions
# Includes interpolation for shape matching

from tqdm.notebook import tqdm # Progress bars
import torch
import torch.nn.functional as F # For interpolate
from sklearn.metrics import accuracy_score, f1_score # Evaluation metrics

# --- Function: train_epoch ---
def train_epoch(model, dataloader, criterion, optimizer, device, task_type='separation', epoch_num=0, total_epochs=1):
    """Runs a single training epoch with a tqdm progress bar."""
    model.train()
    total_loss = 0.0
    num_items = 0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch_num}/{total_epochs} [{task_type.capitalize()} Train]", leave=False)

    for batch_idx, batch in enumerate(progress_bar):
        if batch is None: continue

        if task_type == 'separation':
             if len(batch) < 2: continue
             inputs, targets = batch[0], batch[1] # mixture_mag, mask
        elif task_type == 'classification':
             if len(batch) < 2: continue
             inputs, targets = batch[0], batch[1] # log_mel_spec, label
        else: raise ValueError("Invalid task_type for training")

        if inputs is None or targets is None or inputs.nelement() == 0 or targets.nelement() == 0: continue

        inputs, targets = inputs.to(device), targets.to(device)

        if inputs.dim() == 3: inputs = inputs.unsqueeze(1)
        if task_type == 'separation' and targets.dim() == 3: targets = targets.unsqueeze(1)

        optimizer.zero_grad()
        outputs = model(inputs)

        # Interpolate output shape to match target
        if task_type == 'separation' and outputs.shape != targets.shape:
            # Get target spatial dimensions (Height = Freq, Width = Time)
            target_h, target_w = targets.shape[2], targets.shape[3]
            # Resize output to match target size using bilinear interpolation
            outputs = F.interpolate(outputs, size=(target_h, target_w), mode='bilinear', align_corners=False)
            if batch_idx == 0 and epoch_num == 1: # Print message only once
                 print(f"Interpolated output shape to match target: {outputs.shape}")
        # End interpolation

        # Now calculate loss - shapes should match for separation task
        loss = criterion(outputs, targets)

        if batch_idx == 0 and epoch_num == 1: # Print loss for first batch
            print(f"Loss value (first batch): {loss.item():.6f}")
            # Note: Removed the larger debug block here

        if torch.isnan(loss):
            print(f"Warning: NaN loss detected at batch {batch_idx}. Skipping batch.")
            continue
        loss.backward()
        optimizer.step()

        batch_loss = loss.item()
        batch_size = inputs.size(0)
        total_loss += batch_loss * batch_size
        num_items += batch_size

        if num_items > 0:
            progress_bar.set_postfix(loss=f"{(total_loss / num_items):.4f}")

    avg_loss = total_loss / num_items if num_items > 0 else 0
    return avg_loss


# --- Function: validate_epoch ---
def validate_epoch(model, dataloader, criterion, device, task_type='separation', epoch_num=0, total_epochs=1):
    """Runs a single validation epoch with a tqdm progress bar."""
    model.eval()
    total_loss = 0.0
    num_items = 0
    all_preds = []
    all_targets = []
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch_num}/{total_epochs} [{task_type.capitalize()} Val]", leave=False)

    with torch.no_grad():
        for batch_idx, batch in enumerate(progress_bar):
            if batch is None: continue

            if task_type == 'separation':
                 if len(batch) < 2: continue
                 inputs, targets = batch[0], batch[1] # mixture_mag, mask
            elif task_type == 'classification':
                 if len(batch) < 2: continue
                 inputs, targets = batch[0], batch[1] # log_mel_spec, label
            else: raise ValueError("Invalid task_type for validation")

            if inputs is None or targets is None or inputs.nelement() == 0 or targets.nelement() == 0: continue

            inputs, targets = inputs.to(device), targets.to(device)
            if inputs.dim() == 3: inputs = inputs.unsqueeze(1)
            if task_type == 'separation' and targets.dim() == 3: targets = targets.unsqueeze(1)

            outputs = model(inputs)

            # Interpolate output shape to match target
            if task_type == 'separation' and outputs.shape != targets.shape:
                target_h, target_w = targets.shape[2], targets.shape[3]
                outputs = F.interpolate(outputs, size=(target_h, target_w), mode='bilinear', align_corners=False)
            # End interpolation

            loss = criterion(outputs, targets)
            if torch.isnan(loss):
                print(f"Warning: NaN validation loss detected at batch {batch_idx}. Skipping batch.")
                continue

            total_loss += loss.item() * inputs.size(0)
            num_items += inputs.size(0)

            if task_type == 'classification':
                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

            if num_items > 0:
                 progress_bar.set_postfix(loss=f"{(total_loss / num_items):.4f}")

    avg_loss = total_loss / num_items if num_items > 0 else 0
    val_metrics = {'loss': avg_loss}

    if task_type == 'classification' and all_targets:
        accuracy = accuracy_score(all_targets, all_preds)
        f1 = f1_score(all_targets, all_preds, average='weighted', zero_division=0)
        val_metrics['accuracy'] = accuracy
        val_metrics['f1'] = f1

    return val_metrics

# 6. Evaluation Functions

In [6]:
# evaluation functions


try:
    import mir_eval
    MIR_EVAL_AVAILABLE = True
except ImportError:
    print("WARNING: mir_eval not installed. SDR calculation will be skipped.")
    MIR_EVAL_AVAILABLE = False

# function: calculate_sdr
def calculate_sdr(estimated_source_wav, true_source_wav):
    """Calculates SDR using mir_eval (if available). Attempts float64 casting."""
    if not MIR_EVAL_AVAILABLE:
        return 0.0 # return 0 if mir_eval not installed
    try:
        # ensure inputs are numpy arrays on cpu
        if not isinstance(estimated_source_wav, np.ndarray): estimated_source_wav = estimated_source_wav.cpu().numpy()
        if not isinstance(true_source_wav, np.ndarray): true_source_wav = true_source_wav.cpu().numpy()

        # ensure inputs are 1d
        if estimated_source_wav.ndim > 1: estimated_source_wav = estimated_source_wav.squeeze()
        if true_source_wav.ndim > 1: true_source_wav = true_source_wav.squeeze()
        if estimated_source_wav.ndim != 1 or true_source_wav.ndim != 1:
             return 0.0 # return 0 if dimensions are wrong

        min_len = min(len(estimated_source_wav), len(true_source_wav))
        if min_len == 0: return 0.0 # avoid empty arrays

        # cast to float64 and add epsilon
        est = estimated_source_wav[:min_len].astype(np.float64)
        true = true_source_wav[:min_len].astype(np.float64)
        # add tiny noise to prevent potential perfect silence issues in calculation
        epsilon = 1e-10
        est += np.random.randn(est.shape[0]) * epsilon
        true += np.random.randn(true.shape[0]) * epsilon
        # end modification

        # add check for silence or near-silence in ground truth after adding epsilon
        if np.sum(np.abs(true)) < 1e-8:
            return 0.0

        # mir_eval expects shape (n_sources, n_samples)
        sdr_value, _, _, _ = mir_eval.separation.bss_eval_sources(true[np.newaxis, :], est[np.newaxis, :])

        # return 0 if mir_eval calculation itself fails in a way that produces nan
        if np.isnan(sdr_value[0]):
             return 0.0 # or return np.nan if you want to handle it differently later

        return sdr_value[0]

    except ValueError as ve:
        # error during sdr calculation (valueerror)
        return 0.0
    except Exception as e:
        # unexpected error during sdr calculation
        traceback.print_exc() # print full error details
        return 0.0

# function: evaluate_separation_model
def evaluate_separation_model(model, test_loader, device):
    """Evaluates the separation model on the test set using SDR."""
    model.eval()
    total_sdr = 0.0
    valid_sdr_count = 0 # use a separate counter for valid sdr scores
    processed_items = 0
    print("Evaluating separation model on test set...")
    if not test_loader:
         print("Test separation loader not available. Skipping evaluation.")
         return 0.0

    progress_bar = tqdm(test_loader, desc="Separation Eval", leave=False)

    with torch.no_grad():
        for batch_idx, batch in enumerate(progress_bar):
            if batch is None or len(batch) < 4: continue
            mixture_mag, _, mixture_phase, clean_music_chunk = batch
            if mixture_mag is None or mixture_phase is None or clean_music_chunk is None: continue

            mixture_mag_dev = mixture_mag.to(device)
            mixture_phase_dev = mixture_phase.to(device)

            if mixture_mag_dev.dim() == 3: mixture_mag_dev = mixture_mag_dev.unsqueeze(1)

            predicted_mask = model(mixture_mag_dev)

            if predicted_mask.shape != mixture_mag_dev.shape:
                target_h, target_w = mixture_mag_dev.shape[2], mixture_mag_dev.shape[3]
                predicted_mask = F.interpolate(predicted_mask, size=(target_h, target_w), mode='bilinear', align_corners=False)

            for i in range(mixture_mag_dev.size(0)):
                processed_items += 1
                input_mag_i = mixture_mag_dev[i].squeeze(0)
                mask_i = predicted_mask[i].squeeze(0)
                phase_i = mixture_phase_dev[i].squeeze(0)

                if input_mag_i.shape != mask_i.shape: continue # skip if shapes mismatch

                est_mag_i = input_mag_i * mask_i
                true_wav_i = clean_music_chunk[i].cpu()

                try:
                    # assuming waveform_from_spectrogram is defined elsewhere
                    est_wav_i = waveform_from_spectrogram(est_mag_i, phase_i)
                    if est_wav_i is None:
                         continue # skip if reconstruction fails

                    est_wav_i = est_wav_i.cpu()
                    # ensure estimated waveform length matches true waveform length
                    est_wav_i = est_wav_i[:len(true_wav_i)]

                    sdr = calculate_sdr(est_wav_i, true_wav_i) # calculate sdr using modified function

                    # only accumulate and count non-zero and non-nan sdrs
                    if not np.isnan(sdr) and sdr != 0.0:
                        total_sdr += sdr
                        valid_sdr_count += 1
                    # else: # optional: track zero/nan sdrs
                        # pass

                except Exception as recon_err:
                    print(f"Error during waveform reconstruction or SDR for item {i} in batch {batch_idx}: {recon_err}")
                    traceback.print_exc() # print full error details
                    continue

            # update progress bar description
            if valid_sdr_count > 0:
                progress_bar.set_description(f"Separation Eval (Avg SDR: {total_sdr / valid_sdr_count:.2f})")
            else:
                progress_bar.set_description(f"Separation Eval (Avg SDR: N/A)")


    if valid_sdr_count == 0:
        print(f"--- No valid (non-zero, non-NaN) SDR scores calculated out of {processed_items} items processed. Check for errors above. ---")
        return 0.0

    avg_sdr = total_sdr / valid_sdr_count
    print(f"\n--- Average SDR on Test Set (from {valid_sdr_count} valid scores): {avg_sdr:.4f} ---")
    return avg_sdr


# function: evaluate_classification_model
def evaluate_classification_model(model, test_loader, device):
    """Evaluates the classification model on the test set."""
    model.eval()
    all_preds = []
    all_targets = []
    print("Evaluating classification model on test set...")
    if not test_loader:
         print("Test genre loader not available. Skipping evaluation.")
         return {}
    progress_bar = tqdm(test_loader, desc="Classification Eval", leave=False)
    with torch.no_grad():
        for batch_idx, batch in enumerate(progress_bar):
            if batch is None or len(batch) < 2: continue
            inputs, targets = batch[0], batch[1] # log_mel_spec, label
            if inputs is None or targets is None or inputs.nelement() == 0 or targets.nelement() == 0: continue
            inputs, targets = inputs.to(device), targets.to(device)
            if inputs.dim() == 3: inputs = inputs.unsqueeze(1)
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
    if not all_targets: print("No targets found for classification evaluation."); return {}
    # assuming GENRE_CLASSES and NUM_GENRES are defined globally or passed differently
    target_names_cls = GENRE_CLASSES if 'GENRE_CLASSES' in locals() else [str(i) for i in range(NUM_GENRES)]
    present_labels = sorted(np.unique(all_targets + all_preds))
    valid_present_labels = [l for l in present_labels if l >= 0 and l < len(target_names_cls)]
    filtered_target_names = [target_names_cls[i] for i in valid_present_labels]
    accuracy = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds, average='weighted', zero_division=0)
    print("\n--- Classification Report ---")
    if not filtered_target_names: print("Cannot generate classification report: No valid target names found."); report = "N/A"; conf_matrix = "N/A"
    else:
         report = classification_report(all_targets, all_preds, labels=valid_present_labels, target_names=filtered_target_names, zero_division=0); print(report)
         try:
             print("\n--- Confusion Matrix ---"); cm = confusion_matrix(all_targets, all_preds, labels=valid_present_labels); print(f"Labels: {filtered_target_names}"); print(cm); conf_matrix = cm
         except Exception as cm_err: print(f"Could not generate confusion matrix: {cm_err}"); conf_matrix = "N/A"
    print(f"--- Overall Accuracy: {accuracy:.4f} ---"); print(f"--- Weighted F1-Score: {f1:.4f} ---")
    return {"accuracy": accuracy, "f1": f1, "report": report, "conf_matrix": conf_matrix}



# 7. Combination Logic

In [8]:
# combination logic
# defines function to separate then classify audio
# includes interpolation fix for shape mismatch
# ensure helpers (load_audio, get_spectrogram, waveform_from_spectrogram)
# and variables (sample_rate, audio_chunk_samples, n_fft, hop_length, device, num_genres, genre_classes)
# are defined and accessible

# function: separate_and_classify
def separate_and_classify(mixture_waveform_path, separation_model, classification_model, device):
    """Applies separation then classification (Sequential Approach) to an audio file path."""
    # set models to evaluation mode
    separation_model.eval()
    classification_model.eval()

    try:
        # load mixture waveform
        # assumes load_audio defined (returns cpu tensor)
        mixture_waveform = load_audio(mixture_waveform_path, target_sr=SAMPLE_RATE)
        if mixture_waveform is None:
            print(f"Error: Failed to load audio file {mixture_waveform_path}")
            return None, None

        with torch.no_grad():
            # separation step
            # 1. get spectrogram (cpu)
            # assumes get_spectrogram defined
            mixture_mag_cpu, mixture_phase_cpu = get_spectrogram(mixture_waveform)

            # 2. prepare for model (add dims, move to device)
            mixture_mag_dev = mixture_mag_cpu.unsqueeze(0).unsqueeze(0).to(device)

            # 3. predict mask
            predicted_mask = separation_model(mixture_mag_dev) # output on device

            # interpolate mask to match input shape
            if predicted_mask.shape != mixture_mag_dev.shape:
                 target_h, target_w = mixture_mag_dev.shape[2], mixture_mag_dev.shape[3] # get target shape from input mag
                 predicted_mask = F.interpolate(predicted_mask, size=(target_h, target_w), mode='bilinear', align_corners=False)
            # end interpolation

            # remove batch/channel dims for reconstruction
            input_mag_i = mixture_mag_dev.squeeze(0).squeeze(0) # shape [freq, time] on device
            mask_i = predicted_mask.squeeze(0).squeeze(0)      # shape [freq, time] on device

            if input_mag_i.shape != mask_i.shape:
                 print(f"Error: Shape mismatch even after interpolation! Input: {input_mag_i.shape}, Mask: {mask_i.shape}")
                 return None, None # cannot proceed

            # 4. apply mask
            estimated_music_mag = input_mag_i * mask_i # result on device

            # 5. reconstruct waveform
            # assumes waveform_from_spectrogram defined
            separated_music_waveform = waveform_from_spectrogram(estimated_music_mag, mixture_phase_cpu.to(device))
            separated_music_waveform = separated_music_waveform.cpu() # move result to cpu
            separated_music_waveform = separated_music_waveform[:mixture_waveform.shape[0]] # trim length


            # classification step (on separated waveform)
            required_samples = AUDIO_CHUNK_SAMPLES
            if len(separated_music_waveform) >= required_samples:
                 start = (len(separated_music_waveform) - required_samples) // 2
                 class_chunk = separated_music_waveform[start : start + required_samples]
            else:
                 padding = required_samples - len(separated_music_waveform)
                 class_chunk = torch.nn.functional.pad(separated_music_waveform, (0, padding))

            # extract features (mel spectrogram)
            # ensure config variables (sample_rate, n_fft, hop_length) are available
            mel_spectrogram_transform = torchaudio.transforms.MelSpectrogram(
                sample_rate=SAMPLE_RATE, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=128 # assuming n_mels=128
            ).to(device)
            log_mel_spec_transform = torchaudio.transforms.AmplitudeToDB().to(device)

            mel_spec = mel_spectrogram_transform(class_chunk.to(device))
            log_mel_spec = log_mel_spec_transform(mel_spec)
            log_mel_spec = log_mel_spec.unsqueeze(0).unsqueeze(0) # add batch/channel dims

            # classify
            logits = classification_model(log_mel_spec)
            prediction = torch.argmax(logits, dim=1).item()

        # return prediction index and separated waveform (cpu)
        return prediction, separated_music_waveform

    except Exception as e:
        print(f"Error during separate_and_classify for {mixture_waveform_path}: {e}")
        traceback.print_exc() # print full traceback
        return None, None

# 8. Training execution function

In [18]:
# main execution block for training
# trains task 1 fresh (5s chunks, bs=32, lr scheduler)
# trains task 2 (weight decay, early stopping)

import time
import torch
import torch.nn as nn
import torch.optim as optim
# lr scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
import numpy as np
import os
import soundfile as sf
from sklearn.metrics import classification_report, accuracy_score, f1_score
import traceback # for detailed errors

# record start time
overall_start_time = time.time()

print("=============================================")
print("=== Starting Project Execution (Final Run Attempt) ===")
# ensure config variables defined (from cell [27])
# check existence using locals()/globals()
config_vars_ok = all(var in locals() or var in globals() for var in ['device', 'AUDIO_CHUNK_DURATION_S', 'BATCH_SIZE', 'NUM_EPOCHS_SEP', 'NUM_EPOCHS_CLS', 'LEARNING_RATE_SEP', 'LEARNING_RATE_CLS', 'OUTPUT_DIR', 'NUM_GENRES', 'GENRE_CLASSES', 'SAMPLE_RATE', 'N_FFT', 'HOP_LENGTH'])
if not config_vars_ok:
     print("!!! ERROR: Essential config variables missing.")
     print("!!! Run Configuration Cell [27] first.")
     raise NameError("Configuration variables missing.")
else:
    print(f"=== Using Device: {device} ===")
    print(f"=== Target Separation Chunks: {AUDIO_CHUNK_DURATION_S}s ===")
    print(f"=== Target Batch Size: {BATCH_SIZE} ===")
print("=============================================")

# configuration: set to true to run task 1 training
TRAIN_TASK_1 = True

# initialize models
print("\nInitializing models...")
try:
    # assumes model classes defined
    separation_model = SeparationUNet().to(device)
    classification_model = get_pretrained_genre_classifier(num_genres=NUM_GENRES, pretrained=True).to(device)
    print("Models initialized.")
except NameError as ne:
    print(f"!!! ERROR: Model class definition not found ({ne}). Run Cell [29].")
    raise

# set up loss functions and optimizers
print("\nSetting up loss functions and optimizers...")
try:
    # assumes learning rates defined (in cell [27])
    criterion_sep = nn.L1Loss()
    criterion_cls = nn.CrossEntropyLoss()

    # task 1 optimizer & lr scheduler
    optimizer_sep = optim.Adam(separation_model.parameters(), lr=LEARNING_RATE_SEP)
    scheduler_sep = ReduceLROnPlateau(optimizer_sep, mode='min', factor=0.1, patience=5, verbose=True, min_lr=1e-7)
    print(f"Optimizer_sep: Adam, LR={LEARNING_RATE_SEP}, Scheduler=ReduceLROnPlateau(patience=5)")

    # task 2 optimizer (with weight decay)
    weight_decay_cls = 1e-4
    optimizer_cls = optim.Adam(classification_model.parameters(), lr=LEARNING_RATE_CLS, weight_decay=weight_decay_cls)
    print(f"Optimizer_cls: LR={LEARNING_RATE_CLS}, Weight Decay={weight_decay_cls}")

except Exception as e:
     print(f"!!! ERROR setting up loss/optimizers: {e}")
     raise

# check if dataloaders created
# assumes cell [34] (data handling) was run after cell [27]
loaders_available = all([
    'train_sep_loader' in locals() and train_sep_loader is not None,
    'val_sep_loader' in locals() and val_sep_loader is not None,
    'test_sep_loader' in locals() and test_sep_loader is not None,
    'train_genre_loader' in locals() and train_genre_loader is not None,
    'val_genre_loader' in locals() and val_genre_loader is not None,
    'test_genre_loader' in locals() and test_genre_loader is not None
])

if not loaders_available:
    print("\n!!! ERROR: DataLoaders not available. Run Cell [34] after updating Config Cell [27].")
    raise ValueError("DataLoaders missing, cannot proceed.")
elif 'separation_model' in locals() and 'classification_model' in locals():

    # training phase: task 1 (separation - 5s chunks)
    if TRAIN_TASK_1:
        print("\n--- Training Separation Model (Task 1 - 5s Chunks) ---")

        # starting fresh: delete old checkpoint if it exists
        best_sep_model_path = os.path.join(OUTPUT_DIR, "separation_model_best.pth")
        if os.path.exists(best_sep_model_path):
            print(f"Deleting previous checkpoint to start fresh: {best_sep_model_path}")
            try: os.remove(best_sep_model_path)
            except Exception as e_del: print(f"Warning: Could not delete {best_sep_model_path}: {e_del}")
        else:
            print("No previous checkpoint found, starting fresh.")

        start_epoch_sep = 0
        best_sep_val_loss = float('inf')
        epochs_to_run_sep = NUM_EPOCHS_SEP # use value from cell [27]

        print(f"Starting training for {epochs_to_run_sep} epochs...")
        sep_train_losses = []
        sep_val_losses = []

        for epoch in range(epochs_to_run_sep):
            current_epoch_display = start_epoch_sep + epoch + 1
            print(f"\nEpoch {current_epoch_display}/{epochs_to_run_sep}")
            epoch_start_time = time.time()
            try:
                # assumes train/validate_epoch defined (cell [30])
                train_loss = train_epoch(separation_model, train_sep_loader, criterion_sep, optimizer_sep, device, task_type='separation', epoch_num=current_epoch_display, total_epochs=epochs_to_run_sep)
                val_metrics = validate_epoch(separation_model, val_sep_loader, criterion_sep, device, task_type='separation', epoch_num=current_epoch_display, total_epochs=epochs_to_run_sep)
                val_loss = val_metrics['loss']

                if np.isnan(train_loss) or np.isnan(val_loss): print(f"!!! ERROR: NaN loss detected in Epoch {current_epoch_display}. Stopping."); break

                sep_train_losses.append(train_loss); sep_val_losses.append(val_loss)
                epoch_end_time = time.time()
                print(f"  Epoch {current_epoch_display} Summary: Train Loss={train_loss:.6f}, Val Loss={val_loss:.6f}, Time={epoch_end_time - epoch_start_time:.2f}s")

                if val_loss < best_sep_val_loss:
                    best_sep_val_loss = val_loss
                    save_path = os.path.join(OUTPUT_DIR, "separation_model_best.pth")
                    try: torch.save(separation_model.state_dict(), save_path); print(f"  Saved best separation model (Val Loss: {best_sep_val_loss:.6f})")
                    except Exception as e: print(f"  Error saving model: {e}")

                scheduler_sep.step(val_loss) # step the lr scheduler

            except NameError as e_func: print(f"!!! ERROR: Training function not defined ({e_func}). Run Cell [30]."); break
            except RuntimeError as e_rt: print(f"!!! RUNTIME ERROR in epoch {current_epoch_display}: {e_rt}"); traceback.print_exc(); print("Try reducing BATCH_SIZE further (e.g., 16) in Cell [27] if this is OOM."); break
            except Exception as e_gen: print(f"!!! UNEXPECTED ERROR in epoch {current_epoch_display}: {e_gen}"); traceback.print_exc(); break

        print("\nSeparation Model Training Finished.")
        if sep_train_losses: # plot task 1 results
            plt.figure(figsize=(10, 5))
            plt.plot(range(1, len(sep_train_losses) + 1), sep_train_losses, label='Train Loss')
            plt.plot(range(1, len(sep_val_losses) + 1), sep_val_losses, label='Validation Loss')
            plt.xlabel('Epoch'); plt.ylabel('Loss (L1)'); plt.title('Task 1: Separation Model Training Curve (5s Chunks)')
            plt.legend(); plt.grid(True)
            plot_save_path = os.path.join(OUTPUT_DIR, 'separation_training_curve_5s.png')
            plt.savefig(plot_save_path); print(f"Saved training curve plot to {plot_save_path}")

    # load best separation model checkpoint
    best_sep_model_path = os.path.join(OUTPUT_DIR, "separation_model_best.pth")
    if os.path.exists(best_sep_model_path):
        print(f"\nLoading best separation model for evaluation from: {best_sep_model_path}")
        try: separation_model.load_state_dict(torch.load(best_sep_model_path, map_location=device)); separation_model.eval(); print("Successfully loaded BEST separation model checkpoint.")
        except Exception as e: print(f"!!! ERROR loading final best separation model checkpoint: {e} !!!"); separation_model.eval()
    else: print(f"\n!!! WARNING: Best separation model checkpoint not found. Using potentially untrained model! !!!"); separation_model.eval()

    # training classification model (task 2 - weight decay & early stopping)
    print("\nRe-initializing classification model for fresh training...")
    classification_model = get_pretrained_genre_classifier(num_genres=NUM_GENRES, pretrained=True).to(device)
    optimizer_cls = optim.Adam(classification_model.parameters(), lr=LEARNING_RATE_CLS, weight_decay=weight_decay_cls)

    best_cls_model_path = os.path.join(OUTPUT_DIR, "classification_model_best.pth")
    # delete previous checkpoint
    if os.path.exists(best_cls_model_path):
        print(f"Deleting previous classification checkpoint: {best_cls_model_path}")
        try:
            os.remove(best_cls_model_path)
        except Exception as e_del_cls:
            print(f"Warning: Could not delete {best_cls_model_path}: {e_del_cls}")

    print("\n--- Training Classification Model (Task 2) ---")
    print(f"Starting training for up to {NUM_EPOCHS_CLS} epochs (with Early Stopping)...")
    best_cls_val_f1 = -1.0; epochs_no_improve_cls = 0; patience_cls = 7
    cls_train_losses, cls_val_losses, cls_val_accuracies, cls_val_f1s = [], [], [], []
    for epoch in range(NUM_EPOCHS_CLS):
        current_epoch_display = epoch + 1; print(f"\nEpoch {current_epoch_display}/{NUM_EPOCHS_CLS}"); epoch_start_time = time.time()
        try:
            train_loss = train_epoch(classification_model, train_genre_loader, criterion_cls, optimizer_cls, device, task_type='classification', epoch_num=current_epoch_display, total_epochs=NUM_EPOCHS_CLS)
            val_metrics = validate_epoch(classification_model, val_genre_loader, criterion_cls, device, task_type='classification', epoch_num=current_epoch_display, total_epochs=NUM_EPOCHS_CLS)
            val_loss = val_metrics['loss']; val_accuracy = val_metrics.get('accuracy', 0.0); val_f1 = val_metrics.get('f1', 0.0)
            if np.isnan(train_loss) or np.isnan(val_loss): print(f"!!! ERROR: NaN loss detected in Epoch {current_epoch_display}. Stopping."); break
            cls_train_losses.append(train_loss); cls_val_losses.append(val_loss); cls_val_accuracies.append(val_accuracy); cls_val_f1s.append(val_f1)
            epoch_end_time = time.time()
            print(f"  Epoch {current_epoch_display}/{NUM_EPOCHS_CLS} Summary: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, Val Acc={val_accuracy:.4f}, Val F1={val_f1:.4f}, Time={epoch_end_time - epoch_start_time:.2f}s")
            # early stopping check
            if val_f1 > best_cls_val_f1:
                best_cls_val_f1 = val_f1; epochs_no_improve_cls = 0
                save_path = os.path.join(OUTPUT_DIR, "classification_model_best.pth")
                try: torch.save(classification_model.state_dict(), save_path); print(f"  Saved best classification model (F1: {val_f1:.4f})")
                except Exception as e: print(f"  Error saving model: {e}")
            else:
                epochs_no_improve_cls += 1; print(f"  Validation F1 did not improve for {epochs_no_improve_cls} epoch(s). Best: {best_cls_val_f1:.4f}")
            if epochs_no_improve_cls >= patience_cls: print(f"\nEarly stopping triggered."); break
        except NameError as e_func: print(f"!!! ERROR: Training function not defined ({e_func}). Run Cell [30]."); break
        except Exception as e_gen: print(f"!!! UNEXPECTED ERROR in epoch {current_epoch_display}: {e_gen}"); traceback.print_exc(); break
    print("\nClassification Model Training Finished.")
    # plot task 2 results
    if cls_train_losses:
        epochs_trained_cls = len(cls_train_losses); fig, ax1 = plt.subplots(figsize=(12, 5)); color='tab:red'; ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss', color=color); ax1.plot(range(1, epochs_trained_cls + 1), cls_train_losses, '--', label='Train Loss', color=color); ax1.plot(range(1, epochs_trained_cls + 1), cls_val_losses, label='Validation Loss', color=color); ax1.tick_params(axis='y', labelcolor=color); ax1.legend(loc='upper left'); ax1.grid(True); ax2 = ax1.twinx(); color='tab:blue'; ax2.set_ylabel('Metrics (F1/Acc)', color=color); ax2.plot(range(1, epochs_trained_cls + 1), cls_val_accuracies, ':', label='Validation Accuracy', color=color); ax2.plot(range(1, epochs_trained_cls + 1), cls_val_f1s, label='Validation F1', color=color); ax2.tick_params(axis='y', labelcolor=color); ax2.set_ylim(bottom=0); ax2.legend(loc='upper right'); fig.tight_layout(); plt.title('Task 2: Classification Model Training Curve (Regularized)'); plot_save_path_cls = os.path.join(OUTPUT_DIR, 'classification_training_curve_regularized_final.png'); plt.savefig(plot_save_path_cls); print(f"Saved training curve plot to {plot_save_path_cls}")

    # load best classification model
    best_cls_model_path = os.path.join(OUTPUT_DIR, "classification_model_best.pth")
    if os.path.exists(best_cls_model_path):
        print(f"\nLoading best classification model (based on Val F1) from: {best_cls_model_path}")
        try:
            classification_model.load_state_dict(torch.load(best_cls_model_path, map_location=device))
            classification_model.eval()
            print("Successfully loaded BEST classification model checkpoint.")
        except Exception as e:
            print(f"!!! ERROR loading best classification model checkpoint: {e} !!!")
            if 'classification_model' in locals() and hasattr(classification_model, 'eval'): classification_model.eval()
    else:
        print(f"\n!!! WARNING: Best classification model checkpoint not found at {best_cls_model_path}. Using current model state. !!!")
        if 'classification_model' in locals() and hasattr(classification_model, 'eval'): classification_model.eval()


    # evaluation phase
    separation_model.eval(); classification_model.eval()
    print("\n--- Evaluating Models on Test Set ---")
    print("\nEvaluating Separation Model (SDR)...")
    try: avg_sdr = evaluate_separation_model(separation_model, test_sep_loader, device)
    except Exception as e_eval_sep: print(f"Error during separation eval: {e_eval_sep}"); avg_sdr = "N/A"
    print("\nEvaluating Classification Model (Baseline)...")
    try: baseline_cls_results = evaluate_classification_model(classification_model, test_genre_loader, device)
    except Exception as e_eval_cls: print(f"Error during classification eval: {e_eval_cls}"); baseline_cls_results = {}

    # combination evaluation
    print("\nEvaluating Combined Model (Sequential)...")
    combined_test_files = genre_files_test if 'genre_files_test' in locals() else None
    combined_test_labels = genre_labels_test if 'genre_labels_test' in locals() else None
    all_combined_preds = []; all_combined_targets = []
    if combined_test_files and combined_test_labels:
        print(f"Processing {len(combined_test_files)} files for combined evaluation...")
        try:
            for i, file_path in enumerate(combined_test_files):
                true_label = combined_test_labels[i]
                pred_label, _ = separate_and_classify(file_path, separation_model, classification_model, device)
                if pred_label is not None: all_combined_preds.append(pred_label); all_combined_targets.append(true_label)
                if (i + 1) % 50 == 0: print(f"  Processed {i+1}/{len(combined_test_files)} for combined eval...")
        except NameError as func_err: print(f"!!! ERROR: Required function not defined ({func_err}). Make sure Cell [32] was run.")
        except Exception as e_comb_eval: print(f"Error during combined eval processing: {e_comb_eval}")
    if all_combined_targets:
        combined_accuracy = accuracy_score(all_combined_targets, all_combined_preds); combined_f1 = f1_score(all_combined_targets, all_combined_preds, average='weighted', zero_division=0); print("\n--- Combined (Sequential) Classification Report ---"); target_names_cls = GENRE_CLASSES if 'GENRE_CLASSES' in locals() else [str(i) for i in range(NUM_GENRES)]; present_labels = sorted(np.unique(all_combined_targets + all_combined_preds)); valid_present_labels = [l for l in present_labels if l >= 0 and l < len(target_names_cls)]; filtered_target_names = [target_names_cls[i] for i in valid_present_labels];
        if not filtered_target_names: print("No valid predictions/targets for classification report.")
        else: print(classification_report(all_combined_targets, all_combined_preds, labels=valid_present_labels, target_names=filtered_target_names, zero_division=0))
        print(f"--- Combined Accuracy: {combined_accuracy:.4f} ---"); print(f"--- Combined Weighted F1-Score: {combined_f1:.4f} ---"); baseline_acc = baseline_cls_results.get('accuracy', -1); baseline_f1 = baseline_cls_results.get('f1', -1); print(f"\nComparison: Baseline Acc={baseline_acc:.4f}, F1={baseline_f1:.4f} | Combined Acc={combined_accuracy:.4f}, F1={combined_f1:.4f}")
    else: print("Could not perform combined evaluation (no valid targets/predictions).")


    # case study
    print("\n--- Processing Case Study File ---")
    if 'CASE_STUDY_FILE' in locals() and os.path.exists(CASE_STUDY_FILE):
        print(f"Loading case study file: {CASE_STUDY_FILE}")
        # classify original case study audio
        print("Classifying original case study audio...")
        original_pred = None
        try:
            case_study_waveform = load_audio(CASE_STUDY_FILE, target_sr=SAMPLE_RATE)
            if case_study_waveform is not None:
                required_samples = AUDIO_CHUNK_SAMPLES
                if len(case_study_waveform) >= required_samples:
                    start = (len(case_study_waveform) - required_samples) // 2
                    case_chunk = case_study_waveform[start : start + required_samples]
                else:
                    padding = required_samples - len(case_study_waveform)
                    case_chunk = torch.nn.functional.pad(case_study_waveform, (0, padding))
                mel_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=SAMPLE_RATE, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=128).to(device)
                log_mel_spec_transform = torchaudio.transforms.AmplitudeToDB().to(device)
                mel_spec = mel_spectrogram_transform(case_chunk.to(device))
                log_mel_spec = log_mel_spec_transform(mel_spec).unsqueeze(0).unsqueeze(0)
                with torch.no_grad():
                    logits = classification_model(log_mel_spec)
                    original_pred = torch.argmax(logits, dim=1).item()
                print(f"  Original Prediction: {GENRE_CLASSES[original_pred] if original_pred is not None and original_pred < len(GENRE_CLASSES) else 'Error/Unknown'}")
            else:
                print("  Failed to load case study waveform.")
        except Exception as e:
            print(f"  Error classifying original: {e}")
            traceback.print_exc() # print full traceback

        print("\nSeparating and Classifying case study audio...")
        try:
            combined_pred, separated_wav = separate_and_classify(CASE_STUDY_FILE, separation_model, classification_model, device)
            if combined_pred is not None and separated_wav is not None:
                print(f"  Combined Prediction: {GENRE_CLASSES[combined_pred] if combined_pred < len(GENRE_CLASSES) else 'Unknown'}")
                separated_filename = os.path.join(OUTPUT_DIR, f"{os.path.splitext(os.path.basename(CASE_STUDY_FILE))[0]}_separated_classified_final.wav")
                try:
                    if separated_wav.is_cuda or hasattr(separated_wav, 'is_mps') and separated_wav.is_mps: separated_wav = separated_wav.cpu()
                    save_data = separated_wav.to(torch.float32).numpy()
                    if save_data.ndim > 1: save_data = save_data.squeeze()
                    sf.write(separated_filename, save_data, SAMPLE_RATE)
                    print(f"  Saved separated audio to {separated_filename}")
                except Exception as e: print(f"  Error saving separated audio: {e}")
            else: print("  Failed to process combined.")
        except NameError: print("separate_and_classify function not found. Make sure Cell [32] was run.")
        except Exception as e_sep_cls: print(f"  Error during combined processing: {e_sep_cls}"); traceback.print_exc() # print full traceback
    else:
        print(f"Case study file not found or not defined: {locals().get('CASE_STUDY_FILE', 'Not Defined')}")

else:
     print("!!! Skipping Run because DataLoaders or Models are missing. Check previous cells.")


overall_end_time = time.time()
print("\n=============================================")
print("=== Project Execution Finished ===")
if 'overall_start_time' in locals(): print(f"Total execution time: {(overall_end_time - overall_start_time)/60:.2f} minutes")
else: print("Total execution time not calculated.")
print("=============================================")

=== Starting Project Execution (Final Run Attempt) ===
=== Using Device: cuda ===
=== Target Separation Chunks: 5s ===
=== Target Batch Size: 32 ===

Initializing models...
U-Net Encoder features: [16, 32, 64, 128, 256]
U-Net Bottleneck features: 512
U-Net Decoder features (output channels per block): [256, 128, 64, 32, 16]
Loading pre-trained ResNet18 model...
Adapted model.conv1 to accept 1 input channel.
Initialized new conv1 weights by averaging original RGB weights.
Adapted model.fc to output 8 classes.
Models initialized.

Setting up loss functions and optimizers...
Optimizer_sep: Adam, LR=1e-05, Scheduler=ReduceLROnPlateau(patience=5)
Optimizer_cls: LR=0.001, Weight Decay=0.0001

--- Training Separation Model (Task 1 - 5s Chunks) ---
Deleting previous checkpoint to start fresh: /content/drive/MyDrive/cw1_DL/output/separation_model_best.pth
Starting training for 30 epochs...

Epoch 1/30




Epoch 1/30 [Separation Train]:   0%|          | 0/525 [00:00<?, ?it/s]

Interpolated output shape to match target: torch.Size([32, 1, 1025, 431])

--- Debug Info (Epoch 1, Batch 0) ---
Task Type: separation
Input shape: torch.Size([32, 1, 1025, 431]), Input Min: 0.0000, Max: 443.3140, Mean: 0.9287
Target shape: torch.Size([32, 1, 1025, 431])
Target Min: 0.0000, Max: 1.0000, Mean: 0.7087
Output shape (potentially after interpolation): torch.Size([32, 1, 1025, 431])
Output Min: 0.0000, Max: 1.0000, Mean: 0.5363
Loss value: 0.419299
--- End Debug Info ---



KeyboardInterrupt: 

# 8.1 Evaluation exectution function


In [17]:

# record start time
overall_start_time = time.time()

print("=============================================")
print("=== Starting Project Execution (EVALUATION ONLY) ===")
# ensure config variables defined (from cell [27])
config_vars_ok = all(var in locals() or var in globals() for var in ['device', 'AUDIO_CHUNK_DURATION_S', 'BATCH_SIZE', 'NUM_EPOCHS_SEP', 'NUM_EPOCHS_CLS', 'LEARNING_RATE_SEP', 'LEARNING_RATE_CLS', 'OUTPUT_DIR', 'NUM_GENRES', 'GENRE_CLASSES', 'SAMPLE_RATE', 'N_FFT', 'HOP_LENGTH', 'AUDIO_CHUNK_SAMPLES']) # Added AUDIO_CHUNK_SAMPLES check
if not config_vars_ok:
     print("!!! ERROR: Essential config variables missing.")
     print("!!! Run Configuration Cell [27] first.")
     raise NameError("Configuration variables missing.")
else:
    print(f"=== Using Device: {device} ===")
    print(f"=== Target Separation Chunks (from training): {AUDIO_CHUNK_DURATION_S}s ===")
    print(f"=== Target Batch Size (from training): {BATCH_SIZE} ===")
print("=============================================")

# initialize models (needed before loading weights)
print("\nInitializing model architectures...")
try:
    # assumes model classes defined (cell [29])
    separation_model = SeparationUNet().to(device)
    classification_model = get_pretrained_genre_classifier(num_genres=NUM_GENRES, pretrained=False).to(device) # pretrained=false is fine when loading state_dict
    print("Model architectures initialized.")
except NameError as ne:
    print(f"!!! ERROR: Model class definition not found ({ne}). Run Cell [29].")
    raise
except Exception as e_init:
     print(f"!!! ERROR initializing models: {e_init}")
     raise

# define paths to saved models
# assumes output_dir defined (cell [27])
best_sep_model_path = os.path.join(OUTPUT_DIR, "separation_model_best.pth")
best_cls_model_path = os.path.join(OUTPUT_DIR, "classification_model_best.pth")

# load best separation model checkpoint
model_loaded_sep = False
if os.path.exists(best_sep_model_path):
    print(f"\nLoading best separation model for evaluation from: {best_sep_model_path}")
    try:
        separation_model.load_state_dict(torch.load(best_sep_model_path, map_location=device))
        separation_model.eval() # set to evaluation mode
        print("Successfully loaded BEST separation model checkpoint.")
        model_loaded_sep = True
    except Exception as e:
        print(f"!!! ERROR loading separation model checkpoint: {e} !!!")
else:
    print(f"\n!!! WARNING: Best separation model checkpoint not found at {best_sep_model_path} !!!")

# load best classification model checkpoint
model_loaded_cls = False
if os.path.exists(best_cls_model_path):
    print(f"\nLoading best classification model (based on Val F1) from: {best_cls_model_path}")
    try:
        classification_model.load_state_dict(torch.load(best_cls_model_path, map_location=device))
        classification_model.eval() # set to evaluation mode
        print("Successfully loaded BEST classification model checkpoint.")
        model_loaded_cls = True
    except Exception as e:
        print(f"!!! ERROR loading best classification model checkpoint: {e} !!!")
else:
    print(f"\n!!! WARNING: Best classification model checkpoint not found at {best_cls_model_path} !!!")


# check if dataloaders created (test sets needed)
# assumes cell [34] was run after cell [27]
# only need test loaders
loaders_available = all([
    'test_sep_loader' in locals() and test_sep_loader is not None,
    'test_genre_loader' in locals() and test_genre_loader is not None
])

if not loaders_available:
    print("\n!!! ERROR: Test DataLoaders not available. Run Cell [34] after Config Cell [27].")
    raise ValueError("Test DataLoaders missing, cannot proceed.")

# proceed only if models loaded and test loaders available
elif model_loaded_sep and model_loaded_cls:

    # evaluation phase
    # ensure models are in eval mode
    separation_model.eval()
    classification_model.eval()
    print("\n--- Evaluating Models on Test Set ---")

    # evaluate separation model
    print("\nEvaluating Separation Model (SDR)...")
    avg_sdr = "N/A" # default value
    try:
        # assumes evaluate_separation_model defined (cell [31])
        avg_sdr = evaluate_separation_model(separation_model, test_sep_loader, device)
    except NameError:
         print("evaluate_separation_model function not found. Run Cell [31].")
    except Exception as e_eval_sep:
         print(f"Error during separation eval: {e_eval_sep}")
         traceback.print_exc() # print traceback for eval errors

    # evaluate classification model (baseline)
    print("\nEvaluating Classification Model (Baseline)...")
    baseline_cls_results = {} # default value
    try:
        # assumes evaluate_classification_model defined (cell [31])
        baseline_cls_results = evaluate_classification_model(classification_model, test_genre_loader, device)
    except NameError:
         print("evaluate_classification_model function not found. Run Cell [31].")
    except Exception as e_eval_cls:
         print(f"Error during classification eval: {e_eval_cls}")
         traceback.print_exc()

    # combination evaluation
    print("\nEvaluating Combined Model (Sequential)...")
    # assumes test files/labels defined (cell [34])
    combined_test_files = genre_files_test if 'genre_files_test' in locals() else None
    combined_test_labels = genre_labels_test if 'genre_labels_test' in locals() else None
    all_combined_preds = []; all_combined_targets = []

    if combined_test_files and combined_test_labels:
        print(f"Processing {len(combined_test_files)} files for combined evaluation...")
        try:
            # assumes separate_and_classify defined (cell [32])
            for i, file_path in enumerate(tqdm(combined_test_files, desc="Combined Eval")): # add progress bar
                true_label = combined_test_labels[i]
                pred_label, _ = separate_and_classify(file_path, separation_model, classification_model, device)
                if pred_label is not None:
                    all_combined_preds.append(pred_label)
                    all_combined_targets.append(true_label)
                # optional: print progress less frequently
        except NameError as func_err:
             print(f"!!! ERROR: Required function not defined ({func_err}). Make sure Cell [32] was run.")
        except Exception as e_comb_eval:
            print(f"Error during combined eval processing: {e_comb_eval}")
            traceback.print_exc()

        # display combined results if generated
        if all_combined_targets:
            combined_accuracy = accuracy_score(all_combined_targets, all_combined_preds)
            combined_f1 = f1_score(all_combined_targets, all_combined_preds, average='weighted', zero_division=0)
            print("\n--- Combined (Sequential) Classification Report ---")
            # assumes genre_classes defined (cell [27])
            target_names_cls = GENRE_CLASSES if 'GENRE_CLASSES' in locals() else [str(i) for i in range(NUM_GENRES)]
            present_labels = sorted(np.unique(all_combined_targets + all_combined_preds))
            valid_present_labels = [l for l in present_labels if l >= 0 and l < len(target_names_cls)]
            filtered_target_names = [target_names_cls[i] for i in valid_present_labels]

            if not filtered_target_names: print("No valid predictions/targets for classification report.")
            else: print(classification_report(all_combined_targets, all_combined_preds, labels=valid_present_labels, target_names=filtered_target_names, zero_division=0))

            print(f"--- Combined Accuracy: {combined_accuracy:.4f} ---")
            print(f"--- Combined Weighted F1-Score: {combined_f1:.4f} ---")
            baseline_acc = baseline_cls_results.get('accuracy', -1)
            baseline_f1 = baseline_cls_results.get('f1', -1)
            # ensure baseline results are numbers before formatting
            baseline_acc_str = f"{baseline_acc:.4f}" if isinstance(baseline_acc, (int, float)) else "N/A"
            baseline_f1_str = f"{baseline_f1:.4f}" if isinstance(baseline_f1, (int, float)) else "N/A"
            print(f"\nComparison: Baseline Acc={baseline_acc_str}, F1={baseline_f1_str} | Combined Acc={combined_accuracy:.4f}, F1={combined_f1:.4f}")
        else:
            print("Could not perform combined evaluation (no valid targets/predictions).")
    else:
        print("Skipping combined evaluation - Test files/labels not found (check Cell [34] split).")


    # case study
    print("\n--- Processing Case Study File ---")
    # assumes case_study_file defined (cell [27])
    if 'CASE_STUDY_FILE' in locals() and os.path.exists(CASE_STUDY_FILE):
        print(f"Loading case study file: {CASE_STUDY_FILE}")
        print("Classifying original case study audio...")
        original_pred = None
        try:
            # assumes load_audio, sample_rate, etc. defined
            case_study_waveform = load_audio(CASE_STUDY_FILE, target_sr=SAMPLE_RATE)
            if case_study_waveform is not None:
                required_samples = AUDIO_CHUNK_SAMPLES
                if len(case_study_waveform) >= required_samples: start = (len(case_study_waveform) - required_samples)//2; case_chunk = case_study_waveform[start : start + required_samples]
                else: padding = required_samples - len(case_study_waveform); case_chunk = torch.nn.functional.pad(case_study_waveform, (0, padding))
                mel_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=SAMPLE_RATE, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=128).to(device); log_mel_spec_transform = torchaudio.transforms.AmplitudeToDB().to(device); mel_spec = mel_spectrogram_transform(case_chunk.to(device)); log_mel_spec = log_mel_spec_transform(mel_spec).unsqueeze(0).unsqueeze(0);
                with torch.no_grad(): logits = classification_model(log_mel_spec); original_pred = torch.argmax(logits, dim=1).item();
                print(f"  Original Prediction: {GENRE_CLASSES[original_pred] if original_pred is not None and original_pred < len(GENRE_CLASSES) else 'Error/Unknown'}")
            else: print("  Failed to load case study waveform.")
        except Exception as e: print(f"  Error classifying original: {e}"); traceback.print_exc()

        print("\nSeparating and Classifying case study audio...")
        try:
            # assumes separate_and_classify defined (cell [32])
            combined_pred, separated_wav = separate_and_classify(CASE_STUDY_FILE, separation_model, classification_model, device)
            if combined_pred is not None and separated_wav is not None:
                print(f"  Combined Prediction: {GENRE_CLASSES[combined_pred] if combined_pred < len(GENRE_CLASSES) else 'Unknown'}")
                # assumes output_dir defined
                separated_filename = os.path.join(OUTPUT_DIR, f"{os.path.splitext(os.path.basename(CASE_STUDY_FILE))[0]}_separated_classified_eval_only.wav") # use different filename for eval output
                try:
                    if separated_wav.is_cuda or hasattr(separated_wav, 'is_mps') and separated_wav.is_mps: separated_wav = separated_wav.cpu()
                    save_data = separated_wav.to(torch.float32).numpy();
                    if save_data.ndim > 1: save_data = save_data.squeeze()
                    # assumes sf, sample_rate defined
                    sf.write(separated_filename, save_data, SAMPLE_RATE); print(f"  Saved separated audio to {separated_filename}")
                except Exception as e: print(f"  Error saving separated audio: {e}")
            else: print("  Failed to process combined.")
        except NameError: print("separate_and_classify function not found. Make sure Cell [32] was run.")
        except Exception as e_sep_cls: print(f"  Error during combined processing: {e_sep_cls}"); traceback.print_exc()
    else: print(f"Case study file not found or not defined: {locals().get('CASE_STUDY_FILE', 'Not Defined')}")

else:
     print("!!! Skipping Evaluation because models could not be loaded or DataLoaders missing. Check previous cells and file paths.")


overall_end_time = time.time()
print("\n=============================================")
print("=== Project Execution Finished ===")
if 'overall_start_time' in locals(): print(f"Total execution time: {(overall_end_time - overall_start_time)/60:.2f} minutes")
else: print("Total execution time not calculated.")
print("=============================================")



=== Starting Project Execution (EVALUATION ONLY) ===
=== Using Device: cuda ===
=== Target Separation Chunks (from training): 5s ===
=== Target Batch Size (from training): 32 ===

Initializing model architectures...
U-Net Encoder features: [16, 32, 64, 128, 256]
U-Net Bottleneck features: 512
U-Net Decoder features (output channels per block): [256, 128, 64, 32, 16]
Loading randomly initialized ResNet18 model...
Adapted model.conv1 to accept 1 input channel.
Adapted model.fc to output 8 classes.
Model architectures initialized.

Loading best separation model for evaluation from: /content/drive/MyDrive/cw1_DL/output/separation_model_best.pth
Successfully loaded BEST separation model checkpoint.

Loading best classification model (based on Val F1) from: /content/drive/MyDrive/cw1_DL/output/classification_model_best.pth
Successfully loaded BEST classification model checkpoint.

--- Evaluating Models on Test Set ---

Evaluating Separation Model (SDR)...
Evaluating separation model on test 

Separation Eval:   0%|          | 0/113 [00:00<?, ?it/s]


--- SDR Debug (Batch 0, Item 0) ---
  true_wav_i | Shape: torch.Size([220500]), Min: -0.5025, Max: 0.6137, Mean: 0.0011, Std: 0.1557
  est_wav_i  | Shape: torch.Size([220500]), Min: -0.5030, Max: 0.5754, Mean: 0.0003, Std: 0.0632
  Individual SDR calculated: 12.1899
--- End SDR Debug ---

--- SDR Debug (Batch 0, Item 1) ---
  true_wav_i | Shape: torch.Size([220500]), Min: -0.7108, Max: 0.7418, Mean: 0.0001, Std: 0.1634
  est_wav_i  | Shape: torch.Size([220500]), Min: -0.6768, Max: 0.5292, Mean: 0.0000, Std: 0.0939


	Deprecated as of mir_eval version 0.8.
	It will be removed in mir_eval version 0.9.
  sdr_value, _, _, _ = mir_eval.separation.bss_eval_sources(true[np.newaxis, :], est[np.newaxis, :])


  Individual SDR calculated: 2.0751
--- End SDR Debug ---

--- Average SDR on Test Set (from 3600 valid scores): 4.8087 ---

Evaluating Classification Model (Baseline)...
Evaluating classification model on test set...


Classification Eval:   0%|          | 0/38 [00:00<?, ?it/s]


--- Classification Report ---
              precision    recall  f1-score   support

           0       0.52      0.73      0.61       150
           1       0.47      0.45      0.46       150
           2       0.76      0.51      0.61       150
           3       0.84      0.61      0.71       150
           4       0.54      0.68      0.60       150
           5       0.72      0.65      0.68       150
           6       0.38      0.32      0.35       150
           7       0.55      0.68      0.61       150

    accuracy                           0.58      1200
   macro avg       0.60      0.58      0.58      1200
weighted avg       0.60      0.58      0.58      1200


--- Confusion Matrix ---
Labels: ['0', '1', '2', '3', '4', '5', '6', '7']
[[109   9   0   4  12   5   7   4]
 [ 16  68   5   0  21   7  16  17]
 [  3  17  76   0  17   9  17  11]
 [ 35   4   0  92   2   5  11   1]
 [ 14  14   4   2 102   0   3  11]
 [ 16   7   2   3   9  97  13   3]
 [ 10  14   6   8  20   8  48  36

Combined Eval:   0%|          | 0/1200 [00:00<?, ?it/s]


--- Combined (Sequential) Classification Report ---
               precision    recall  f1-score   support

   Electronic       0.55      0.67      0.60       150
 Experimental       0.48      0.47      0.47       150
         Folk       0.72      0.53      0.61       150
      Hip-Hop       0.82      0.65      0.72       150
 Instrumental       0.51      0.68      0.58       150
International       0.73      0.60      0.66       150
          Pop       0.34      0.31      0.32       150
         Rock       0.56      0.67      0.61       150

     accuracy                           0.57      1200
    macro avg       0.59      0.57      0.57      1200
 weighted avg       0.59      0.57      0.57      1200

--- Combined Accuracy: 0.5725 ---
--- Combined Weighted F1-Score: 0.5735 ---

Comparison: Baseline Acc=0.5783, F1=0.5779 | Combined Acc=0.5725, F1=0.5735

--- Processing Case Study File ---
Case study file not found or not defined: /content/drive/MyDrive/cw1_DL/Case_study_city.mp3

=