# Phase 4.7: Control Experiment - Pushing the Limits with EfficientNet

**Objective:** This notebook represents the final experiment within our CNN-on-spectrograms paradigm. Having determined that data balancing is crucial, we now investigate the impact of the model architecture itself.

Our goal is to test if a more modern and powerful CNN, **EfficientNet-B2**, can outperform our previous ResNet18 model on the same unbalanced dataset, thereby establishing the performance ceiling for this entire approach.

In [None]:
# ===================================================================
# CELL 1: SETUP
# ===================================================================
from google.colab import drive
drive.mount('/content/drive')

# Install all necessary packages for the entire notebook
!pip install librosa pandas seaborn matplotlib tqdm

Mounted at /content/drive


In [None]:
# ===================================================================
# CELL 2: DATA PREPARATION (UNBALANCED)
# ===================================================================
import os
import random
from sklearn.model_selection import train_test_split

# --- Configuration ---
RAVDESS_PATH = "/content/drive/MyDrive/ser_project/ravdess_data/"
CREMA_D_PATH = "/content/drive/MyDrive/ser_project/crema_d_data/AudioWAV/"

# --- Mappings (6 core emotions) ---
unified_emotion_map = { "neutral": 0, "happy": 1, "sad": 2, "angry": 3, "fearful": 4, "disgust": 5 }
unified_emotion_labels = ["neutral", "happy", "sad", "angry", "fearful", "disgust"]
ravdess_map = { "01": "neutral", "03": "happy", "04": "sad", "05": "angry", "06": "fearful", "07": "disgust" }
crema_d_map = { "NEU": "neutral", "HAP": "happy", "SAD": "sad", "ANG": "angry", "FEA": "fearful", "DIS": "disgust" }

# --- Gather files and labels ---
all_files = []
all_labels_str = []
print("--- GATHERING UNBALANCED FILES ---")
# Process RAVDESS
ravdess_count = 0
for root, dirs, files in os.walk(RAVDESS_PATH):
    for f in files:
        if f.endswith('.wav'):
            try:
                code = f.split("-")[2]
                if code in ravdess_map:
                    all_files.append(os.path.join(root, f)); all_labels_str.append(ravdess_map[code])
                    ravdess_count += 1
            except IndexError: continue
print(f"Found {ravdess_count} relevant files in RAVDESS.")
# Process CREMA-D
crema_d_count = 0
for f in os.listdir(CREMA_D_PATH):
    if f.endswith('.wav'):
        try:
            code = f.split("_")[2]
            if code in crema_d_map:
                all_files.append(os.path.join(CREMA_D_PATH, f)); all_labels_str.append(crema_d_map[code])
                crema_d_count += 1
        except IndexError: continue
print(f"Found {crema_d_count} relevant files in CREMA-D.")
print(f"\nTotal files found across both datasets: {len(all_files)}")

# --- Create final data splits ---
train_val_files, test_files, train_val_labels_str, test_labels_str = train_test_split(
    all_files, all_labels_str, test_size=0.15, random_state=42, stratify=all_labels_str
)
train_files, val_files, train_labels_str, val_labels_str = train_test_split(
    train_val_files, train_val_labels_str, test_size=0.1, random_state=42, stratify=train_val_labels_str
)

print("\n--- DATA SPLITTING COMPLETE ---")
print(f"Training samples: {len(train_files)}")
print(f"Validation samples: {len(val_files)}")
print(f"Test samples: {len(test_files)}")

--- GATHERING UNBALANCED FILES ---
Found 1056 relevant files in RAVDESS.
Found 7442 relevant files in CREMA-D.

Total files found across both datasets: 8498

--- DATA SPLITTING COMPLETE ---
Training samples: 6500
Validation samples: 723
Test samples: 1275


## Part 1: Upgrading the Architecture and Scheduler

To create the strongest possible contender, we make two significant upgrades to our training process:

