# Music Spectrogram Classification (Custom CNN Implementation)

This notebook implements an **end-to-end training + evaluation + submission** pipeline for music spectrogram classification, using a CNN built entirely **without `torch.nn` high-level layers**. All layers (convolution, pooling, activations, batch norm, dropout, linear) are custom-implemented.

---

## Model Architecture

### Custom Layer Implementations
All layers built manually using only PyTorch primitives:
- **`CustomConv2d`**: Convolution via `unfold + batch matmul` (no `F.conv2d`)
- **`CustomMaxPool2d`**: Max pooling via `unfold + max`
- **`CustomBatchNorm2d`**: Batch normalization with running statistics
- **`CustomLinear`**: Fully connected layer with manual matrix multiplication
- **`CustomDropout`**: Dropout with manual masking
- **`leaky_relu`**: LeakyReLU activation function

### Triple-Branch CNN Architecture
```
Input: 3 images (96×96×3 each)
    ↓
Branch 1, 2, 3 (independent):
  - Block 1: Conv(3→16) → BN(Batch Normalisation) → LeakyReLU → MaxPool → Dropout(0.00)
  - Block 2: Conv(16→32) → BN → LeakyReLU → MaxPool → Dropout(0.05)
  - Block 3: Conv(32→48) → BN → LeakyReLU → MaxPool → Dropout(0.08)
  - Block 4: Conv(48→64) → BN → LeakyReLU → MaxPool → Dropout(0.10)
  - Block 5: Conv(64→64) → BN → LeakyReLU → Dropout(0.10)
  - Global Average Pooling → 64 features
    ↓
Concatenate: 3 × 64 = 192 features
    ↓
Fusion Head:
  - Linear(192→128) → LeakyReLU → Dropout(0.30)
  - Linear(128→16) → Logits
```
**Key Features**:
- Progressive dropout schedule (0% → 10% through depth)
- Global average pooling for parameter efficiency
- ~278K parameters (well under 500K limit)

---

## Experimentation History

### Evolution to Final Architecture

**Experiment 1: Initial Baseline (Macro F1 = 0.4893)**
- **Image size**: 64×64
- **Activation**: ReLU
- **Epochs**: 10
- **Architecture**: Triple-branch CNN with basic blocks
- **Regularization**: None
- **Results**: Poor performance, underfitting
- **Key Issues**: 
  - Small image size lost important spectral details
  - ReLU caused dying neuron problems
  - Insufficient training time
  - No regularization led to poor generalization

**Experiment 2: Improved Architecture (Macro F1 ≈ 0.876)**
- **Image size**: 96×96 (increased from 64×64)
- **Activation**: LeakyReLU (replaced ReLU)
- **Epochs**: 30 (increased from 10)
- **Regularization**: Added
  - Batch normalization after each conv layer
  - Label smoothing (0.1)
  - Progressive dropout (0.0 → 0.1)
  - Weight decay (1e-4)
- **Results**: Major improvement in performance
- **Key Improvements**:
  - Larger images preserved spectral details
  - LeakyReLU prevented dying neurons
  - More epochs allowed better convergence
  - Regularization improved generalization

**Experiment 3: Extended Training (Macro F1 = 0.89)**
- **Image size**: 96×96
- **Activation**: LeakyReLU
- **Epochs**: 40 
- **Regularization**: Same as Experiment 2
- **Additional changes**:
  - Added OneCycleLR scheduler
  - Gradient clipping (1.0)
- **Results**: Further improvement from extended training
- **Key Insight**: Model still improving, not saturated

**Experiment 4: Final Architecture (Macro F1 ≈ 0.90)** ✓
- **Image size**: 96×96
- **Activation**: LeakyReLU
- **Epochs**: 50 
- **Optimization enhancements**:
  - OneCycleLR (1e-3 → 3e-3 → 3e-5)
  - EMA weight averaging (decay=0.999)
- **Regularization**:
  - Progressive dropout (0.0 → 0.1 in conv, 0.3 in FC)
  - Batch normalization
  - Weight decay (1e-4)
- **Results**: Best performance achieved

## Data Preprocessing pipeline

- **Dataset**: Reads `train/metadata.csv` and `test/metadata.csv`
- **Triple-input format**: Loads **three images per sample** (`input_1`, `input_2`, `input_3`)
- **Image size**: Resizes to **96×96**
- **Augmentation**: Spectrogram-safe transformations only
  - Horizontal flip (p=0.5) - safe for time-axis
  - SpecAugment (time & frequency masking with p=0.7)
  - **No vertical flips or rotations** (would destroy frequency structure)

---

## Training Details

### Hyperparameters
| Parameter | Value | Description |
|-----------|-------|-------------|
| **Image Size** | 96×96 | Input resolution |
| **Batch Size** | 32 | Training batch size |
| **Epochs** | 50 | Total training epochs |
| **Optimizer** | Adam | With weight decay |
| **Learning Rate** | 1e-3 → 3e-3 | OneCycleLR schedule |
| **Weight Decay** | 1e-4 | L2 regularization |
| **Label Smoothing** | 0.1 | Prevents overconfidence |
| **MixUp Alpha** | 0.3 | Data mixing augmentation |
| **Gradient Clip** | 1.0 | Max gradient norm |
| **Loss Function** | Cross-entropy | Cross-entropy with label smoothing + MixUp |

### Training Strategy
- **Loss Function**: Cross-entropy with label smoothing + MixUp
- **LR Scheduler**: `OneCycleLR` stepped **per batch**
  - Warmup: 30% of training
  - Max LR: 3e-3
  - Final LR: 3e-5
- **Gradient Clipping**: Max norm 1.0 for stability
- **EMA**: Exponential moving average (decay=0.999) for smoother weights
- **Model Selection**: Best model saved based on **validation macro-F1**

