In [48]:
# Move imports to utility
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os, gc, random, math, time, copy , zipfile, tarfile, shutil, subprocess
from pathlib import Path
from tqdm.notebook import tqdm
import IPython.display as ipd
from IPython.display import display, clear_output
import ipywidgets as widgets

import librosa
import librosa.display
import soundfile as sf

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import torchaudio
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader
import torch.amp as amp

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix, average_precision_score
from sklearn.preprocessing import LabelEncoder

In [2]:
import birdclef_util as utl

In [42]:
!ls /kaggle/input/precomputed-specs-np-zipped/

precomputed-specs-np-zipped


In [43]:
!file /kaggle/input/precomputed-specs-np-zipped/precomputed-specs-np-zipped

/kaggle/input/precomputed-specs-np-zipped/precomputed-specs-np-zipped: gzip compressed data, from Unix, original size modulo 2^32 240396288


In [49]:
COMP_PATH = Path("/kaggle/input/precomputed-specs-np-zipped/precomputed-specs-np-zipped")
DECOMP_TARGET_DIR = Path("/kaggle/working/precomputed_specs_extracted")
DECOMP_TARGET_DIR.mkdir(exist_ok=True, parents=True)
print(f"Decompression target directory: {DECOMP_TARGET_DIR}")

Decompression target directory: /kaggle/working/precomputed_specs_extracted


In [50]:
if 'COMP_PATH' in globals() and COMP_PATH and os.path.exists(COMP_PATH):
    print(f"Attempting to decompress {COMP_PATH} into {DECOMP_TARGET_DIR}...")

    tar_command = [
        "tar",
        "-xzf",
        COMP_PATH,
        "-C",
        str(DECOMP_TARGET_DIR) # tar -C needs a string path
    ]

    try:
        process = subprocess.run(tar_command, check=True, capture_output=True, text=True)
        print("Decompression successful!")
    except subprocess.CalledProcessError as e:
        print("Decompression failed!")
        print("Error code:", e.returncode)
        print("stdout:", e.stdout)
        print("stderr:", e.stderr)
        raise e
    except FileNotFoundError:
        print("Error: 'tar' command not found. This is unlikely in Kaggle environment.")
        raise
else:
    print("Error: COMP_PATH is not defined or file does not exist.")

Attempting to decompress /kaggle/input/precomputed-specs-np-zipped/precomputed-specs-np-zipped into /kaggle/working/precomputed_specs_extracted...
Decompression failed!
Error code: 2
stdout: 
stderr: tar: kaggle/working/precomputed_specs_np/crcwoo1_XC113957.ogg_clip11.npy: Wrote only 1536 of 10240 bytes
tar: kaggle/working/precomputed_specs_np/crcwoo1_XC113957.ogg_clip14.npy: Cannot write: No space left on device
tar: kaggle/working/precomputed_specs_np/crcwoo1_XC113957.ogg_clip17.npy: Cannot write: No space left on device
tar: kaggle/working/precomputed_specs_np/crcwoo1_XC113957.ogg_clip19.npy: Cannot write: No space left on device
tar: kaggle/working/precomputed_specs_np/crcwoo1_XC114065.ogg_clip0.npy: Cannot write: No space left on device
tar: kaggle/working/precomputed_specs_np/crcwoo1_XC114065.ogg_clip4.npy: Cannot write: No space left on device
tar: kaggle/working/precomputed_specs_np/crcwoo1_XC114065.ogg_clip6.npy: Cannot write: No space left on device
tar: kaggle/working/precom

CalledProcessError: Command '['tar', '-xzf', PosixPath('/kaggle/input/precomputed-specs-np-zipped/precomputed-specs-np-zipped'), '-C', '/kaggle/working/precomputed_specs_extracted']' returned non-zero exit status 2.

In [3]:
# Function to seed everything to ensure reproducibility
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False # Change to true if input sizes are kept constant