* **New Architecture:** We replace the `ResNet18` backbone with `EfficientNet-B2`, a more modern architecture known for its high performance and parameter efficiency.
* **New Scheduler:** We use the `OneCycleLR` scheduler, an aggressive and effective technique for achieving faster and more stable training convergence.

To ensure a fair comparison with the ResNet18 model from the `v6` experiment, we deliberately use the same **unbalanced** dataset.

In [None]:
# ===================================================================
# CELL 3: THE FINAL PUSH - TRAINING WITH EFFICIENTNET & ONECYCLELR
# ===================================================================
import torch, torch.nn as nn, os, numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm
from torch.optim.lr_scheduler import OneCycleLR
from torchvision import models
from torchvision import transforms

# --- Configuration ---
SPECTROGRAM_PATH = "/content/drive/MyDrive/ser_project/processed_spectrograms_final/"
MAX_LEARNING_RATE = 5e-4; BATCH_SIZE = 32; EPOCHS = 40
CHECKPOINT_BEST_PATH = "/content/drive/MyDrive/ser_project/efficientnet_final_push_best.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); print(f"Using device: {device}")
emotion_to_idx = {e: i for i, e in enumerate(unified_emotion_labels)}

# --- SpecAugment & Dataset Class ---
spec_augment_transform = transforms.Compose([
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.05), ratio=(0.2, 5.0), value=0),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.08), ratio=(0.01, 0.2), value=0),
])
class SpectrogramDataset(Dataset):
    def __init__(self, file_paths, labels, target_width=300):
        self.file_paths, self.labels, self.target_width = file_paths, labels, target_width
    def __len__(self): return len(self.file_paths)
    def __getitem__(self, idx):
        filename = os.path.basename(self.file_paths[idx]).replace('.wav', '.npy')
        file_path = os.path.join(SPECTROGRAM_PATH, filename)
        label = self.labels[idx]
        spectrogram = np.load(file_path)
        current_width = spectrogram.shape[1]
        if current_width < self.target_width: spectrogram = np.pad(spectrogram, ((0, 0), (0, self.target_width - current_width)), mode='constant')
        elif current_width > self.target_width: spectrogram = spectrogram[:, :self.target_width]
        spec_min, spec_max = spectrogram.min(), spectrogram.max()
        if spec_max > spec_min: spectrogram = (spectrogram - spec_min) / (spec_max - spec_min)
        spectrogram_3ch = np.stack([spectrogram, spectrogram, spectrogram], axis=0)
        return torch.tensor(spectrogram_3ch, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

# --- Create Datasets & DataLoaders ---
train_labels = [emotion_to_idx[lbl] for lbl in train_labels_str]; val_labels = [emotion_to_idx[lbl] for lbl in val_labels_str]
train_dataset = SpectrogramDataset(train_files, train_labels); val_dataset = SpectrogramDataset(val_files, val_labels)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2); val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# --- Initialize Model: EfficientNet-B2 ---
print("Initializing EfficientNet-B2 model...")
model = models.efficientnet_b2(weights='IMAGENET1K_V1')
num_ftrs = model.classifier[1].in_features; model.classifier[1] = nn.Linear(num_ftrs, len(unified_emotion_labels)); model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=MAX_LEARNING_RATE); criterion = nn.CrossEntropyLoss()
scheduler = OneCycleLR(optimizer, max_lr=MAX_LEARNING_RATE, total_steps=len(train_loader) * EPOCHS)

# --- Training Loop ---
best_val_acc = 0.0
print("Starting final push training...")
for epoch in range(EPOCHS):
    model.train(); running_loss = 0.0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
        inputs, labels = inputs.to(device), labels.to(device)
        inputs = spec_augment_transform(inputs)
        optimizer.zero_grad(); outputs = model(inputs); loss = criterion(outputs, labels)
        loss.backward(); optimizer.step(); scheduler.step(); running_loss += loss.item() * inputs.size(0)
    train_loss = running_loss / len(train_dataset)

    model.eval(); val_loss = 0.0; correct = 0; total = 0
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs); loss = criterion(outputs, labels); val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1); total += labels.size(0); correct += (predicted == labels).sum().item()
    val_accuracy = 100 * correct / total; val_loss /= len(val_dataset)
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.2f}%")

    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy; print(f"🎉 New best validation accuracy: {best_val_acc:.2f}%. Saving model..."); torch.save({'model_state_dict': model.state_dict()}, CHECKPOINT_BEST_PATH)