### Regularization Techniques
1. Progressive dropout (0.0 → 0.10 in conv blocks, 0.30 in FC)
2. Label smoothing (0.1)
3. MixUp augmentation (α=0.3)
4. SpecAugment (time/freq masking)
5. Weight decay (1e-4)
6. Batch normalization (implicit regularization)

---

## Evaluation Strategy and metrics

After training completes, the notebook prints comprehensive validation metrics:

### Reported Metrics
- ✅ **Total trainable parameters** (enforces ≤500K constraint)
- ✅ **Accuracy**: Overall classification accuracy
- ✅ **Macro-Precision**: Average precision across all 16 classes
- ✅ **Macro-Recall**: Average recall across all 16 classes
- ✅ **Macro-F1**: Primary metric for model selection
- ✅ **Confusion Matrix**: 16×16 matrix showing per-class performance

### Validation Strategy
- **Split**: 15% stratified validation set
- **Evaluation**: Using EMA weights 

---

## Hyperparameter Tuning Strategy

We tuned hyperparameters using a **controlled, validation-driven process** with the goal of maximising **validation Macro-F1** on a fixed stratified split, while keeping the model within the **500K parameter constraint**.

### Fixed Protocol for Fair Comparison

- **Same stratified split every run**: `train_test_split` with fixed random state
- **Same random seed**: `SEED=42` for Python/NumPy/PyTorch
- **Primary selection metric**: Validation Macro-F1 (not accuracy)
- **EMA evaluation**: Model selection based on Macro-F1 evaluated using EMA-smoothed weights (`EMA.apply_to(model)` during validation)
- **One change group at a time**: Tune one group (LR, resolution, regularisation, augmentation) while keeping others fixed

---

### Stage A — Optimisation Stability (Learning Rate + OneCycleLR)

Because training uses Adam + OneCycleLR, we tuned the schedule primarily via `MAX_LR` (and kept schedule shape constants unless needed):

**Tuned (coarse → fine):**
- `MAX_LR ∈ {1e-3, 2e-3, 3e-3}`
- Corresponding base LR values: `LR ∈ {3e-4, 7e-4, 1e-3}`

**If training showed late-epoch instability/metric drop:**
- Reduced `LR`: 1e-3 → 7e-4 → 5e-4
- Reduced `MAX_LR`: 3e-3 → 2e-3 → 1.5e-3

**Scheduler shape (tuned only if needed):**
- `pct_start ∈ {0.2, 0.3, 0.4}`
- `final_div_factor ∈ {50, 100, 200}`

**Selection rule:** Choose the setting with the best EMA Macro-F1 and smooth training (no degradation), then confirm in a full 50-epoch run.

### Stage B — Input Resolution (IMG_SIZE) with Batch Size Scaling

Since the model uses global average pooling in each branch (`global_avg_pool`), increasing `IMG_SIZE` does not increase FC input size. We tested:

**Resolutions tested:**
- `IMG_SIZE ∈ {96, 112, 128}`

**Adjusted BATCH_SIZE for GPU memory:**
- 96 → 32
- 112 → 24–32
- 128 → 16–24

**Selection rule:** Pick the resolution that improves Macro-F1 **without overfitting** (Macro-F1 plateaus or drops while train loss keeps decreasing).


### Stage C — Regularisation (Generalisation Control)

We tuned regularisation to improve Macro-F1 (especially hard classes) and reduce overfitting:

**Parameters tuned:**
- `WD ∈ {5e-5, 1e-4, 2e-4}` (weight decay)
- `LABEL_SMOOTH ∈ {0.05, 0.10, 0.15}` (applied inside `F.cross_entropy(..., label_smoothing=...)`)
- `GRAD_CLIP ∈ {0.5, 1.0}` (keep 1.0 unless gradients spike)

**Selection rule:** Choose the combination that improves Macro-F1 and macro-recall, **not just accuracy**.


### Stage D — Augmentation and Robustness (SpecAugment + MixUp + EMA)

We tuned augmentation strength because the task is spectrogram-like and benefits from invariance:

**SpecAugment (implemented in dataset):**
- `(time_mask, freq_mask) ∈ {(12,12), (16,16)}`
- `p ∈ {0.5, 0.7, 0.85}`

**MixUp:**
- `MIX_ALPHA ∈ {0.2, 0.3, 0.4}`

**EMA decay:**
- `decay ∈ {0.998, 0.999}`

**Selection rule:** Prioritise the config that reduces validation confusion (confusion matrix) and raises Macro-F1 **consistently across epochs**.

---

## Inference & Submission

### Test Set Processing
1. Loads test metadata from `test/metadata.csv`
2. Applies same preprocessing as validation (**no augmentation**)
3. Runs inference using best EMA model
4. Generates predictions in **exact metadata order**

### Submission File Creation
- **Format**: CSV with columns `id,target`
- **Output**: `/kaggle/working/submission96.csv`
- **Ordering**: Preserves original test set order

## Results on Validation dataset

### Validation Performance
Total trainable parameters: 278,912
- **Macro-F1**:  0.927060
- **Accuracy**:  0.932498
- **Macro-Precision**: 0.929580
- **Macro-Recall**: 0.925461
- **Training Time**: ~2 hours on Kaggle GPU (50 epochs)

### Model Characteristics
- **Architecture**: Triple-branch CNN with late fusion
- **Parameters**: 278,912
- **Regularization**: Progressive dropout + MixUp + Augmentation + EMA
- **Efficiency**: Global average pooling reduces parameters significantly

---

## Team Contributions

**Group 19**