In [4]:
# Setup 
cfg = utl.Config()
seed_everything(cfg.SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

In [5]:
# Load the training data
train_df = pd.read_csv(cfg.BASE_DATA_PATH/"train.csv")

In [None]:
# MOVE TO UTILS
# As before, create column for durations 
if 'train_df' in globals() and not train_df.empty:
    print("Analyzing audio durations...")
    durations = []
    pbar = tqdm(train_df['filename'].tolist(), desc="Calculating durations")
    for filename in pbar:
        file_path = cfg.BASE_DATA_PATH/"train_audio"/filename
        if file_path.exists():
            try:
                # Efficient approach to get duration with loading the whole file
                info = sf.info(file_path)
                durations.append(info.duration)
            except Exception as e:
                print(f"Could not get info for {filename}: {e}") #Comment / uncomment for debugging
                durations.append(np.nan) # mark errors
        else:
            durations.append(np.nan)

In [None]:
train_df['duration'] = durations
train_df['duration'].describe()

In [None]:
# Carry out Mel Spectrogram transforms
mel_spectrogram_tfms = T.MelSpectrogram(
    sample_rate = cfg.SAMPLE_RATE,
    n_fft = cfg.N_FFT,
    hop_length = cfg.HOP_LENGTH,
    n_mels = cfg.N_MELS,
    f_min = cfg.FMIN,
    f_max = cfg.FMAX
)#.to(device) # moving back to CPU / transforms are moved to the GPU

In [None]:
# Connverty power spec to DB
amp_to_db_tfms = T.AmplitudeToDB(stype='power', top_db=80)#.to(device) 

## Pipeline - Dataset and DataLoader

In [None]:
# Create label encoder
if not train_df.empty and 'primary_label' in train_df.columns:
    u_lbl = sorted(train_df['primary_label'].unique())
    n_class = len(u_lbl)
    print(f"Found {n_class} unique primary labels.")

    # Mappings from label strings to int index and back
    lbl_2_int = {l: i for i, l in enumerate(u_lbl)}
    int_2_lbl = {i: l for l, i in lbl_2_int.items()}
    # Adding int labels to the df 
    train_df['primary_label_int'] = train_df['primary_label'].map(lbl_2_int)
    # Storing mappings in config for global access
    cfg.N_CLASSES = n_class
    cfg.LBL_2_INT = lbl_2_int
    cfg.INT_2_LBL = int_2_lbl
else: 
     print("train_df is empty or missing 'primary label'. Cannot create encoder.")
     cfg.N_CLASSES = 0     

In [None]:
# Helped function for creating the target tensor
def create_target_tensor(pri_lbl_int, n_classes):
    """Creates a multi label binary tensor for a given primary label"""
    t = torch.zeros(n_classes, dtype=torch.float32)
    if pri_lbl_int is not None and 0 <= pri_lbl_int < n_classes:
        t[pri_lbl_int] = 1.0
    return t

### Calculate Clip Information

In [None]:
# --- Step 1: Re-calculate Clip Information (Ensure this finishes) ---
clip_samples = [] # Reset the list
total_clips = 0

if 'train_df' in globals() and not train_df.empty and 'duration' in train_df.columns and 'primary_label_int' in train_df.columns and cfg.N_CLASSES > 0:
    print("\nRE-CALCULATING clip information...")
    valid_files_df = train_df[train_df['duration'] >= cfg.TARGET_DURATION_S].copy()
    print(f"Processing {len(valid_files_df)} files >= {cfg.TARGET_DURATION_S}s duration.")

    for index, row in tqdm(valid_files_df.iterrows(), total=len(valid_files_df), desc="Mapping clips"):
        filename = row['filename']
        primary_label_int = row['primary_label_int']
        duration = row['duration']

        if pd.isna(duration) or duration < cfg.TARGET_DURATION_S:
            continue

        num_clips_in_file = math.floor(duration / cfg.TARGET_DURATION_S)

        if num_clips_in_file > 0:
            for clip_idx in range(num_clips_in_file):
                clip_samples.append({
                    'filename': filename,
                    'primary_label_int': primary_label_int,
                    'clip_index': clip_idx
                })
            total_clips += num_clips_in_file

    print(f"Total number of {cfg.TARGET_DURATION_S}s clips RE-CALCULATED: {total_clips}")
    cfg.TOTAL_CLIPS = total_clips

    # --- Step 2: Verify the recalculated list immediately ---
    problem_filename_str = "ywcpar/iNat922688.ogg"
    problem_clips_info_recalc = [item for item in clip_samples if item['filename'] == problem_filename_str]
    print(f"\n--- Verifying RECALCULATED clip_samples for {problem_filename_str} ---")
    if problem_clips_info_recalc:
        print(f"Found {len(problem_clips_info_recalc)} clips.")
        max_idx_recalc = max(item['clip_index'] for item in problem_clips_info_recalc)
        print(f"Highest clip_index found: {max_idx_recalc}") # Should be 2
    else:
        print("No clips found.")

else:
    print("Cannot re-calculate clips: train_df missing, empty, or missing required columns.")
    clip_samples = []
    cfg.TOTAL_CLIPS = 0

In [None]:
print("\nInstantiating BirdClefDataset with RECALCULATED clip_samples...")
# Ensure clip_samples is not empty before creating dataset
if 'clip_samples' in globals() and clip_samples:
    bird_dataset = utl.BirdClefDataset(clip_samples, cfg.TRAIN_AUDIO_PATH, cfg.N_CLASSES)
    print(f"Dataset instantiated with {len(bird_dataset)} items.")

### Dataloaders

In [None]:
if 'clip_samples' not in globals() or not clip_samples:
    print("Error: clip_samples list not found or is empty. Rerun clip calculations first.")
else:
    print(f"Total clips available for splitting: {len(clip_samples)}.")

    # We'll convert the list to a DataFrame for easier splitting
    if not isinstance(clip_samples, pd.DataFrame):
        split_df = pd.DataFrame(clip_samples)
    else:
        split_df = clip_samples
# Filter classes with only one sample
if not split_df.empty and 'primary_label_int' in split_df.columns:
    label_cnt = split_df['primary_label_int'].value_counts()
    single_samp = label_cnt[label_cnt == 1].index.tolist()
    if single_samp:
        print(f"Found {len(single_samp)} classes with only 1 sample clip.")
        print("Removing classes: ", [cfg.INT_2_LBL.get(lbl, lbl) for lbl in single_samp])
        # Filter
        original_clip_cnt = len(split_df)
        split_df_filt = split_df[~split_df['primary_label_int'].isin(single_samp)].copy()
        del_cnt = original_clip_cnt - len(split_df_filt)
        print(f"Removed {del_cnt} clips belonging to single-sample classes.")
        print(f"Remaining clips: {len(split_df_filt)}")
    else:
        print("No single-sample classes found. Proceed with original data.")
        split_df_filt = split_df.copy()
    
    cfg.VALIDATION_SPLIT = 0.2

    if not split_df_filt.empty:
        features = split_df_filt.index
        labels = split_df_filt['primary_label_int']

        try:
            train_idx, val_idx = train_test_split(
                features,
                test_size=cfg.VALIDATION_SPLIT,
                random_state=cfg.SEED,
                stratify=labels # This should now work
            )

            # Create the train and validation lists of clip info dicts
            train_clip = split_df_filt.loc[train_idx].to_dict('records')
            val_clip = split_df_filt.loc[val_idx].to_dict('records')

            print(f"\nSplit successful after filtering:")
            print(f"Training clips: {len(train_clip)}")
            print(f"Validation clips: {len(val_clip)}")

            # Verify stratification again (optional)
            train_labels_dist = pd.Series([d['primary_label_int'] for d in train_clip]).value_counts(normalize=True)
            val_labels_dist = pd.Series([d['primary_label_int'] for d in val_clip]).value_counts(normalize=True)
            print("\nExample Class Proportions:")
            example_class_int = cfg.LBL_2_INT.get('grekis')
            if example_class_int is not None and example_class_int not in single_samp: # Check if example class was removed
                 print(f"  Class '{cfg.INT_2_LBL[example_class_int]}' (Index {example_class_int}):")
                 print(f"    Train proportion: {train_labels_dist.get(example_class_int, 0):.4f}")
                 print(f"    Valid proportion: {val_labels_dist.get(example_class_int, 0):.4f}")
            else:
                 print("Could not check 'grekis' proportion (not found or was removed).")


        except ValueError as e:
            print(f"\nError during stratified split even after filtering: {e}")
            print("This might happen if filtering leaves other classes with too few samples for the split ratio.")
            train_clip, val_clip = [], []

    else:
        print("No data left after filtering single-sample classes.")
        train_clip, val_clip = [], []

else:
    print("Cannot perform filtering/split: clip data is empty or missing labels.")
    train_clip, val_clip = [], []

In [None]:
# Packaging into a FastAI style container
class BirdDataLoaders:
    def __init__(self, train_dl, valid_dl, device='cuda'):
        self.train = train_dl
        self.valid = valid_dl
        self.device = device
    
    @property
    def train_df(self):
        return self.train.dataset

    @property
    def valid_ds(self):
        return self.valid.dataset

    @property
    def num_classes(self):
        return self.train_df.num_classes if self.train_df else 0

In [None]:
# Explain as to why this is necessary.
try:
     context = mp.get_context(None)
     if not isinstance(context, mp.SpawnContext):
         mp.set_start_method('spawn', force=True)
     else: print("MP already set to spawn.")
except RuntimeError as e:
     print(f"Couldn't set start method: {e}") 

In [None]:
PRECOMPUTED_INPUT_PATH = "/kaggle/input/precomputed-specs-np-zipped/precomputed-specs-np-zipped"

# --- Define Batch Size and Number of Workers ---
cfg.BATCH_SIZE = 64
cfg.NUM_WORKERS = 2 

# --- Create Datasets ---
if 'train_clip' in globals() and 'val_clip' in globals() and train_clip and val_clip:
    train_dataset = utl.BirdClefDataset(train_clip, cfg.TRAIN_AUDIO_PATH, cfg.N_CLASSES, augmentations=None)
    val_dataset = utl.BirdClefDataset(val_clip, cfg.TRAIN_AUDIO_PATH, cfg.N_CLASSES, augmentations=None)

    print(f"Train Dataset size: {len(train_dataset)}")
    print(f"Validation Dataset size: {len(val_dataset)}")

    # --- Create DataLoaders ---
    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg.BATCH_SIZE,
        shuffle=True,
        num_workers=cfg.NUM_WORKERS, 
        pin_memory=False, 
        drop_last=False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=cfg.BATCH_SIZE * 2,
        shuffle=False,
        num_workers=cfg.NUM_WORKERS, 
        pin_memory=False,
        drop_last=False
    )

    print(f"\nDataLoaders created (with num_workers={cfg.NUM_WORKERS}).")
    print(f"  Train loader: {len(train_loader)} batches")
    print(f"  Val loader: {len(val_loader)} batches")

    # --- Test fetching one batch ---
    print("\nFetching one batch from train_loader...")
    try:
        batch_spectrograms, batch_labels = next(iter(train_loader))

        print(f"  Batch spectrograms shape: {batch_spectrograms.shape}")
        print(f"  Batch labels shape: {batch_labels.shape}")
        print(f"  Spectrograms device: {batch_spectrograms.device}") 
        print(f"  Labels device: {batch_labels.device}") 
        print(f"  Example label sum (first item in batch): {batch_labels[0].sum().item()}")
        print(f"  Example label max value: {batch_labels[0].max().item()}")

    except Exception as e:
        print(f"  Error fetching batch even with num_workers=0: {e}")
        import traceback
        traceback.print_exc()