Using device: cuda
Initializing EfficientNet-B2 model...
Downloading: "https://download.pytorch.org/models/efficientnet_b2_rwightman-c35c1473.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b2_rwightman-c35c1473.pth


100%|██████████| 35.2M/35.2M [00:00<00:00, 194MB/s]


Starting final push training...


Epoch 1/40 [Train]: 100%|██████████| 204/204 [32:20<00:00,  9.51s/it]
Epoch 1/40 [Val]: 100%|██████████| 23/23 [03:33<00:00,  9.28s/it]


Epoch 1/40 | Train Loss: 1.7083 | Val Loss: 1.5713 | Val Acc: 36.93%
🎉 New best validation accuracy: 36.93%. Saving model...


Epoch 2/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.44it/s]
Epoch 2/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.96it/s]


Epoch 2/40 | Train Loss: 1.4840 | Val Loss: 1.3698 | Val Acc: 42.88%
🎉 New best validation accuracy: 42.88%. Saving model...


Epoch 3/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 13.76it/s]
Epoch 3/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.52it/s]


Epoch 3/40 | Train Loss: 1.2867 | Val Loss: 1.1761 | Val Acc: 54.77%
🎉 New best validation accuracy: 54.77%. Saving model...


Epoch 4/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 13.91it/s]
Epoch 4/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.16it/s]


Epoch 4/40 | Train Loss: 1.1004 | Val Loss: 1.1523 | Val Acc: 56.02%
🎉 New best validation accuracy: 56.02%. Saving model...


Epoch 5/40 [Train]: 100%|██████████| 204/204 [00:15<00:00, 13.60it/s]
Epoch 5/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.41it/s]


Epoch 5/40 | Train Loss: 0.9716 | Val Loss: 1.1306 | Val Acc: 56.98%
🎉 New best validation accuracy: 56.98%. Saving model...


Epoch 6/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 13.61it/s]
Epoch 6/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.10it/s]


Epoch 6/40 | Train Loss: 0.8300 | Val Loss: 1.0954 | Val Acc: 60.44%
🎉 New best validation accuracy: 60.44%. Saving model...


Epoch 7/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 13.84it/s]
Epoch 7/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.51it/s]


Epoch 7/40 | Train Loss: 0.7367 | Val Loss: 1.2087 | Val Acc: 56.43%


Epoch 8/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.42it/s]
Epoch 8/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.13it/s]


Epoch 8/40 | Train Loss: 0.6510 | Val Loss: 1.3307 | Val Acc: 56.71%


Epoch 9/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 13.78it/s]
Epoch 9/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.76it/s]


Epoch 9/40 | Train Loss: 0.5588 | Val Loss: 1.2706 | Val Acc: 59.06%


Epoch 10/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.11it/s]
Epoch 10/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 12.78it/s]


Epoch 10/40 | Train Loss: 0.4778 | Val Loss: 1.5273 | Val Acc: 58.64%


Epoch 11/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.01it/s]
Epoch 11/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.62it/s]


Epoch 11/40 | Train Loss: 0.3962 | Val Loss: 1.3352 | Val Acc: 61.27%
🎉 New best validation accuracy: 61.27%. Saving model...


Epoch 12/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.05it/s]
Epoch 12/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.90it/s]


Epoch 12/40 | Train Loss: 0.3442 | Val Loss: 1.7864 | Val Acc: 55.33%


Epoch 13/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.45it/s]
Epoch 13/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.41it/s]


Epoch 13/40 | Train Loss: 0.3080 | Val Loss: 1.3761 | Val Acc: 59.47%


Epoch 14/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.39it/s]
Epoch 14/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.65it/s]