| Name | Email | Contributions |
|------|-------|--------------|
| Sontam Deekshitha | sontamd22@iitk.ac.in | 14.29% |
| Mahathi Garapati | gnagal22@iitk.ac.in | 14.29% |
| Maradana Kasi Sri Roshan | mkasi22@iitk.ac.in | 14.29% |
| Chintapudi Gowtham Chand | cgchand22@iitk.ac.in | 14.29% |
| Krishna Kumar Bais | krishnakb24@iitk.ac.in | 14.29% |
| Sevak Baliram Shekokar | bssevak24@iitk.ac.in | 14.29% |
| Daksh Kumar Singh | dakshks22@iitk.ac.in | 14.29%|

**Team Coordination**: All members of the team contributed equally

---

In [1]:
import os, math, random, warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from tqdm import tqdm

## **Reproducibility, Device, and Dataset Path Detection**

In [2]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Choose device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

def find_dataset_music_root():
    """
    Automatically finds the dataset root directory in Kaggle's input folder.
    Handles different possible folder structures.
    """
    base = "/kaggle/input"
    if not os.path.exists(base):
        raise FileNotFoundError("/kaggle/input not found. Are you running on Kaggle?")

    candidates = []
    for d in os.listdir(base):
        p = os.path.join(base, d)
        if not os.path.isdir(p):
            continue

        # Check if dataset_music subfolder exists
        cand1 = os.path.join(p, "dataset_music")
        if os.path.exists(os.path.join(cand1, "train", "metadata.csv")) and os.path.exists(os.path.join(cand1, "test", "metadata.csv")):
            candidates.append(cand1)

        # Check if train/test folders are directly in this directory
        if os.path.exists(os.path.join(p, "train", "metadata.csv")) and os.path.exists(os.path.join(p, "test", "metadata.csv")):
            candidates.append(p)

    if len(candidates) == 0:
        # Debug: print what's actually in /kaggle/input
        print("Could not auto-detect dataset root. Contents of /kaggle/input:")
        for d in os.listdir(base):
            print(" -", d)
        raise FileNotFoundError("No folder found containing train/metadata.csv and test/metadata.csv")

    # Return the first valid match
    return candidates[0]

DATA_ROOT = find_dataset_music_root()
TRAIN_META = os.path.join(DATA_ROOT, "train", "metadata.csv")
TEST_META  = os.path.join(DATA_ROOT, "test", "metadata.csv")

print("Detected DATA_ROOT:", DATA_ROOT)
print("TRAIN_META:", TRAIN_META)
print("TEST_META :", TEST_META)

assert os.path.exists(TRAIN_META), f"TRAIN_META not found: {TRAIN_META}"
assert os.path.exists(TEST_META),  f"TEST_META not found:  {TEST_META}"


Device: cuda
Detected DATA_ROOT: /kaggle/input/music-dataset/dataset_music
TRAIN_META: /kaggle/input/music-dataset/dataset_music/train/metadata.csv
TEST_META : /kaggle/input/music-dataset/dataset_music/test/metadata.csv


## **Image Preprocessing Utilities**

In [3]:
MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)

def pil_resize(img, size):
    """Resize image to square dimension"""
    return img.resize((size, size), Image.BILINEAR)

def pil_hflip(img, p=0.5):
    """Horizontal flip with probability p (safe for spectrograms)"""
    if random.random() < p:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img

def pil_to_tensor_norm(img):
    """Convert PIL image to normalized tensor"""
    arr = np.asarray(img).astype(np.float32) / 255.0
    arr = (arr - MEAN) / STD
    arr = np.transpose(arr, (2,0,1))
    return torch.tensor(arr, dtype=torch.float32)

def spec_augment(x, time_mask=12, freq_mask=12, p=0.7):
    """
    SpecAugment: randomly mask time and frequency bands.
    This helps the model generalize better to spectrograms.
    """
    if random.random() > p:
        return x
    C, H, W = x.shape

    # Frequency masking (vertical strips)
    fm = random.randint(0, freq_mask)
    if fm > 0 and H > fm:
        f0 = random.randint(0, H - fm)
        x[:, f0:f0+fm, :] = 0

    # Time masking (horizontal strips)
    tm = random.randint(0, time_mask)
    if tm > 0 and W > tm:
        t0 = random.randint(0, W - tm)
        x[:, :, t0:t0+tm] = 0

    return x

## Dataset Class

In [5]:
class MusicTripleDataset(Dataset):
    """
    Loads three images per sample (input_1, input_2, input_3).
    Handles both training set (with labels) and test set (without labels).
    """
    def __init__(self, meta_csv, img_size=96, is_train=True, df=None):
        self.meta_csv = meta_csv
        self.img_size = img_size
        self.is_train = is_train

        if df is None:
            self.df = pd.read_csv(meta_csv)
        else:
            self.df = df.reset_index(drop=True)

        # Handle different column names for train vs test
        self.col1 = "input_1_path" if "input_1_path" in self.df.columns else "input_1"
        self.col2 = "input_2"
        self.col3 = "input_3"
        self.has_labels = "target" in self.df.columns

        # Base directory for resolving relative paths
        self.base_root = os.path.dirname(os.path.dirname(meta_csv))

    def __len__(self):
        return len(self.df)

    def _resolve_path(self, p):
        """
        Handles various path formats in the metadata CSV.
        Tries multiple strategies to find the actual file.
        """
        p = str(p).strip()
        if os.path.isabs(p) and os.path.exists(p):
            return p

        # Try relative to current working directory
        p1 = os.path.normpath(os.path.join(os.getcwd(), p))
        if os.path.exists(p1):
            return p1

        # Strip common prefixes and try relative to base_root
        rel = p.replace("./", "")
        rel = rel.replace("dataset_music/", "")
        p2 = os.path.join(self.base_root, rel)
        if os.path.exists(p2):
            return p2

        # Last resort: just join with base_root
        p3 = os.path.join(self.base_root, p)
        return p3

    def _load_one(self, path):
        """Load and preprocess a single image"""
        fp = self._resolve_path(path)
        img = Image.open(fp).convert("RGB")
        img = pil_resize(img, self.img_size)
        if self.is_train:
            img = pil_hflip(img, p=0.5)
        x = pil_to_tensor_norm(img)
        if self.is_train:
            x = spec_augment(x, time_mask=12, freq_mask=12, p=0.7)
        return x

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        x1 = self._load_one(row[self.col1])
        x2 = self._load_one(row[self.col2])
        x3 = self._load_one(row[self.col3])
        if self.has_labels:
            y = int(row["target"])
            return x1, x2, x3, torch.tensor(y, dtype=torch.long)
        return x1, x2, x3