else:
    print("\nCannot create Datasets/DataLoaders: train_clip_info or val_clip_info not available.")

In [None]:
# --- Instantiate the container (if loaders were created successfully) ---
if 'train_loader' in globals() and 'val_loader' in globals():
    if 'device' not in globals():
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Device: {device}")
    dataloaders = BirdDataLoaders(train_loader, val_loader, device=device)
    print("\nBirdDataLoaders container created.")
    print(f"\nNum classes from the container: {dataloaders.num_classes}")
else:
    print("\nCould not create BirdDataLoaders container as loaders are missing.")
    dataloaders = None # Set to None if creation failed

## Loss Function, Optimizer, Metrics and Baseline Model

In [None]:
def baseline_model(num_classes=cfg.N_CLASSES, pretrained=True):
    """Creates EfficientNet-B0 model adapted for 1 channel input and 
     the specified number of classes"""
    model = timm.create_model('efficientnet_b0', pretrained=pretrained)
    # Modify the input convolutional layer 
    original_conv_stem = model.conv_stem
    original_weights = original_conv_stem.weight.data
    # Calculating the mean weights across the input channels
    mean_weights = original_weights.mean(dim=1, keepdim=True)

    new_conv_stem = nn.Conv2d(
        1, # input channel
        original_conv_stem.out_channels,
        kernel_size=original_conv_stem.kernel_size,
        stride=original_conv_stem.stride,
        padding=original_conv_stem.padding,
        bias=(original_conv_stem.bias is not None)
    )
    # Assign the calculated mean weights to the new layer
    new_conv_stem.weight.data = mean_weights
    if original_conv_stem.bias is not None:
        new_conv_stem.bias.data = original_conv_stem.bias.data
    # Replace the original conv stem with the new one
    model.conv_stem = new_conv_stem
    # Modify the final classification layer
    num_in_features = model.classifier.in_features
    model.classifier = nn.Linear(num_in_features, num_classes)
    return model