Epoch 14/40 | Train Loss: 0.2447 | Val Loss: 1.5179 | Val Acc: 61.69%
🎉 New best validation accuracy: 61.69%. Saving model...


Epoch 15/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 13.69it/s]
Epoch 15/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.89it/s]


Epoch 15/40 | Train Loss: 0.2222 | Val Loss: 1.5635 | Val Acc: 62.10%
🎉 New best validation accuracy: 62.10%. Saving model...


Epoch 16/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 13.75it/s]
Epoch 16/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.86it/s]


Epoch 16/40 | Train Loss: 0.2223 | Val Loss: 1.5120 | Val Acc: 60.72%


Epoch 17/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.06it/s]
Epoch 17/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.65it/s]


Epoch 17/40 | Train Loss: 0.2230 | Val Loss: 1.4848 | Val Acc: 63.07%
🎉 New best validation accuracy: 63.07%. Saving model...


Epoch 18/40 [Train]: 100%|██████████| 204/204 [00:15<00:00, 13.55it/s]
Epoch 18/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.15it/s]


Epoch 18/40 | Train Loss: 0.1361 | Val Loss: 1.7777 | Val Acc: 60.86%


Epoch 19/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.20it/s]
Epoch 19/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.29it/s]


Epoch 19/40 | Train Loss: 0.1298 | Val Loss: 1.9125 | Val Acc: 61.00%


Epoch 20/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.13it/s]
Epoch 20/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.16it/s]


Epoch 20/40 | Train Loss: 0.1000 | Val Loss: 1.6218 | Val Acc: 64.45%
🎉 New best validation accuracy: 64.45%. Saving model...


Epoch 21/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 13.64it/s]
Epoch 21/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.78it/s]


Epoch 21/40 | Train Loss: 0.1189 | Val Loss: 1.6954 | Val Acc: 63.35%


Epoch 22/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.06it/s]
Epoch 22/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.15it/s]


Epoch 22/40 | Train Loss: 0.0930 | Val Loss: 1.6291 | Val Acc: 62.38%


Epoch 23/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 13.97it/s]
Epoch 23/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.36it/s]


Epoch 23/40 | Train Loss: 0.0880 | Val Loss: 1.7166 | Val Acc: 62.66%


Epoch 24/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.40it/s]
Epoch 24/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.20it/s]


Epoch 24/40 | Train Loss: 0.0630 | Val Loss: 1.8655 | Val Acc: 62.79%


Epoch 25/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.40it/s]
Epoch 25/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.73it/s]


Epoch 25/40 | Train Loss: 0.0614 | Val Loss: 1.7133 | Val Acc: 64.18%


Epoch 26/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.25it/s]
Epoch 26/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.76it/s]


Epoch 26/40 | Train Loss: 0.0514 | Val Loss: 1.7796 | Val Acc: 61.96%


Epoch 27/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.20it/s]
Epoch 27/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.12it/s]


Epoch 27/40 | Train Loss: 0.0417 | Val Loss: 1.7562 | Val Acc: 63.90%


Epoch 28/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.36it/s]
Epoch 28/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.08it/s]


Epoch 28/40 | Train Loss: 0.0330 | Val Loss: 1.6261 | Val Acc: 65.84%
🎉 New best validation accuracy: 65.84%. Saving model...


Epoch 29/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 13.82it/s]
Epoch 29/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.44it/s]


Epoch 29/40 | Train Loss: 0.0226 | Val Loss: 1.7424 | Val Acc: 63.35%


Epoch 30/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.31it/s]
Epoch 30/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.78it/s]


Epoch 30/40 | Train Loss: 0.0276 | Val Loss: 1.8138 | Val Acc: 64.18%


Epoch 31/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.09it/s]
Epoch 31/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.86it/s]


Epoch 31/40 | Train Loss: 0.0202 | Val Loss: 1.7802 | Val Acc: 64.73%


Epoch 32/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.29it/s]
Epoch 32/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.06it/s]