## Custom Layers - Activation, Dropout, Batch Norm

In [6]:
def leaky_relu(x, neg=0.01):
    """LeakyReLU activation (prevents dying neurons)"""
    return torch.where(x > 0, x, neg * x)

class CustomDropout(nn.Module):
    """Dropout layer built manually"""
    def __init__(self, p=0.3):
        super().__init__()
        self.p = p
    
    def forward(self, x):
        if (not self.training) or self.p <= 0:
            return x
        keep = 1.0 - self.p
        mask = (torch.rand_like(x) < keep).float() / keep
        return x * mask

class CustomBatchNorm2d(nn.Module):
    """
    Batch normalization for 2D feature maps.
    Normalizes across batch and spatial dimensions.
    """
    def __init__(self, c, eps=1e-5, momentum=0.1):
        super().__init__()
        self.eps = eps
        self.momentum = momentum
        self.gamma = nn.Parameter(torch.ones(c))
        self.beta  = nn.Parameter(torch.zeros(c))
        self.register_buffer("running_mean", torch.zeros(c))
        self.register_buffer("running_var",  torch.ones(c))

    def forward(self, x):
        if self.training:
            # Compute batch statistics
            mean = x.mean(dim=(0,2,3))
            var  = x.var(dim=(0,2,3), unbiased=False)
            # Update running statistics
            with torch.no_grad():
                self.running_mean.mul_(1-self.momentum).add_(mean, alpha=self.momentum)
                self.running_var.mul_(1-self.momentum).add_(var,  alpha=self.momentum)
        else:
            # Use running statistics during eval
            mean = self.running_mean
            var  = self.running_var

        # Normalize and apply learnable affine transform
        xhat = (x - mean.view(1,-1,1,1)) / torch.sqrt(var.view(1,-1,1,1) + self.eps)
        return xhat * self.gamma.view(1,-1,1,1) + self.beta.view(1,-1,1,1)


## Custom Layers - Convolution and Max Pooling

In [7]:
class CustomConv2d(nn.Module):
    """
    2D convolution using unfold and batch matrix multiplication.
    Implements convolution without using torch.nn.Conv2d.
    """
    def __init__(self, in_c, out_c, k=3, stride=1, padding=1, bias=True):
        super().__init__()
        self.out_c = out_c
        self.k = k
        self.stride = stride
        self.padding = padding
        
        # He initialization (good for ReLU-like activations)
        fan_in = in_c * k * k
        w = torch.randn(out_c, in_c, k, k) * math.sqrt(2.0 / fan_in)
        self.weight = nn.Parameter(w)
        self.bias = nn.Parameter(torch.zeros(out_c)) if bias else None

    def forward(self, x):
        B, C, H, W = x.shape
        
        # Add padding if needed
        if self.padding > 0:
            x = F.pad(x, (self.padding, self.padding, self.padding, self.padding))
        
        # Extract patches using unfold
        patches = F.unfold(x, kernel_size=self.k, stride=self.stride)  # (B, C*k*k, L)
        CK2 = patches.size(1)
        
        # Reshape weight for matrix multiplication
        Wmat = self.weight.view(self.out_c, CK2).unsqueeze(0).expand(B, -1, -1)  # (B, out_c, CK2)
        
        # Compute convolution via batch matmul
        out = torch.bmm(Wmat, patches)  # (B, out_c, L)
        
        if self.bias is not None:
            out = out + self.bias.view(1, self.out_c, 1)
        
        # Reshape back to 2D feature map
        Hp, Wp = x.shape[2], x.shape[3]
        Hout = (Hp - self.k)//self.stride + 1
        Wout = (Wp - self.k)//self.stride + 1
        return out.view(B, self.out_c, Hout, Wout)

class CustomMaxPool2d(nn.Module):
    """Max pooling using unfold and max operation"""
    def __init__(self, k=2, stride=2, padding=0):
        super().__init__()
        self.k = k
        self.stride = stride
        self.padding = padding
    
    def forward(self, x):
        if self.padding > 0:
            x = F.pad(x, (self.padding, self.padding, self.padding, self.padding), value=float("-inf"))
        
        B, C, H, W = x.shape
        
        # Extract patches
        patches = F.unfold(x, kernel_size=self.k, stride=self.stride)  # (B, C*k*k, L)
        L = patches.size(2)
        patches = patches.view(B, C, self.k*self.k, L)
        
        # Take max over each patch
        out, _ = patches.max(dim=2)  # (B, C, L)
        
        # Reshape to 2D
        Hout = (H - self.k)//self.stride + 1
        Wout = (W - self.k)//self.stride + 1
        return out.view(B, C, Hout, Wout)


## Custom Layers - Fully Connected and Global Pooling

In [9]:
class CustomLinear(nn.Module):
    """Fully connected layer built manually"""
    def __init__(self, in_f, out_f, bias=True):
        super().__init__()
        bound = 1.0 / math.sqrt(in_f)
        self.weight = nn.Parameter(torch.empty(out_f, in_f).uniform_(-bound, bound))
        self.bias   = nn.Parameter(torch.zeros(out_f)) if bias else None
    
    def forward(self, x):
        y = x @ self.weight.t()
        if self.bias is not None:
            y = y + self.bias
        return y

