In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
import torch
import torchaudio
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import os
from pathlib import Path
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings("ignore")

In [13]:
# 1. Constants
TARGET_SR = 22050
DURATION_S = 5        # (seconds)
NUM_SAMPLES = TARGET_SR * DURATION_S
N_MELS = 128          # Number of Mel bins (vertical resolution)
N_FFT = 1024
HOP_LENGTH = 512

# Training parameters
BATCH_SIZE = 32
EPOCHS_PHASE_1 = 20
EPOCHS_PHASE_2 = 30
LEARNING_RATE_HEAD = 1e-4
LEARNING_RATE_FINE = 1e-6

# Paths
TRAIN_DIR = Path("/content/drive/MyDrive/decibel_duel/train/train")
TEST_DIR = Path("/content/drive/MyDrive/decibel_duel/test/test")
MODEL_PATH = Path("/content/drive/MyDrive/decibel_duel")

# Get device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


In [6]:
# 2. Dataset class

mel_spectrogram_transform = torch.nn.Sequential(
    torchaudio.transforms.MelSpectrogram(
        sample_rate=TARGET_SR,
        n_mels=N_MELS,
        n_fft=N_FFT,
        hop_length=HOP_LENGTH
    ),
    torchaudio.transforms.AmplitudeToDB()
)

class AudioDataset(Dataset):
    def __init__(self, df, base_dir, transform, is_test=False):
        self.df = df
        self.base_dir = base_dir
        self.transform = transform
        self.is_test = is_test

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if self.is_test:
            filename = self.df.iloc[idx]['filename']
            filepath = self.base_dir / filename
            label_idx = -1
        else:
            filename = self.df.iloc[idx]['filepath']
            filepath = self.base_dir / filename
            label_idx = self.df.iloc[idx]['label_idx']

        # 1. Load audio
        waveform, sr = torchaudio.load(filepath)

        # 2. Resample if needed
        if sr != TARGET_SR:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SR)
            waveform = resampler(waveform)

        # 3. Convert to mono
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        # 4. Pad or truncate to fixed length
        if waveform.shape[1] < NUM_SAMPLES:
            num_missing = NUM_SAMPLES - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, num_missing))
        else:
            waveform = waveform[:, :NUM_SAMPLES]

        # 6. Apply Mel Spectrogram transform
        spectrogram = self.transform(waveform)

        # 7. Normalize (per-sample normalization)
        mean = spectrogram.mean()
        std = spectrogram.std()
        spectrogram = (spectrogram - mean) / (std + 1e-6)

        # 8. Format for EfficientNet (3-channel "image")
        spectrogram = spectrogram.expand(3, -1, -1) # [1, H, W] -> [3, H, W]

        if self.is_test:
            return spectrogram, filename
        else:
            return spectrogram, label_idx

In [7]:
# 3. Load File Paths and Create DataLoaders
print("Loading data...")

all_files = []
class_names = sorted([d.name for d in TRAIN_DIR.iterdir() if d.is_dir()])
class_to_idx = {name: i for i, name in enumerate(class_names)}
idx_to_class = {i: name for name, i in class_to_idx.items()}

for class_name in class_names:
    class_dir = TRAIN_DIR / class_name
    for filepath in class_dir.glob("*.wav"):
        all_files.append((str(filepath.relative_to(TRAIN_DIR)), class_to_idx[class_name]))

all_train_df = pd.DataFrame(all_files, columns=['filepath', 'label_idx'])

# Create stratified train/validation split
train_df, val_df = train_test_split(
    all_train_df,
    test_size=0.15, # 15% validation
    stratify=all_train_df['label_idx'],
    random_state=42
)

# Create Datasets
train_dataset = AudioDataset(train_df, TRAIN_DIR, mel_spectrogram_transform)
val_dataset = AudioDataset(val_df, TRAIN_DIR, mel_spectrogram_transform)

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"Data ready: {len(train_dataset)} train, {len(val_dataset)} validation samples.")