Epoch 32/40 | Train Loss: 0.0170 | Val Loss: 1.8110 | Val Acc: 64.87%


Epoch 33/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.19it/s]
Epoch 33/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.61it/s]


Epoch 33/40 | Train Loss: 0.0171 | Val Loss: 1.7431 | Val Acc: 65.42%


Epoch 34/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.35it/s]
Epoch 34/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.31it/s]


Epoch 34/40 | Train Loss: 0.0171 | Val Loss: 1.7832 | Val Acc: 64.45%


Epoch 35/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.39it/s]
Epoch 35/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.22it/s]


Epoch 35/40 | Train Loss: 0.0154 | Val Loss: 1.7493 | Val Acc: 64.73%


Epoch 36/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.26it/s]
Epoch 36/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.15it/s]


Epoch 36/40 | Train Loss: 0.0113 | Val Loss: 1.7236 | Val Acc: 64.87%


Epoch 37/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.41it/s]
Epoch 37/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.09it/s]


Epoch 37/40 | Train Loss: 0.0106 | Val Loss: 1.7442 | Val Acc: 65.28%


Epoch 38/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.46it/s]
Epoch 38/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 16.08it/s]


Epoch 38/40 | Train Loss: 0.0099 | Val Loss: 1.7554 | Val Acc: 62.79%


Epoch 39/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.16it/s]
Epoch 39/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 14.57it/s]


Epoch 39/40 | Train Loss: 0.0098 | Val Loss: 1.7880 | Val Acc: 64.18%


Epoch 40/40 [Train]: 100%|██████████| 204/204 [00:14<00:00, 14.47it/s]
Epoch 40/40 [Val]: 100%|██████████| 23/23 [00:01<00:00, 15.31it/s]

Epoch 40/40 | Train Loss: 0.0062 | Val Loss: 1.7633 | Val Acc: 64.73%





## Part 2: The Verdict - ResNet18 vs. EfficientNet-B2

This is the final comparison. We evaluate the newly trained EfficientNet model on the separate RAVDESS and CREMA-D test sets. The results are then compared directly against both our champion balanced model (v5) and the unbalanced ResNet18 model (v6) to draw a final conclusion on the effectiveness of the CNN-based approach.

In [None]:
# ===================================================================
# CELL 4: FINAL DETAILED EVALUATION OF EFFICIENTNET MODEL
# ===================================================================
import torch, torch.nn as nn, os, numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm
from torchvision import models

# --- Configuration ---
CHECKPOINT_BEST_PATH = "/content/drive/MyDrive/ser_project/efficientnet_final_push_best.pth"
SPECTROGRAM_PATH = "/content/drive/MyDrive/ser_project/processed_spectrograms_final/"
BATCH_SIZE = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); print(f"Using device: {device}")

# --- Mappings and Dataset Class (must be defined here to work) ---
unified_emotion_labels = ["neutral", "happy", "sad", "angry", "fearful", "disgust"]
emotion_to_idx = {e: i for i, e in enumerate(unified_emotion_labels)}

class SpectrogramDataset(Dataset):
    def __init__(self, file_paths, labels, target_width=300):
        self.file_paths, self.labels, self.target_width = file_paths, labels, target_width
    def __len__(self): return len(self.file_paths)
    def __getitem__(self, idx):
        filename = os.path.basename(self.file_paths[idx]).replace('.wav', '.npy')
        file_path = os.path.join(SPECTROGRAM_PATH, filename)
        label = self.labels[idx]
        spectrogram = np.load(file_path)
        current_width = spectrogram.shape[1]
        if current_width < self.target_width: spectrogram = np.pad(spectrogram, ((0, 0), (0, self.target_width - current_width)), mode='constant')
        elif current_width > self.target_width: spectrogram = spectrogram[:, :self.target_width]
        spec_min, spec_max = spectrogram.min(), spectrogram.max()
        if spec_max > spec_min: spectrogram = (spectrogram - spec_min) / (spec_max - spec_min)
        spectrogram_3ch = np.stack([spectrogram, spectrogram, spectrogram], axis=0)
        return torch.tensor(spectrogram_3ch, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

# --- Load the Best Trained Model ---
print("Loading best 'Final Push' model (EfficientNet-B2)...")
model = models.efficientnet_b2()
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, len(unified_emotion_labels))
best_checkpoint = torch.load(CHECKPOINT_BEST_PATH, map_location=device)
model.load_state_dict(best_checkpoint['model_state_dict'])
model = model.to(device)
model.eval()