def global_avg_pool(x):
    """Global average pooling: reduces (B,C,H,W) to (B,C)"""
    return x.mean(dim=(2,3))

## Convolutional Blocks and Branch Architecture

In [10]:
class ConvBlock(nn.Module):
    """
    A single convolutional block:
    Conv -> BatchNorm -> LeakyReLU -> [optional MaxPool] -> [optional Dropout]
    """
    def __init__(self, in_c, out_c, pool=True, drop=0.0):
        super().__init__()
        self.conv = CustomConv2d(in_c, out_c, k=3, stride=1, padding=1, bias=True)
        self.bn   = CustomBatchNorm2d(out_c)
        self.pool = CustomMaxPool2d(k=2, stride=2) if pool else None
        self.drop = CustomDropout(drop) if drop > 0 else None
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = leaky_relu(x, 0.01)
        if self.pool is not None:
            x = self.pool(x)
        if self.drop is not None:
            x = self.drop(x)
        return x

class BranchCNN(nn.Module):
    """
    Single branch of the triple-branch architecture.
    Five conv blocks with progressive dropout, ending with global average pooling.
    """
    def __init__(self):
        super().__init__()
        self.b1 = ConvBlock(3, 16, pool=True,  drop=0.00)
        self.b2 = ConvBlock(16, 32, pool=True, drop=0.05)
        self.b3 = ConvBlock(32, 48, pool=True, drop=0.08)
        self.b4 = ConvBlock(48, 64, pool=True, drop=0.10)
        self.b5 = ConvBlock(64, 64, pool=False, drop=0.10)  # No pooling in last block
    
    def forward(self, x):
        x = self.b1(x)
        x = self.b2(x)
        x = self.b3(x)
        x = self.b4(x)
        x = self.b5(x)
        return global_avg_pool(x)  # Output: (B, 64)

## Complete Model Architecture

In [11]:
class TripleBranchNet(nn.Module):
    """
    Main model: three independent CNN branches + fusion head.
    Each branch processes one input image, then features are concatenated.
    """
    def __init__(self, num_classes=16):
        super().__init__()
        self.branch1 = BranchCNN()
        self.branch2 = BranchCNN()
        self.branch3 = BranchCNN()
        
        # Fusion head: concatenate 3x64 features, then classify
        self.fc1 = CustomLinear(64*3, 128)
        self.dp  = CustomDropout(0.30)
        self.fc2 = CustomLinear(128, num_classes)
    
    def forward(self, x1, x2, x3):
        # Process each input through its own branch
        f1 = self.branch1(x1)
        f2 = self.branch2(x2)
        f3 = self.branch3(x3)
        
        # Concatenate features
        fused = torch.cat([f1, f2, f3], dim=1)
        
        # Classification head
        x = self.fc1(fused)
        x = leaky_relu(x, 0.01)
        x = self.dp(x)
        return self.fc2(x)