In [None]:
#Instantiate model and move to GPU
#if 'device' not in globals():
#    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
baseline_model = baseline_model(num_classes=cfg.N_CLASSES, pretrained=True)
baseline_model.to(device)

In [None]:
# Test with a dummy input batch
try:
    dummy_batch_cpu = torch.randn(cfg.BATCH_SIZE, cfg.N_MELS, 313)
    dummy_batch_cpu = dummy_batch_cpu.unsqueeze(1)
    dummy_batch_gpu = dummy_batch_cpu.to(device)
    #Forward pass
    baseline_model.eval() # set to evaluation mode for inference test
    with torch.no_grad():
        outputs = baseline_model(dummy_batch_gpu)
    print(f"Dummy batch input shape (on GPU): {dummy_batch_gpu.shape}")
    print(f"Model output shape: {outputs.shape}")
    assert outputs.shape == (cfg.BATCH_SIZE, cfg.N_CLASSES), "Output mismatch!!"
    print("Forward pass successful!")
except Exception as e:
    print(f"Error during test pass: {e}")
    import traceback
    traceback.print_exc()

In [None]:
loss = nn.BCEWithLogitsLoss()

In [None]:
cfg.LEARNING_RATE = 1e-3
opt = optim.AdamW(baseline_model.parameters(), lr=cfg.LEARNING_RATE)