# --- Prepare the Separate Test Sets ---
# Uses the 'test_files' and 'test_labels_str' variables from Cell 2
print("Preparing separate test sets for evaluation...")
test_labels = [emotion_to_idx[lbl] for lbl in test_labels_str]

# Filter for RAVDESS files
ravdess_test_files = [f for f in test_files if 'ravdess_data' in f.lower().replace('\\', '/')]
ravdess_test_labels = [l for i, l in enumerate(test_labels) if 'ravdess_data' in test_files[i].lower().replace('\\', '/')]

# Filter for CREMA-D files
crema_d_test_files = [f for f in test_files if 'crema_d_data' in f.lower().replace('\\', '/')]
crema_d_test_labels = [l for i, l in enumerate(test_labels) if 'crema_d_data' in test_files[i].lower().replace('\\', '/')]

# --- Evaluation Function ---
def evaluate(files, labels, name):
    if not files:
        print(f"\nSkipping evaluation for {name}: No test files found.")
        return

    dataset = SpectrogramDataset(files, labels)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
    all_preds, all_true = [], []
    with torch.no_grad():
        for inputs, labs in tqdm(loader, desc=f"Evaluating on {name}"):
            inputs, labs = inputs.to(device), labs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_true.extend(labs.cpu().numpy())

    accuracy = accuracy_score(all_true, all_preds)
    print(f"\n>>> Accuracy on {name}: {accuracy * 100:.2f}%")
    print(f"Classification Report for {name}:")
    print(classification_report(all_true, all_preds, target_names=unified_emotion_labels, zero_division=0))

# --- Run the Final Evaluations ---
print("\n--- FINAL DETAILED EVALUATION ---")
if ravdess_test_files:
    evaluate(ravdess_test_files, ravdess_test_labels, "RAVDESS Test Set")
if crema_d_test_files:
    evaluate(crema_d_test_files, crema_d_test_labels, "CREMA-D Test Set")

Using device: cuda
Loading best 'Final Push' model (EfficientNet-B2)...
Preparing separate test sets for evaluation...

--- FINAL DETAILED EVALUATION ---


Evaluating on RAVDESS Test Set: 100%|██████████| 5/5 [01:19<00:00, 15.97s/it]



>>> Accuracy on RAVDESS Test Set: 80.88%
Classification Report for RAVDESS Test Set:
              precision    recall  f1-score   support

     neutral       0.76      1.00      0.87        13
       happy       0.71      0.88      0.79        25
         sad       0.88      0.52      0.65        27
       angry       0.87      0.83      0.85        24
     fearful       0.76      0.96      0.85        23
     disgust       0.95      0.79      0.86        24

    accuracy                           0.81       136
   macro avg       0.82      0.83      0.81       136
weighted avg       0.83      0.81      0.80       136



Evaluating on CREMA-D Test Set: 100%|██████████| 36/36 [10:54<00:00, 18.18s/it]


>>> Accuracy on CREMA-D Test Set: 64.97%
Classification Report for CREMA-D Test Set:
              precision    recall  f1-score   support

     neutral       0.69      0.79      0.74       164
       happy       0.63      0.58      0.60       195
         sad       0.58      0.58      0.58       192
       angry       0.70      0.82      0.76       196
     fearful       0.61      0.54      0.58       197
     disgust       0.67      0.61      0.64       195

    accuracy                           0.65      1139
   macro avg       0.65      0.65      0.65      1139
weighted avg       0.65      0.65      0.65      1139