Loading data...
Data ready: 2932 train, 518 validation samples.


In [8]:
def train_one_epoch(model, loader, criterion, optimizer, scaler):
    model.train()
    total_loss = 0
    total_correct = 0

    for inputs, labels in tqdm(loader, desc="Training", leave=False):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        with autocast():  # Mixed precision
            outputs = base_model(inputs)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total_correct += (predicted == labels).sum().item()

    avg_loss = total_loss / len(loader.dataset)
    avg_acc = total_correct / len(loader.dataset)
    return avg_loss, avg_acc

In [14]:
def validate_model(model, loader, criterion, scheduler):
    model.eval()
    val_loss = 0
    val_correct = 0

    with torch.no_grad(), autocast():
        for inputs, labels in tqdm(loader, desc="Validating", leave=False):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * inputs.size(0)

            _, predicted = torch.max(outputs.data, 1)
            val_correct += (predicted == labels).sum().item()

    avg_loss = val_loss / len(loader.dataset)
    avg_acc = val_correct / len(loader.dataset)
    scheduler.step(val_loss)
    return avg_loss, avg_acc

In [10]:
class EarlyStopping:
    def __init__(self, patience=10, verbose=True, path='best_model.pth'):
        self.patience = patience
        self.verbose = verbose
        self.path = path
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        drivepath = MODEL_PATH / self.path
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
        torch.save(model.state_dict(), drivepath)
        self.val_loss_min = val_loss

In [11]:
# --- 4. Load Pretrained Model (EfficientNet) ---
base_model = models.efficientnet_b0(weights="IMAGENET1K_V1")

num_features = base_model.classifier[1].in_features
base_model.classifier[1] = nn.Linear(num_features, len(class_names))

base_model = base_model.to(DEVICE)

# Test with a sample batch to ensure shapes are correct
try:
    data, _ = next(iter(train_loader))
    output = base_model(data.to(DEVICE))
    print(f"Model output shape: {output.shape}") # Should be [BATCH_SIZE, 5]
except Exception as e:
    print(f"Error checking model: {e}")

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth


100%|██████████| 20.5M/20.5M [00:00<00:00, 168MB/s]


Model output shape: torch.Size([32, 5])


In [15]:
print("\n--- Starting Phase 1: Training Classifier Head ---")

optimizer = torch.optim.Adam(base_model.classifier.parameters(), lr=LEARNING_RATE_HEAD)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
scaler = GradScaler()
criterion = nn.CrossEntropyLoss()
early_stopper = EarlyStopping(path='phase1_best_model.pth')