In [None]:
## MOVE TO UTILS
def calculate_metrics(preds, labels, threshold=0.5):
    metrics = {}
    if isinstance(preds, torch.Tensor):
        probs = torch.sigmoid(preds).cpu().numpy()
        preds_np = probs
    else:
        preds_np = np.array(preds)
        probs = preds_np # For thresholding

    if isinstance(labels, torch.Tensor):
        labels_np = labels.cpu().numpy()
    else:
        labels_np = np.array(labels)

    binary_preds = (probs >= threshold).astype(int)
    sample_acc = (labels_np == binary_preds).mean(axis=1)
    metrics['sample_avg_accuracy'] = np.mean(sample_acc)
    
    # Macro avg. AUC
    class_auc_scores = []
    valid_class = 0
    for i in range(all_labels_np.shape[1]): # Iterate through classes 
        class_labels = labels_np[:, i]
        class_preds = preds_np[:, i]
        # Check if class has both positive and negative samples in this batch
        if len(np.unique(class_labels)) > 1:
            try:
                auc = roc_auc_score(class_labels, class_preds)
                class_auc_scores.append(auc)
                valid_classes += 1
            except ValueError as e: # In case preds are all identical for a class
                print(f"Warning: Could not calculate AUC for class {i}: {e}")
                class_auc_scores.append(np.nan)
        else: # If only one class is present, AUC is undefined
            print(f"Warning: AUC undefined for class {i} (only one label value present)")
            class_auc_scores.append(np.nan)

    if valid_classes > 0: metrics['macro_auc'] = np.nanmean(class_auc_scores)
    else: metrics['macro_auc'] = 0.0
    
    # Macro-averaged Avg. Precision - for multilabel classification
    class_ap_scores = []
    valid_classes_ap = 0
    for i in range(all_labels_np.shape[1]):
        class_labels = all_labels_np[:, i]
        class_preds = all_preds_np[:, i]
        if np.sum(class_labels) > 0: # Need at least one positive sample for AP
             try:
                  ap = average_precision_score(class_labels, class_preds)
                  class_ap_scores.append(ap)
                  valid_classes_ap += 1
             except ValueError as e:
                  # print(f"Warning: Could not calculate AP for class {i}: {e}")
                  class_ap_scores.append(np.nan)
        else:
             # AP is undefined if no positive samples
             # print(f"Warning: AP undefined for class {i} (no positive labels)")
             class_ap_scores.append(np.nan)
    if valid_classes_ap > 0:
         metrics['macro_ap'] = np.nanmean(class_ap_scores)
    else:
         metrics['macro_ap'] = 0.0

    return metrics