def count_params(model):
    """Count trainable parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



## Training Utilities - MixUp and EMA

In [12]:
def mixup_batch(x1, x2, x3, y, alpha=0.3):
    """
    MixUp: create virtual training samples by mixing pairs of examples.
    Helps model generalize better by encouraging linear behavior between examples.
    """
    if alpha <= 0:
        return x1, x2, x3, y, None
    
    lam = float(np.random.beta(alpha, alpha))
    idx = torch.randperm(x1.size(0), device=x1.device)
    y2 = y[idx]
    
    # Mix the inputs
    x1m = lam*x1 + (1-lam)*x1[idx]
    x2m = lam*x2 + (1-lam)*x2[idx]
    x3m = lam*x3 + (1-lam)*x3[idx]
    
    return x1m, x2m, x3m, y, (y2, lam)

def mixup_loss(logits, y, mix_info=None, label_smoothing=0.1):
    """Compute loss for possibly-mixed batch"""
    if mix_info is None:
        return F.cross_entropy(logits, y, label_smoothing=label_smoothing)
    
    y2, lam = mix_info
    l1 = F.cross_entropy(logits, y,  reduction="none", label_smoothing=label_smoothing)
    l2 = F.cross_entropy(logits, y2, reduction="none", label_smoothing=label_smoothing)
    return (lam*l1 + (1-lam)*l2).mean()

class EMA:
    """
    Maintains exponential moving average of model weights.
    This gives us a smoother version of the model that often generalizes better.
    """
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
    
    @torch.no_grad()
    def update(self, model):
        """Update shadow weights after each training step"""
        for k, v in model.state_dict().items():
            self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay)
    
    def apply_to(self, model):
        """Load EMA weights into model"""
        model.load_state_dict(self.shadow, strict=True)


## Evaluation and Gradient Clipping

In [14]:
@torch.no_grad()
def evaluate(model, loader):
    """Run evaluation and compute metrics"""
    model.eval()
    preds, trues = [], []
    
    for batch in loader:
        x1, x2, x3, y = batch
        x1, x2, x3 = x1.to(DEVICE), x2.to(DEVICE), x3.to(DEVICE)
        logits = model(x1, x2, x3)
        p = logits.argmax(dim=1).cpu().numpy()
        preds.append(p)
        trues.append(y.numpy())
    
    preds = np.concatenate(preds)
    trues = np.concatenate(trues)
    
    acc = accuracy_score(trues, preds)
    prec, rec, f1, _ = precision_recall_fscore_support(trues, preds, average="macro", zero_division=0)
    cm = confusion_matrix(trues, preds)
    
    return acc, prec, rec, f1, cm

def clip_gradients(model, max_norm=1.0):
    """
    Clip gradients by global norm to prevent exploding gradients.
    Helps stabilize training, especially early on.
    """
    total = 0.0
    for p in model.parameters():
        if p.grad is not None:
            n = p.grad.data.norm(2)
            total += n.item()**2
    total = total**0.5
    
    if total > max_norm:
        scale = max_norm / (total + 1e-6)
        for p in model.parameters():
            if p.grad is not None:
                p.grad.data.mul_(scale)


## Training Loop

In [15]:
def train_one_epoch(model, loader, optimizer, scheduler, ema, mix_alpha=0.3, label_smoothing=0.1, grad_clip=1.0):
    """Train for one epoch"""
    model.train()
    running = 0.0
    
    for batch in tqdm(loader, desc="Train", leave=False):
        x1, x2, x3, y = batch
        x1, x2, x3, y = x1.to(DEVICE), x2.to(DEVICE), x3.to(DEVICE), y.to(DEVICE)
        
        optimizer.zero_grad(set_to_none=True)
        
        # Apply MixUp
        x1m, x2m, x3m, y1, mix_info = mixup_batch(x1, x2, x3, y, alpha=mix_alpha)
        
        # Forward pass
        logits = model(x1m, x2m, x3m)
        loss = mixup_loss(logits, y1, mix_info=mix_info, label_smoothing=label_smoothing)
        
        # Backward pass
        loss.backward()
        clip_gradients(model, max_norm=grad_clip)
        optimizer.step()
        
        # Step scheduler and update EMA
        if scheduler is not None:
            scheduler.step()
        if ema is not None:
            ema.update(model)
        
        running += loss.item()
    
    return running / max(1, len(loader))

## Complete Training Pipeline and Submission

In [18]:
def main():
    # Hyperparameters
    IMG_SIZE = 96
    BATCH_SIZE = 32
    NUM_WORKERS = 2
    VAL_RATIO = 0.15

    EPOCHS = 50
    LR = 1e-3
    WD = 1e-4
    MAX_LR = 3e-3
    MIX_ALPHA = 0.3
    LABEL_SMOOTH = 0.10
    GRAD_CLIP = 1.0

    # Load and split data
    full_df = pd.read_csv(TRAIN_META)
    train_df, val_df = train_test_split(
        full_df, test_size=VAL_RATIO, random_state=SEED, stratify=full_df["target"]
    )

    # Create datasets
    ds_train = MusicTripleDataset(TRAIN_META, img_size=IMG_SIZE, is_train=True,  df=train_df)
    ds_val   = MusicTripleDataset(TRAIN_META, img_size=IMG_SIZE, is_train=False, df=val_df)

    # Create data loaders
    loader_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
    loader_val   = DataLoader(ds_val,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    # Initialize model
    model = TripleBranchNet(num_classes=16).to(DEVICE)
    n_params = count_params(model)
    print(f"Total trainable parameters: {n_params:,}")
    assert n_params <= 500_000, f"Param limit exceeded: {n_params}"

    # Setup optimizer and scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WD)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=MAX_LR, epochs=EPOCHS, steps_per_epoch=len(loader_train),
        pct_start=0.3, div_factor=10, final_div_factor=100
    )

    # Initialize EMA
    ema = EMA(model, decay=0.999)

    best_f1 = -1.0
    best_state = None

    print("\nStarting training...\n")

    # Training loop
    for epoch in range(1, EPOCHS+1):
        # Train for one epoch
        loss = train_one_epoch(
            model, loader_train, optimizer, scheduler, ema,
            mix_alpha=MIX_ALPHA, label_smoothing=LABEL_SMOOTH, grad_clip=GRAD_CLIP
        )

        # Validate using EMA weights
        cur = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        ema.apply_to(model)
        acc, prec, rec, f1, cm = evaluate(model, loader_val)
        model.load_state_dict(cur, strict=True)

        print(f"Epoch {epoch:02d}/{EPOCHS} | loss={loss:.4f} | val_acc={acc:.4f} | macroP={prec:.4f} | macroR={rec:.4f} | macroF1={f1:.4f}")

        # Save best model
        if f1 > best_f1:
            best_f1 = f1
            best_state = {k: v.detach().cpu().clone() for k, v in ema.shadow.items()}
            torch.save(best_state, "/kaggle/working/best_model_ema.pth")
            print("  ✓ Saved new best EMA model")

    # Final evaluation
    print("\n" + "="*60)
    print("FINAL VALIDATION (BEST EMA)")
    print("="*60)
    
    # Load best model
    model.load_state_dict(best_state, strict=True)
    acc, prec, rec, f1, cm = evaluate(model, loader_val)
    
    print(f"Total trainable parameters: {n_params:,}")
    print(f"Accuracy:        {acc:.6f}")
    print(f"Macro-Precision: {prec:.6f}")
    print(f"Macro-Recall:    {rec:.6f}")
    print(f"Macro-F1:        {f1:.6f}")
    print("\nConfusion Matrix:\n", cm)
    print("="*60 + "\n")

    # Test inference and submission
    
    # Load test data
    test_df = pd.read_csv(TEST_META)
    ds_test = MusicTripleDataset(TEST_META, img_size=IMG_SIZE, is_train=False, df=test_df)
    loader_test = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

    # Run inference
    model.eval()
    preds = []
    with torch.no_grad():
        for batch in tqdm(loader_test, desc="Infer Test"):
            x1, x2, x3 = batch
            x1, x2, x3 = x1.to(DEVICE), x2.to(DEVICE), x3.to(DEVICE)
            logits = model(x1, x2, x3)
            preds.append(logits.argmax(dim=1).cpu().numpy())
    preds = np.concatenate(preds)[:len(test_df)]

    # Create submission file
    sub = pd.DataFrame({"id": test_df["id"].values, "target": preds.astype(int)})
    out_path = "/kaggle/working/submission96.csv"
    sub.to_csv(out_path, index=False)
    print("Saved:", out_path)
    print(sub.head(10))

# Execute
main()

Total trainable parameters: 278,912

Starting training...



                                                        

Epoch 01/50 | loss=2.3133 | val_acc=0.1726 | macroP=0.0687 | macroR=0.1255 | macroF1=0.0703
  ✓ Saved new best EMA model


                                                        

Epoch 02/50 | loss=2.0434 | val_acc=0.2834 | macroP=0.1686 | macroR=0.2277 | macroF1=0.1380
  ✓ Saved new best EMA model


                                                        

Epoch 03/50 | loss=1.9162 | val_acc=0.3894 | macroP=0.4551 | macroR=0.3405 | macroF1=0.2884
  ✓ Saved new best EMA model


                                                        

Epoch 04/50 | loss=1.8529 | val_acc=0.5287 | macroP=0.5559 | macroR=0.4893 | macroF1=0.4870
  ✓ Saved new best EMA model


                                                        

Epoch 05/50 | loss=1.7885 | val_acc=0.6161 | macroP=0.6005 | macroR=0.5811 | macroF1=0.5794
  ✓ Saved new best EMA model


                                                        

Epoch 06/50 | loss=1.7509 | val_acc=0.6530 | macroP=0.6343 | macroR=0.6232 | macroF1=0.6198
  ✓ Saved new best EMA model


                                                        

Epoch 07/50 | loss=1.7111 | val_acc=0.6799 | macroP=0.6647 | macroR=0.6528 | macroF1=0.6493
  ✓ Saved new best EMA model


                                                        

Epoch 08/50 | loss=1.7176 | val_acc=0.7101 | macroP=0.7000 | macroR=0.6863 | macroF1=0.6843
  ✓ Saved new best EMA model


                                                        

Epoch 09/50 | loss=1.6745 | val_acc=0.7327 | macroP=0.7235 | macroR=0.7114 | macroF1=0.7100
  ✓ Saved new best EMA model


                                                        

Epoch 10/50 | loss=1.6481 | val_acc=0.7495 | macroP=0.7403 | macroR=0.7302 | macroF1=0.7278
  ✓ Saved new best EMA model


                                                        

Epoch 11/50 | loss=1.6179 | val_acc=0.7706 | macroP=0.7629 | macroR=0.7513 | macroF1=0.7515
  ✓ Saved new best EMA model


                                                        

Epoch 12/50 | loss=1.6149 | val_acc=0.7908 | macroP=0.7839 | macroR=0.7740 | macroF1=0.7745
  ✓ Saved new best EMA model


                                                        

Epoch 13/50 | loss=1.5915 | val_acc=0.8079 | macroP=0.8008 | macroR=0.7916 | macroF1=0.7925
  ✓ Saved new best EMA model


                                                        

Epoch 14/50 | loss=1.6140 | val_acc=0.8183 | macroP=0.8133 | macroR=0.8032 | macroF1=0.8053
  ✓ Saved new best EMA model


                                                        

Epoch 15/50 | loss=1.5572 | val_acc=0.8207 | macroP=0.8143 | macroR=0.8040 | macroF1=0.8066
  ✓ Saved new best EMA model


                                                        

Epoch 16/50 | loss=1.5458 | val_acc=0.8293 | macroP=0.8220 | macroR=0.8111 | macroF1=0.8144
  ✓ Saved new best EMA model


                                                        

Epoch 17/50 | loss=1.5113 | val_acc=0.8387 | macroP=0.8337 | macroR=0.8220 | macroF1=0.8256
  ✓ Saved new best EMA model


                                                        

Epoch 18/50 | loss=1.5201 | val_acc=0.8439 | macroP=0.8398 | macroR=0.8280 | macroF1=0.8311
  ✓ Saved new best EMA model


                                                        

Epoch 19/50 | loss=1.4794 | val_acc=0.8497 | macroP=0.8447 | macroR=0.8339 | macroF1=0.8369
  ✓ Saved new best EMA model


                                                        

Epoch 20/50 | loss=1.5005 | val_acc=0.8580 | macroP=0.8522 | macroR=0.8454 | macroF1=0.8472
  ✓ Saved new best EMA model


                                                        

Epoch 21/50 | loss=1.4732 | val_acc=0.8601 | macroP=0.8544 | macroR=0.8472 | macroF1=0.8494
  ✓ Saved new best EMA model


                                                        

Epoch 22/50 | loss=1.4339 | val_acc=0.8650 | macroP=0.8606 | macroR=0.8527 | macroF1=0.8550
  ✓ Saved new best EMA model


                                                        

Epoch 23/50 | loss=1.4552 | val_acc=0.8677 | macroP=0.8629 | macroR=0.8561 | macroF1=0.8579
  ✓ Saved new best EMA model


                                                        

Epoch 24/50 | loss=1.4333 | val_acc=0.8717 | macroP=0.8669 | macroR=0.8593 | macroF1=0.8615
  ✓ Saved new best EMA model


                                                        

Epoch 25/50 | loss=1.3764 | val_acc=0.8763 | macroP=0.8728 | macroR=0.8640 | macroF1=0.8668
  ✓ Saved new best EMA model


                                                        

Epoch 26/50 | loss=1.4150 | val_acc=0.8839 | macroP=0.8800 | macroR=0.8719 | macroF1=0.8744
  ✓ Saved new best EMA model


                                                        

Epoch 27/50 | loss=1.3981 | val_acc=0.8839 | macroP=0.8787 | macroR=0.8721 | macroF1=0.8741


                                                        

Epoch 28/50 | loss=1.3846 | val_acc=0.8870 | macroP=0.8818 | macroR=0.8762 | macroF1=0.8777
  ✓ Saved new best EMA model


                                                        

Epoch 29/50 | loss=1.3603 | val_acc=0.8949 | macroP=0.8892 | macroR=0.8852 | macroF1=0.8860
  ✓ Saved new best EMA model


                                                        

Epoch 30/50 | loss=1.3631 | val_acc=0.8940 | macroP=0.8891 | macroR=0.8838 | macroF1=0.8852


                                                        

Epoch 31/50 | loss=1.3393 | val_acc=0.8940 | macroP=0.8909 | macroR=0.8840 | macroF1=0.8862
  ✓ Saved new best EMA model


                                                        

Epoch 32/50 | loss=1.2931 | val_acc=0.8989 | macroP=0.8958 | macroR=0.8897 | macroF1=0.8916
  ✓ Saved new best EMA model


                                                        

Epoch 33/50 | loss=1.3256 | val_acc=0.9053 | macroP=0.9020 | macroR=0.8958 | macroF1=0.8980
  ✓ Saved new best EMA model


                                                        

Epoch 34/50 | loss=1.3160 | val_acc=0.9071 | macroP=0.9034 | macroR=0.8980 | macroF1=0.8999
  ✓ Saved new best EMA model


                                                        

Epoch 35/50 | loss=1.2839 | val_acc=0.9108 | macroP=0.9082 | macroR=0.9016 | macroF1=0.9042
  ✓ Saved new best EMA model


                                                        

Epoch 36/50 | loss=1.2737 | val_acc=0.9136 | macroP=0.9103 | macroR=0.9050 | macroF1=0.9069
  ✓ Saved new best EMA model


                                                        

Epoch 37/50 | loss=1.2370 | val_acc=0.9154 | macroP=0.9119 | macroR=0.9070 | macroF1=0.9086
  ✓ Saved new best EMA model


                                                        

Epoch 38/50 | loss=1.2167 | val_acc=0.9172 | macroP=0.9129 | macroR=0.9090 | macroF1=0.9104
  ✓ Saved new best EMA model


                                                        

Epoch 39/50 | loss=1.2629 | val_acc=0.9206 | macroP=0.9172 | macroR=0.9134 | macroF1=0.9147
  ✓ Saved new best EMA model


                                                        

Epoch 40/50 | loss=1.2248 | val_acc=0.9224 | macroP=0.9195 | macroR=0.9154 | macroF1=0.9170
  ✓ Saved new best EMA model


                                                        

Epoch 41/50 | loss=1.2149 | val_acc=0.9236 | macroP=0.9205 | macroR=0.9167 | macroF1=0.9180
  ✓ Saved new best EMA model


                                                        

Epoch 42/50 | loss=1.2036 | val_acc=0.9255 | macroP=0.9225 | macroR=0.9184 | macroF1=0.9200
  ✓ Saved new best EMA model


                                                        

Epoch 43/50 | loss=1.2163 | val_acc=0.9261 | macroP=0.9235 | macroR=0.9187 | macroF1=0.9207
  ✓ Saved new best EMA model


                                                        

Epoch 44/50 | loss=1.1835 | val_acc=0.9267 | macroP=0.9241 | macroR=0.9193 | macroF1=0.9213
  ✓ Saved new best EMA model


                                                        

Epoch 45/50 | loss=1.1663 | val_acc=0.9276 | macroP=0.9242 | macroR=0.9201 | macroF1=0.9217
  ✓ Saved new best EMA model


                                                        

Epoch 46/50 | loss=1.1750 | val_acc=0.9285 | macroP=0.9255 | macroR=0.9211 | macroF1=0.9228
  ✓ Saved new best EMA model


                                                        

Epoch 47/50 | loss=1.1678 | val_acc=0.9291 | macroP=0.9259 | macroR=0.9215 | macroF1=0.9232
  ✓ Saved new best EMA model


                                                        

Epoch 48/50 | loss=1.2045 | val_acc=0.9307 | macroP=0.9272 | macroR=0.9234 | macroF1=0.9249
  ✓ Saved new best EMA model


                                                        

Epoch 49/50 | loss=1.1628 | val_acc=0.9325 | macroP=0.9296 | macroR=0.9255 | macroF1=0.9271
  ✓ Saved new best EMA model


                                                        

Epoch 50/50 | loss=1.1528 | val_acc=0.9319 | macroP=0.9288 | macroR=0.9248 | macroF1=0.9264

FINAL VALIDATION (BEST EMA)
Total trainable parameters: 278,912
Accuracy:        0.932498
Macro-Precision: 0.929580
Macro-Recall:    0.925461
Macro-F1:        0.927060

Confusion Matrix:
 [[340   1   0   0   0   0   0   0   0   0   1   0   0   0   0   0]
 [  6 119   1   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0 245   1   0   0   0   0   0   0   0   0   0   0   0   0]
 [  1   0   0 328   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0 144   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   1   1   0 164   3   1   0   1   3   7   0   1   1   3]
 [  0   0   0   0   1   4 158   1   3   1   3   5   2   0   1   0]
 [  0   0   0   0   0   2   3 147   3   0   0   6   0   1   0   0]
 [  2   0   0   0   0   6   0   3 210   1   4   6   3   0   3   0]
 [  0   0   0   0   0   0   0   0   0 148   9   1   1   2   1   0]
 [  2   0   0   1   1   0   0   1   0   8 165   1

Infer Test: 100%|██████████| 455/455 [04:01<00:00,  1.89it/s]

Saved: /kaggle/working/submission96.csv
   id  target
0   0       3
1   1       8
2   2       8
3   3       0
4   4      11
5   5      10
6   6       3
7   7       8
8   8       4
9   9       2