for epoch in range(EPOCHS_PHASE_1):
    train_loss, train_acc = train_one_epoch(base_model, train_loader, criterion, optimizer,scaler)
    val_loss, val_acc = validate_model(base_model, val_loader, criterion, scheduler)

    print(f"Epoch {epoch+1}/{EPOCHS_PHASE_1} | "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    early_stopper(val_loss, base_model)
    if early_stopper.early_stop:
        print("Early stopping")
        break
print("✅ Phase 1 completed.")


--- Starting Phase 1: Training Classifier Head ---


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 1/20 | Train Loss: 1.4857, Train Acc: 0.4181 | Val Loss: 1.3389, Val Acc: 0.6062
Validation loss decreased (inf --> 1.338904). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
 Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
    File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
      self._shutdown_workers()  
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^^    ^if w.is_alive():^
^ ^ ^ ^ ^ ^  ^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^    ^assert self._parent_pid == os.getpid(), 'can only test a child process'^
   ^ ^ ^ ^   ^^ 
  File "/usr/

Epoch 2/20 | Train Loss: 1.2679, Train Acc: 0.6310 | Val Loss: 1.1600, Val Acc: 0.7124
Validation loss decreased (1.338904 --> 1.159982). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 3/20 | Train Loss: 1.1180, Train Acc: 0.6951 | Val Loss: 1.0431, Val Acc: 0.7259
Validation loss decreased (1.159982 --> 1.043088). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 4/20 | Train Loss: 1.0204, Train Acc: 0.7190 | Val Loss: 0.9475, Val Acc: 0.7452
Validation loss decreased (1.043088 --> 0.947494). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
Exception ignored in:  
 <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>^   Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
       self._shutdown_workers() 
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    ^if w.is_alive():^
^^ ^ ^ ^ ^  ^^ ^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^    assert self._parent_pid == os.getpid(), 'can only test a child process'^^
^ ^   ^ ^ ^^^ 
  File "/usr/

Epoch 5/20 | Train Loss: 0.9460, Train Acc: 0.7329 | Val Loss: 0.8807, Val Acc: 0.7490
Validation loss decreased (0.947494 --> 0.880718). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 6/20 | Train Loss: 0.8843, Train Acc: 0.7466 | Val Loss: 0.8166, Val Acc: 0.7722
Validation loss decreased (0.880718 --> 0.816636). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 7/20 | Train Loss: 0.8304, Train Acc: 0.7589 | Val Loss: 0.7951, Val Acc: 0.7645
Validation loss decreased (0.816636 --> 0.795122). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>Exception ignored in: 
Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()self._shutdown_workers()

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():    
if w.is_alive():
             ^ ^^^^^^^^^^^^^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
 ^        ^  ^^^^^^^^^^^
  Fil

Epoch 8/20 | Train Loss: 0.8068, Train Acc: 0.7589 | Val Loss: 0.7560, Val Acc: 0.7799
Validation loss decreased (0.795122 --> 0.756009). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 9/20 | Train Loss: 0.7718, Train Acc: 0.7715 | Val Loss: 0.7073, Val Acc: 0.7761
Validation loss decreased (0.756009 --> 0.707269). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 10/20 | Train Loss: 0.7450, Train Acc: 0.7691 | Val Loss: 0.7080, Val Acc: 0.7838
EarlyStopping counter: 1 out of 10


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 11/20 | Train Loss: 0.7015, Train Acc: 0.7967 | Val Loss: 0.6696, Val Acc: 0.7819
Validation loss decreased (0.707269 --> 0.669641). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 12/20 | Train Loss: 0.7035, Train Acc: 0.7855 | Val Loss: 0.6598, Val Acc: 0.7896
Validation loss decreased (0.669641 --> 0.659776). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 13/20 | Train Loss: 0.6682, Train Acc: 0.7967 | Val Loss: 0.6480, Val Acc: 0.7973
Validation loss decreased (0.659776 --> 0.647986). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 14/20 | Train Loss: 0.6692, Train Acc: 0.7947 | Val Loss: 0.6073, Val Acc: 0.8224
Validation loss decreased (0.647986 --> 0.607341). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 15/20 | Train Loss: 0.6426, Train Acc: 0.8025 | Val Loss: 0.6023, Val Acc: 0.8205
Validation loss decreased (0.607341 --> 0.602292). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 16/20 | Train Loss: 0.6430, Train Acc: 0.7998 | Val Loss: 0.6119, Val Acc: 0.8089
EarlyStopping counter: 1 out of 10


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 17/20 | Train Loss: 0.6342, Train Acc: 0.7974 | Val Loss: 0.5849, Val Acc: 0.8243
Validation loss decreased (0.602292 --> 0.584915). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>^^
^^Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^    self._shutdown_workers()^
^^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^    ^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
if w.is_alive():    
assert self._parent_pid == os.getpid(), 'can only test a child process'     
   ^  ^^ ^^ ^ ^  ^^^  ^^ 
^^  File "/u

Epoch 18/20 | Train Loss: 0.6119, Train Acc: 0.8117 | Val Loss: 0.5628, Val Acc: 0.8320
Validation loss decreased (0.584915 --> 0.562806). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 19/20 | Train Loss: 0.6062, Train Acc: 0.8114 | Val Loss: 0.5529, Val Acc: 0.8243
Validation loss decreased (0.562806 --> 0.552909). Saving model...


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Epoch 20/20 | Train Loss: 0.5929, Train Acc: 0.8104 | Val Loss: 0.5473, Val Acc: 0.8320
Validation loss decreased (0.552909 --> 0.547291). Saving model...
✅ Phase 1 completed.


In [17]:
# 5. Load previous model weights
path = MODEL_PATH / "phase1_best_model.pth"
try:
    prev_state = torch.load(path, map_location=DEVICE)
    model_dict = base_model.state_dict()
    pretrained_dict = {k: v for k, v in prev_state.items() if k in model_dict and v.shape == model_dict[k].shape}
    model_dict.update(pretrained_dict)
    base_model.load_state_dict(model_dict)
    print(f"✅ Loaded compatible weights ({len(pretrained_dict)} layers transferred)")
except Exception as e:
    print("⚠️ Could not load previous weights:", e)

# --- 7. Training Setup ---
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()
optimizer = optim.AdamW(base_model.parameters(), lr=LEARNING_RATE_FINE, weight_decay=1e-6)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
early_stopper = EarlyStopping(path="phase2_best_model.pth")

# --- 8. Fine-tuning Loop ---
for epoch in range(EPOCHS_PHASE_2):
    print(f"\n🧠 Epoch {epoch+1}/{EPOCHS_PHASE_2}")

    train_loss, train_acc = train_one_epoch(base_model, train_loader, criterion, optimizer, scaler)
    val_loss, val_acc = validate_model(base_model, val_loader, criterion, scheduler)

    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")
    early_stopper(val_loss, base_model)
    if early_stopper.early_stop:
        print("Early stopping")
        break
print("✅ Fine-tuning completed.")

✅ Loaded compatible weights (360 layers transferred)

🧠 Epoch 1/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.5884 | Train Acc: 81.21%
Val Loss: 0.5352 | Val Acc: 83.59%
Validation loss decreased (inf --> 0.535226). Saving model...

🧠 Epoch 2/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.5693 | Train Acc: 82.09%
Val Loss: 0.5248 | Val Acc: 84.17%
Validation loss decreased (0.535226 --> 0.524769). Saving model...

🧠 Epoch 3/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.5548 | Train Acc: 82.91%
Val Loss: 0.5079 | Val Acc: 84.36%
Validation loss decreased (0.524769 --> 0.507902). Saving model...

🧠 Epoch 4/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.5512 | Train Acc: 82.71%
Val Loss: 0.5077 | Val Acc: 83.59%
Validation loss decreased (0.507902 --> 0.507715). Saving model...

🧠 Epoch 5/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
      Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>^
Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^    ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^    ^if w.is_alive():^
^ ^ ^ 
   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
       assert self._parent_pid == os.getpid(), 'can only test a child process'^
^ ^ ^^  ^^  ^^  ^ ^ ^
   File "/usr

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.5317 | Train Acc: 83.80%
Val Loss: 0.4948 | Val Acc: 85.14%
Validation loss decreased (0.507715 --> 0.494799). Saving model...

🧠 Epoch 6/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.5384 | Train Acc: 83.02%
Val Loss: 0.4801 | Val Acc: 85.33%
Validation loss decreased (0.494799 --> 0.480112). Saving model...

🧠 Epoch 7/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Train Loss: 0.5261 | Train Acc: 83.87%
Val Loss: 0.4730 | Val Acc: 85.33%
Validation loss decreased (0.480112 --> 0.472985). Saving model...

🧠 Epoch 8/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.5209 | Train Acc: 83.56%
Val Loss: 0.4684 | Val Acc: 85.33%
Validation loss decreased (0.472985 --> 0.468406). Saving model...

🧠 Epoch 9/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.5156 | Train Acc: 82.91%
Val Loss: 0.4672 | Val Acc: 85.33%
Validation loss decreased (0.468406 --> 0.467201). Saving model...

🧠 Epoch 10/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220> 
Traceback (most recent call last):
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
        self._shutdown_workers()
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
^    ^if w.is_alive():^
^ ^    ^^ ^ ^^^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^    ^assert self._parent_pid == os.getpid(), 'can only test a child process'^
 ^^ ^^ 
   File "/usr/lib/py

Train Loss: 0.4963 | Train Acc: 85.10%
Val Loss: 0.4495 | Val Acc: 85.71%
Validation loss decreased (0.467201 --> 0.449548). Saving model...

🧠 Epoch 11/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.4973 | Train Acc: 84.75%
Val Loss: 0.4487 | Val Acc: 86.10%
Validation loss decreased (0.449548 --> 0.448654). Saving model...

🧠 Epoch 12/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.4905 | Train Acc: 85.06%
Val Loss: 0.4499 | Val Acc: 86.49%
EarlyStopping counter: 1 out of 10

🧠 Epoch 13/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.4799 | Train Acc: 85.37%
Val Loss: 0.4290 | Val Acc: 86.10%
Validation loss decreased (0.448654 --> 0.428952). Saving model...

🧠 Epoch 14/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.4710 | Train Acc: 85.91%
Val Loss: 0.4451 | Val Acc: 86.49%
EarlyStopping counter: 1 out of 10

🧠 Epoch 15/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.4714 | Train Acc: 85.30%
Val Loss: 0.4165 | Val Acc: 87.45%
Validation loss decreased (0.428952 --> 0.416497). Saving model...

🧠 Epoch 16/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
 Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
      self._shutdown_workers() 
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
     if w.is_alive():^
^  ^  ^^^ ^ ^ ^^^^^^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^    ^assert self._parent_pid == os.getpid(), 'can only test a child process'^^
^
  File "/usr/lib/python

Train Loss: 0.4668 | Train Acc: 85.54%
Val Loss: 0.4222 | Val Acc: 86.49%
EarlyStopping counter: 1 out of 10

🧠 Epoch 17/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.4592 | Train Acc: 85.71%
Val Loss: 0.4068 | Val Acc: 87.45%
Validation loss decreased (0.416497 --> 0.406771). Saving model...

🧠 Epoch 18/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.4516 | Train Acc: 86.12%
Val Loss: 0.4097 | Val Acc: 88.22%
EarlyStopping counter: 1 out of 10

🧠 Epoch 19/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.4413 | Train Acc: 85.57%
Val Loss: 0.4102 | Val Acc: 87.84%
EarlyStopping counter: 2 out of 10

🧠 Epoch 20/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>^
^^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
^^^^    
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
self._shutdown_workers()
    assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    
if w.is_alive(): 
              ^  ^ ^^^^^^^^^^^^^^^^^^^^^^^

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.4388 | Train Acc: 86.49%
Val Loss: 0.3973 | Val Acc: 88.22%
Validation loss decreased (0.406771 --> 0.397290). Saving model...

🧠 Epoch 21/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.4312 | Train Acc: 86.56%
Val Loss: 0.3890 | Val Acc: 88.22%
Validation loss decreased (0.397290 --> 0.389043). Saving model...

🧠 Epoch 22/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.4132 | Train Acc: 87.62%
Val Loss: 0.3853 | Val Acc: 88.42%
Validation loss decreased (0.389043 --> 0.385325). Saving model...

🧠 Epoch 23/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^Exception ignored in: ^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    
self._shutdown_workers()  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    
assert self._parent_pid == os.getpid(), 'can only test a child process'  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

      if w.is_alive(): 
              ^ ^^^^^^^^^^^^^^^^^^^^^
^ 

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.4232 | Train Acc: 86.39%
Val Loss: 0.3900 | Val Acc: 88.42%
EarlyStopping counter: 1 out of 10

🧠 Epoch 24/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.4074 | Train Acc: 87.59%
Val Loss: 0.3736 | Val Acc: 88.61%
Validation loss decreased (0.385325 --> 0.373601). Saving model...

🧠 Epoch 25/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.4189 | Train Acc: 86.83%
Val Loss: 0.3755 | Val Acc: 89.19%
EarlyStopping counter: 1 out of 10

🧠 Epoch 26/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__

Traceback (most recent call last):
      File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
 ^self._shutdown_workers()    
self._shutdown_workers()  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():    
if w.is_alive():
           ^^  ^^ ^^^^^^^^^^^^^^^^^
^  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
^^    
assert self._parent_pid == os.getpid(), 'can only test a child process'
  File "/usr/lib/pyth

Train Loss: 0.4002 | Train Acc: 87.86%
Val Loss: 0.3720 | Val Acc: 89.19%
Validation loss decreased (0.373601 --> 0.372000). Saving model...

🧠 Epoch 27/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.3978 | Train Acc: 87.76%
Val Loss: 0.3616 | Val Acc: 88.80%
Validation loss decreased (0.372000 --> 0.361584). Saving model...

🧠 Epoch 28/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.3927 | Train Acc: 88.03%
Val Loss: 0.3594 | Val Acc: 89.00%
Validation loss decreased (0.361584 --> 0.359400). Saving model...

🧠 Epoch 29/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.4035 | Train Acc: 87.35%
Val Loss: 0.3558 | Val Acc: 89.38%
Validation loss decreased (0.359400 --> 0.355850). Saving model...

🧠 Epoch 30/30


Training:   0%|          | 0/92 [00:00<?, ?it/s]

Validating:   0%|          | 0/17 [00:00<?, ?it/s]

Train Loss: 0.3864 | Train Acc: 88.23%
Val Loss: 0.3498 | Val Acc: 89.19%
Validation loss decreased (0.355850 --> 0.349820). Saving model...
✅ Fine-tuning completed.


In [18]:
# --- 6. Inference and Submission ---
print("\n--- Starting Inference on Test Set ---")

# Load the best model saved
path = MODEL_PATH / "phase2_best_model.pth"
base_model.load_state_dict(torch.load(path))
base_model.eval()

# Create test dataset and loader
test_files = [f.name for f in TEST_DIR.glob("*.wav")]
test_df = pd.DataFrame(test_files, columns=['filename'])
test_dataset = AudioDataset(test_df, TEST_DIR, mel_spectrogram_transform, is_test=True)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2
)