### BASIC TRAINING LOOP

In [None]:
scaler = amp.GradScaler() # testing automatic mixed precision

def train_one_epoch(model, loader, loss, opt, device, scaler):
    model.train()
    total_loss = 0.0
    n_samples = 0
    pbar= tqdm(loader, desc="Training", leave=False)
    
    for inputs, labels in pbar:
        inputs = inputs.unsqueeze(1)
        inputs = inputs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        opt.zero_grad()
        with amp.autocast(device_type='cuda'): # amp implementation
            outputs = model(inputs)
            batch_loss = loss(outputs, labels)

        
        scaler.scale(batch_loss).backward() 
        scaler.step(opt) # Unscale gradients
        scaler.update() # Update scaler for next iteration
        # Accumulate loss
        batch_size = inputs.size(0)
        total_loss += batch_loss.item() * batch_size
        n_samples += batch_size
        pbar.set_postfix(loss=f"{batch_loss.item():.4f}")
    
    avg_loss = total_loss / n_samples
    return avg_loss

In [None]:
def validate_one_epoch(model, loader, loss, device):
    model.eval()
    total_loss = 0.0
    n_samples = 0
    preds = []
    labels = []

    with torch.no_grad():
        pbar = tqdm(loader, desc="Validation", leave=False)
        for inputs, labels in pbar:
            inputs = inputs.unsqueeze(1)

            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            with amp.autocast(device_type='cuda'):
                outputs = model(inputs)
                batch_loss = loss(outputs, labels)

            batch_size = inputs.size(0)
            total_loss += batch_loss.item() * batch_size
            num_samples += batch_size
            all_preds.append(outputs.cpu())
            all_labels.append(labels.cpu())
            pbar.set_postfix(loss=f"{batch_loss.item():.4f}")

    avg_loss = total_loss / n_samples
    preds_tensor = torch.cat(preds)
    labels_tensor = torch.cat(labels)
    # Ensure metrics are calculated using float32 since amp may output float16 logits
    metrics = calculate_metrics(preds_tensor.float(), labels_tensor.float())
    return avg_loss, metrics 