predictions = []
filenames_list = []

with torch.no_grad():
    for inputs, filenames in tqdm(test_loader, desc="Predicting"):
        inputs = inputs.to(DEVICE)

        outputs = base_model(inputs)

        # Get the predicted class index
        _, predicted_indices = torch.max(outputs.data, 1)

        # Store predictions and filenames
        predictions.extend(predicted_indices.cpu().numpy())
        filenames_list.extend(filenames)

# Map indices back to class names
predicted_labels = [idx_to_class[idx] for idx in predictions]

# Create submission DataFrame
submission_df = pd.DataFrame({
    'ID': filenames_list,
    'Class': predicted_labels
})

# Save to CSV
submission_df.to_csv('submission.csv', index=False)
print("\nSubmission file created successfully: 'submission.csv'")
print(submission_df.head())


--- Starting Inference on Test Set ---


Predicting:   0%|          | 0/24 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7c0b9a398220>

  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
Traceback (most recent call last):
      File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    assert self._parent_pid == os.getpid(), 'can only test a child process'self._shutdown_workers()

   File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
      if w.is_alive(): 
             ^ ^^^^^^^^^^^^^^^^^^^^^^^^


Submission file created successfully: 'submission.csv'
                  ID          Class
0  108357-9-0-15.wav   street_music
1  113601-9-0-22.wav   street_music
2  106015-5-0-15.wav  engine_idling
3   112075-5-0-1.wav       drilling
4   105289-8-0-1.wav          siren