In [None]:
def run_training(model, dataloaders, loss, opt, num_epochs, device):
    """Runs the main training loop."""
    scaler = amp.GradScaler()
    start_time = time.time()
    best_val_metric = 0.0 
    best_model_state = None
    history = {'train_loss': [], 'val_loss': [], 'val_macro_auc': [], 
               'val_macro_ap': [], 'val_sample_avg_accuracy': []}

    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")

        # --- Training Phase ---
        train_loss = train_one_epoch(model, dataloaders.train, loss, opt, device, scaler) # pass scaler
        history['train_loss'].append(train_loss)
        print(f"Epoch {epoch+1} Train Loss: {train_loss:.4f}")

        # --- Validation Phase ---
        val_loss, metrics = validate_one_epoch(model, dataloaders.valid, loss, device)
        history['val_loss'].append(val_loss)
        history['val_macro_auc'].append(metrics.get('macro_auc', 0.0))
        history['val_macro_ap'].append(metrics.get('macro_ap', 0.0))
        history['val_sample_avg_accuracy'].append(metrics.get('sample_avg_accuracy', 0.0))

        print(f"Epoch {epoch+1} Val Loss : {val_loss:.4f}")
        print(f"Epoch {epoch+1} Val Metrics: ")
        for k, v in metrics.items():
            print(f"  {k}: {v:.4f}")

        # --- Checkpoint Best Model (based on Macro AUC) ---
        current_val_metric = metrics.get('macro_auc', 0.0)
        if current_val_metric > best_val_metric:
            best_val_metric = current_val_metric
            # Use deepcopy to avoid saving reference to changing model
            best_model_state = copy.deepcopy(model.state_dict())
            print(f"*** New best Macro AUC: {best_val_metric:.4f}. Saving model state. ***")

        # --- Epoch Timing ---
        epoch_end_time = time.time()
        print(f"Epoch {epoch+1} Time: {epoch_end_time - epoch_start_time:.2f} seconds")

        # --- Optional: Clear CUDA Cache ---
        torch.cuda.empty_cache()
        gc.collect()

    # --- End of Training ---
    total_time = time.time() - start_time
    print(f"\n--- Training Finished ---")
    print(f"Total Training Time: {total_time // 60:.0f}m {total_time % 60:.0f}s")
    print(f"Best Validation Macro AUC: {best_val_metric:.4f}")

    # Load the best model state back into the model
    if best_model_state:
         print("Loading best model weights...")
         model.load_state_dict(best_model_state)

    return model, history

In [None]:
cfg.EPOCHS = 5
if all(k in globals() for k in ['baseline_model', 'dataloaders', 'loss', 
                                'opt', 'device']):
    trained_model, train_history = run_training(
        baseline_model,
        dataloaders,
        loss,
        opt,
        num_epochs=cfg.EPOCHS,
        device=device
    )

    epochs = range(1, cfg.EPOCHS + 1)
    # Plot Losses
    plt.subplot(1, 2, 1)
    plt.plot(epochs, training_history['train_loss'], 'bo-', label='Training Loss')
    plt.plot(epochs, training_history['val_loss'], 'ro-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss (BCEWithLogits)')
    plt.legend()
    plt.grid(True)

    # Plot Macro AUC
    plt.subplot(1, 2, 2)
    plt.plot(epochs, training_history['val_macro_auc'], 'go-', label='Validation Macro AUC')
    plt.title('Validation Macro AUC')
    plt.xlabel('Epochs')
    plt.ylabel('Macro AUC')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()

else:
    print("Error: One or more required components (model, dataloaders, criterion, optimizer, device) not found.")
    