Simple is best. Using sklearns mutlioutput classifier on a basic CNN model

First, basics

In [1]:
# Probably more imports than are really necessary...
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB
from tqdm import tqdm
import librosa
import numpy as np
import miditoolkit
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, average_precision_score, accuracy_score
import random

In [2]:
TAGS = ['rock', 'oldies', 'jazz', 'pop', 'dance', 'blues', 'punk', 'chill', 'electronic', 'country']
tag_to_index = {tag: i for i, tag in enumerate(TAGS)}


# do multi-hot encoding

def multi_hot_encode(tags):
    """
    Given a list of tag strings, return a multi-hot encoded tensor.
    Example input: ['jazz', 'pop']
    Output: tensor([0., 0., 1., 1., 0., 0., 0., 0., 0., 0.])
    """
    vec = torch.zeros(len(TAGS), dtype=torch.float32)
    for tag in tags:
        if tag in tag_to_index:
            vec[tag_to_index[tag]] = 1.0
        else:
            raise ValueError(f"Unknown tag: {tag}")
    return vec


In [3]:
dataroot3 = "data/student_files/task3_audio_classification/"

In [4]:
import torch
import joblib  # for sklearn models
import os

def save_model(model, filepath='sol_3_2.pt'):
    """Save a PyTorch or scikit-learn model to a file"""
    if 'torch' in str(type(model)):
        torch.save(model.state_dict(), filepath)
        print(f"PyTorch model saved to {filepath}")
    else:
        joblib.dump(model, filepath)
        print(f"scikit-learn model saved to {filepath}")

def load_model(model_class_or_type, filepath='sol_3_2.pt', *args, **kwargs):
    """Load a PyTorch or scikit-learn model from a file"""
    ext = os.path.splitext(filepath)[1]
    
    if ext in ['.pt', '.pth']:
        model = model_class_or_type(*args, **kwargs)  # instantiate PyTorch model
        model.load_state_dict(torch.load(filepath))
        model.eval()
        print(f"PyTorch model loaded from {filepath}")
    else:
        model = joblib.load(filepath)
        print(f"scikit-learn model loaded from {filepath}")

    return model

I already have the data i want, just load it in

In [5]:
import pickle

# load data
with open("task3_train_data.pkl", "rb") as file:
    data = pickle.load(file)

X_data = data['x']
y_data = data['y']

In [6]:
X_data.shape, y_data.shape

(torch.Size([4000, 3, 128, 512]), torch.Size([4000, 10]))

In [7]:
type(X_data)

torch.Tensor

Format for skLearn

In [8]:
# Split the 3 channels into separate tensors
mel = X_data[:, 0].unsqueeze(1)   # (4000, 1, 128, 512)
mfcc = X_data[:, 1].unsqueeze(1)  # (4000, 1, 128, 512)
q = X_data[:, 2].unsqueeze(1)     # (4000, 1, 128, 512)

# No change needed for y_data
y = y_data  # (4000, 10)

In [9]:
from skmultilearn.model_selection import iterative_train_test_split
import numpy as np

# Real input arrays (shape: [4000, 1, 128, 512])
mel_np = mel.numpy()
mfcc_np = mfcc.numpy()
q_np = q.numpy()
y_np = y.numpy()

# Step 1: Create dummy X just for index-based splitting
X_dummy = np.zeros((y_np.shape[0], 1))  # shape: (4000, 1)

# Step 2: Perform iterative split on dummy X
X_dummy_train, y_train, X_dummy_val, y_val = iterative_train_test_split(X_dummy, y_np, test_size=0.1)

# Step 3: Get indices back by comparing values (X_dummy has 0s)
train_idx = np.where(X_dummy_train[:, 0] == 0)[0]
val_idx = np.where(X_dummy_val[:, 0] == 0)[0]

# Step 4: Use those indices to slice real inputs
mel_train, mfcc_train, q_train = mel_np[train_idx], mfcc_np[train_idx], q_np[train_idx]
mel_val, mfcc_val, q_val = mel_np[val_idx], mfcc_np[val_idx], q_np[val_idx]

# Final output format
X_train = (mel_train, mfcc_train, q_train)
X_val = (mel_val, mfcc_val, q_val)

Class Imbalances

In [10]:
import torch
from collections import Counter

def verify_data(y_train, mel=None, mfcc=None, num_classes=10):
    """Check label distribution and optional input stats. 
       Returns pos_weight tensor for BCEWithLogitsLoss to handle class imbalance.

       Parameters:
       - y_train (Tensor or ndarray): shape (N, num_classes), binary multi-label.
       - mel (Tensor or ndarray): optional, for range checking.
       - mfcc (Tensor or ndarray): optional, for range checking.
    """
    if isinstance(y_train, np.ndarray):
        y_train = torch.tensor(y_train)
    
    label_counter = Counter()
    total_assignments = 0
    sample_count = y_train.size(0)

    for i in range(num_classes):
        class_count = (y_train[:, i] == 1).sum().item()
        label_counter[i] += class_count
        total_assignments += class_count

    print(f"Total samples: {sample_count}")
    print(f"Total class assignments (1s): {total_assignments}\n")

    print("Class frequency distribution:")
    counts = []
    for i in range(num_classes):
        count = label_counter[i]
        counts.append(count)
        print(f"  Class {i}: {count} assignments ({count / total_assignments:.2%})")

    # Compute pos_weight = (N - count) / count
    label_counts_tensor = torch.tensor(counts, dtype=torch.float)
    pos_weight = (sample_count - label_counts_tensor) / label_counts_tensor

    print("\nComputed pos_weight (for BCEWithLogitsLoss):")
    for i, w in enumerate(pos_weight):
        print(f"  Class {i}: {w.item():.4f}")

    # Optional: check mel / mfcc value ranges
    if mel is not None and isinstance(mel, np.ndarray):
        mel = torch.tensor(mel)
    if mfcc is not None and isinstance(mfcc, np.ndarray):
        mfcc = torch.tensor(mfcc)

    if mel is not None:
        if torch.isnan(mel).any() or torch.isinf(mel).any():
            print("WARNING: NaN or Inf values found in mel data!")
        else:
            print(f"mel range: [{mel.min().item():.4f}, {mel.max().item():.4f}]")

    if mfcc is not None:
        if torch.isnan(mfcc).any() or torch.isinf(mfcc).any():
            print("WARNING: NaN or Inf values found in mfcc data!")
        else:
            print(f"mfcc range: [{mfcc.min().item():.4f}, {mfcc.max().item():.4f}]")

    return pos_weight

In [11]:
pos_weight = verify_data(y_train)
pos_weight = torch.log1p(pos_weight) 
pos_weight

Total samples: 3602
Total class assignments (1s): 4166

Class frequency distribution:
  Class 0: 1765 assignments (42.37%)
  Class 1: 147 assignments (3.53%)
  Class 2: 351 assignments (8.43%)
  Class 3: 617 assignments (14.81%)
  Class 4: 152 assignments (3.65%)
  Class 5: 220 assignments (5.28%)
  Class 6: 162 assignments (3.89%)
  Class 7: 41 assignments (0.98%)
  Class 8: 476 assignments (11.43%)
  Class 9: 235 assignments (5.64%)

Computed pos_weight (for BCEWithLogitsLoss):
  Class 0: 1.0408
  Class 1: 23.5034
  Class 2: 9.2621
  Class 3: 4.8379
  Class 4: 22.6974
  Class 5: 15.3727
  Class 6: 21.2346
  Class 7: 86.8537
  Class 8: 6.5672
  Class 9: 14.3277


tensor([0.7133, 3.1988, 2.3285, 1.7644, 3.1654, 2.7956, 3.1016, 4.4757, 2.0238,
        2.7297])

Model stuff

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNNBranch(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),   # (B, 16, 128, 512)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),                           # (B, 16, 64, 256)
        
            nn.Conv2d(16, 32, kernel_size=3, padding=1),  # (B, 32, 64, 256)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),                           # (B, 32, 32, 128)
        
            nn.Conv2d(32, 32, kernel_size=3, padding=1),  # (B, 32, 32, 128)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),                           # (B, 32, 16, 64)
        
            nn.Conv2d(32, 64, kernel_size=3, padding=1),  # (B, 64, 16, 64)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),                           # (B, 64, 8, 32)
        )


        self.fc_branch = nn.Sequential(
            nn.Flatten(),                                # (B, 96 * 8 * 32 = 24,576)
            nn.Linear(64 * 8 * 32, 512),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.fc_branch(x)  # (B, 512)
        return x

class MultiInputCNNClassifier(nn.Module):
    def __init__(self, n_classes=10):
        super().__init__()
        self.branch_mel = CNNBranch()
        self.branch_mfcc = CNNBranch()
        self.branch_q = CNNBranch()

        self.classifier = nn.Sequential(
            nn.Linear(512 * 3, 256),  # (B, 1536)
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, n_classes)
        )

    def forward(self, mel, mfcc, q):  # Each input: (B, 1, 128, 512)
        f_mel = self.branch_mel(mel)
        f_mfcc = self.branch_mfcc(mfcc)
        f_q = self.branch_q(q)

        x = torch.cat([f_mel, f_mfcc, f_q], dim=1)  # (B, 1536)
        return self.classifier(x)                   # (B, n_classes)

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNNClassifier(nn.Module):
    def __init__(self, n_classes=10, input_shape=(1, 128, 512)):
        super(CNNClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.3)
        
        # Compute flattened feature size dynamically
        self.feature_dim = self._get_feature_dim(input_shape)
        
        self.fc1 = nn.Linear(self.feature_dim, 256)
        self.fc2 = nn.Linear(256, n_classes)

    def _get_feature_dim(self, input_shape):
        x = torch.zeros(1, *input_shape)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        return x.view(1, -1).shape[1]

    def forward(self, x):  # x: (B, 1, 128, 512)
        x = self.pool(F.relu(self.conv1(x)))  # (B, 16, 64, 256)
        x = self.pool(F.relu(self.conv2(x)))  # (B, 32, 32, 128)
        x = x.view(x.size(0), -1)             # Flatten
        x = self.dropout(F.relu(self.fc1(x)))
        return torch.sigmoid(self.fc2(x))     # For multilabel classification


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SingleInputCNNClassifier(nn.Module):
    def __init__(self, n_classes=10, dropout_rate=0.3):
        super().__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),         # (B, 16, 64, 256)
            nn.Dropout2d(0.1),
            
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),         # (B, 32, 32, 128)
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),         # (B, 64, 16, 64)
            nn.Dropout2d(0.15)
        )
        
        # Use adaptive pooling to avoid giant flattening
        self.pool = nn.AdaptiveAvgPool2d((4, 4))  # Output: (B, 64, 4, 4)

        self.classifier = nn.Sequential(
            nn.Flatten(),              # (B, 64 * 4 * 4 = 1024)
            nn.Linear(64 * 4 * 4, 256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, n_classes)
        )

        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):  # x: (B, 1, 128, 512)
        x = self.features(x)
        x = self.pool(x)
        x = self.classifier(x)
        return x  # (B, n_classes)


Wrapper for sklearns mutli-output classifier

In [23]:
from sklearn.base import BaseEstimator, ClassifierMixin
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
from sklearn.metrics import average_precision_score


# Assumes MultiInputCNNClassifier is defined elsewhere and imported
# from your_model import MultiInputCNNClassifier

class SklearnCNNClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, epochs=500, lr=1e-3, batch_size=24, device='cpu', n_classes=1):
        self.epochs = epochs
        self.lr = lr
        self.batch_size = batch_size
        self.device = device
        self.n_classes = n_classes
        self._build_model()
        self.history = []


    def _build_model(self):
        # self.model = MultiInputCNNClassifier(n_classes=self.n_classes).to(self.device)
        # self.model = SingleInputCNNClassifier(n_classes=self.n_classes).to(self.device)
        self.model = CNNClassifier(n_classes=10).to(self.device)
        self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(self.device))
        self. optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=1e-4)


    def overfit(self, X, y, max_epochs=100, patience=10, restore_best=True):
        import time
        import torch
        from torch.utils.data import TensorDataset, DataLoader
    
        device = torch.device(self.device)
        print(f"\n🖥️  Using device: {device}")
    
        mel, mfcc, q = [torch.tensor(arr, dtype=torch.float32) for arr in X]
        y = torch.tensor(y, dtype=torch.float32)
    
        dataset = TensorDataset(mel, mfcc, q, y)
        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
    
        self.model.to(device)
        self.model.train()
    
        start_time = time.time()
        print(f"🔥 Overfitting CNN classifier (output dim={self.n_classes}) to maximize mAP...")
    
        best_map = -1
        best_epoch = 0
        best_state = None
    
        for epoch in range(1, max_epochs + 1):
            epoch_loss = 0.0
            num_batches = 0
    
            all_preds = []
            all_targets = []
    
            for i, (mel_b, mfcc_b, q_b, y_b) in enumerate(loader):
            
                mel_b = mel_b.to(device)
                mfcc_b = mfcc_b.to(device)
                q_b = q_b.to(device)
                y_b = y_b.to(device)
    
                self.optimizer.zero_grad()
                # outputs = self.model(mel_b, mfcc_b, q_b)
                outputs = self.model(mel_b)
                loss = self.criterion(outputs, y_b)
                loss.backward()
                self.optimizer.step()
    
                epoch_loss += loss.item()
                num_batches += 1
    
                preds = torch.sigmoid(outputs).detach().cpu()
                all_preds.append(preds)
                all_targets.append(y_b.detach().cpu())
    
                if i == 0:
                    print(f"\nEpoch {epoch}, Batch {i}:")
                    print(f"  Targets     : {y_b.cpu().numpy()[:1]}")
                    print(f"  Predictions : {preds.numpy()[:1]}")
    
            avg_loss = epoch_loss / num_batches
            all_preds = torch.cat(all_preds).numpy()
            all_targets = torch.cat(all_targets).numpy()
    
            try:
                map_score = average_precision_score(all_targets, all_preds, average="macro")
            except ValueError as e:
                map_score = float("nan")
                print(f"⚠️  mAP computation failed: {e}")
    
            print(f"✅ Epoch {epoch}/{max_epochs} — Avg Loss: {avg_loss:.4f} | mAP: {map_score:.4f}")
    
            # Track best model
            if map_score > best_map:
                best_map = map_score
                best_epoch = epoch
            elif epoch - best_epoch >= patience:
                print(f"\n⏹️ Early stopping: mAP has not improved for {patience} epochs (Best: {best_map:.4f} @ epoch {best_epoch})")
                break
    
            if device.type == "mps":
                torch.mps.empty_cache()
    
        if restore_best and best_state is not None:
            self.model.load_state_dict(best_state)
            print(f"\n🧠 Restored best model state from epoch {best_epoch} with mAP={best_map:.4f}")
    
        elapsed = time.time() - start_time
        print(f"🏁 Overfitting complete. Best mAP: {best_map:.4f}. Time elapsed: {elapsed:.2f}s\n")
    
        return self


    def fit(self, X_train, y_train, X_val=None, y_val=None):
        import time
        import copy
        import torch
        from torch.utils.data import TensorDataset, DataLoader
        from sklearn.metrics import average_precision_score
        from torch.optim.lr_scheduler import ReduceLROnPlateau
    
        device = torch.device(self.device)
        print(f"\n🖥️  Using device: {device}")
        if device.type == "mps":
            torch.mps.empty_cache()
    
        self.history = []
        patience = 500
        best_val_map = float('-inf')
        best_model_state = None
        epochs_without_improvement = 0
    
        scheduler = ReduceLROnPlateau(self.optimizer, mode='min', factor=0.5, patience=5)
    
        def make_loader(X, y, shuffle):
            mel, mfcc, q = [torch.tensor(arr, dtype=torch.float32, device=self.device) for arr in X]
            y = torch.tensor(y, dtype=torch.float32, device=self.device)
            dataset = TensorDataset(mel, mfcc, q, y)
            return DataLoader(dataset, batch_size=self.batch_size, shuffle=shuffle)
    
        train_loader = make_loader(X_train, y_train, shuffle=True)
        val_loader = make_loader(X_val, y_val, shuffle=False) if X_val is not None and y_val is not None else None
    
        self.model.to(device)
        start_time = time.time()
        print(f"🚀 Training CNN classifier (output dim={self.n_classes}) for up to {self.epochs} epochs...")
    
        for epoch in range(1, self.epochs + 1):
            self.model.train()
            epoch_loss = 0.0
            num_batches = 0
            all_preds_train = []
            all_targets_train = []
    
            for mel_b, mfcc_b, q_b, y_b in train_loader:
                mel_b, mfcc_b, q_b, y_b = mel_b.to(device), mfcc_b.to(device), q_b.to(device), y_b.to(device)
    
                self.optimizer.zero_grad()
                outputs = self.model(mel_b)  # or (mel_b, mfcc_b, q_b)
                loss = self.criterion(outputs, y_b)
                loss.backward()
                self.optimizer.step()
    
                epoch_loss += loss.item()
                num_batches += 1
    
                preds = torch.sigmoid(outputs).detach().cpu()
                all_preds_train.append(preds)
                all_targets_train.append(y_b.detach().cpu())
    
            avg_train_loss = epoch_loss / num_batches
            all_preds_train = torch.cat(all_preds_train).numpy()
            all_targets_train = torch.cat(all_targets_train).numpy()
    
            try:
                train_map = average_precision_score(all_targets_train, all_preds_train, average="macro")
            except ValueError as e:
                train_map = float("nan")
                print(f"⚠️  Train mAP computation failed: {e}")
    
            # Validation
            avg_val_loss = None
            val_map = None
    
            if val_loader is not None:
                self.model.eval()
                all_preds_val, all_targets_val = [], []
                val_loss, val_batches = 0.0, 0
    
                with torch.no_grad():
                    for mel_b, mfcc_b, q_b, y_b in val_loader:
                        mel_b, mfcc_b, q_b, y_b = mel_b.to(device), mfcc_b.to(device), q_b.to(device), y_b.to(device)
                        outputs = self.model(mel_b)
                        preds = torch.sigmoid(outputs)
                        loss = self.criterion(outputs, y_b)
    
                        val_loss += loss.item()
                        val_batches += 1
                        all_preds_val.append(preds.cpu())
                        all_targets_val.append(y_b.cpu())
    
                avg_val_loss = val_loss / val_batches
                all_preds_val = torch.cat(all_preds_val).numpy()
                all_targets_val = torch.cat(all_targets_val).numpy()
    
                try:
                    val_map = average_precision_score(all_targets_val, all_preds_val, average="macro")
                except ValueError as e:
                    val_map = float("nan")
                    print(f"⚠️  Val mAP computation failed: {e}")
    
            # Logging
            current_lr = self.optimizer.param_groups[0]['lr']
            elapsed = time.time() - start_time
    
            print(f"\n✅ Epoch {epoch}/{self.epochs} — "
                  f"Train Loss: {avg_train_loss:.4f} "
                  f"| Train mAP: {train_map:.4f} "
                  f"{f'| Val Loss: {avg_val_loss:.4f} | Val mAP: {val_map:.4f}' if val_map is not None else ''} "
                  f"| LR: {current_lr:.6f} | Elapsed: {elapsed:.1f}s")
    
            # History
            self.history.append({
                "epoch": epoch,
                "train_loss": avg_train_loss,
                "train_map": train_map,
                "val_loss": avg_val_loss,
                "val_map": val_map
            })
    
            # LR Scheduler
            if avg_val_loss is not None:
                scheduler.step(avg_val_loss)
    
            # Early stopping
            if val_map is not None:
                if val_map > best_val_map:
                    best_val_map = val_map
                    best_model_state = copy.deepcopy(self.model.state_dict())
                    save_model(self.model)
                    epochs_without_improvement = 0
                else:
                    epochs_without_improvement += 1
                    if epochs_without_improvement >= patience:
                        print(f"\n🛑 Early stopping at epoch {epoch} (no val mAP improvement for {patience} epochs).")
                        break
    
            if device.type == "mps":
                torch.mps.empty_cache()
    
        # Restore best
        if best_model_state:
            self.model.load_state_dict(best_model_state)
    
        elapsed = time.time() - start_time
        print(f"\n🏁 Training complete. Best val mAP: {best_val_map:.4f}. Total time: {elapsed:.2f}s")
        return self



    def predict(self, X):
        import torch
    
        self.model.eval()
        device = torch.device(self.device)
    
        mel, mfcc, q = [torch.tensor(arr, dtype=torch.float32, device=device) for arr in X]
    
        with torch.no_grad():
            # outputs = self.model(mel, mfcc, q)
            outputs = self.model(mel)

            preds = torch.sigmoid(outputs).cpu().numpy()
            return (preds > 0.5).astype(int)



    def score(self, X, y):
        from sklearn.metrics import average_precision_score
        import torch
    
        self.model.eval()
        device = torch.device(self.device)
    
        # Unpack and convert inputs to torch tensors on CPU
        mel, mfcc, q = [torch.tensor(arr, dtype=torch.float32) for arr in X]
        y_true = torch.tensor(y, dtype=torch.float32).unsqueeze(-1)
    
        preds = []
    
        # Predict in batches to avoid OOM
        batch_size = self.batch_size
        for i in range(0, len(y_true), batch_size):
            mel_b = mel[i:i+batch_size].to(device)
            mfcc_b = mfcc[i:i+batch_size].to(device)
            q_b = q[i:i+batch_size].to(device)
    
            with torch.no_grad():
                # outputs = self.model(mel_b, mfcc_b, q_b)
                outputs = self.model(mel_b)

                preds.append(outputs.cpu())
    
        y_pred = torch.cat(preds).numpy()
        y_true = y_true.numpy()
    
        # Compute mean Average Precision (mAP)
        return average_precision_score(y_true, y_pred)


    def get_params(self, deep=True):
        return {
            'epochs': self.epochs,
            'lr': self.lr,
            'batch_size': self.batch_size,
            'device': self.device,
            'n_classes': self.n_classes
        }

    def set_params(self, **params):
        for key, value in params.items():
            setattr(self, key, value)
        self._build_model()
        return self

Let's try to overfit on small data

In [15]:
from sklearn.multioutput import MultiOutputClassifier
# clf = MultiOutputClassifier(SklearnCNNClassifier(device='mps', n_classes=1, epochs=10), n_jobs=None)
# clf = SklearnCNNClassifier(n_classes=10)
# X_subset = tuple(x[:30] for x in X_train)  # Correctly slice each of mel, mfcc, q
# clf.overfit(X_subset, y_train[:30])
# clf.fit(X_subset, y_train[:30])

In [24]:
from sklearn.multioutput import MultiOutputClassifier
clf = SklearnCNNClassifier(device='mps', n_classes=10, epochs=50, lr=0.001)
clf.fit(X_train, y_train, X_val, y_val)
# clf.fit(X_train, y_train)


🖥️  Using device: mps
🚀 Training CNN classifier (output dim=10) for up to 50 epochs...



KeyboardInterrupt



In [None]:
import matplotlib.pyplot as plt

def plot_training_history(history):
    epochs = [h["epoch"] for h in history]
    train_loss = [h["train_loss"] for h in history]
    val_loss = [h["val_loss"] for h in history]
    train_map = [h["train_map"] for h in history]
    val_map = [h["val_map"] for h in history]

    plt.figure(figsize=(12, 5))

    # Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, label="Train Loss")
    plt.plot(epochs, val_loss, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Loss over Epochs")
    plt.legend()

    # mAP
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_map, label="Train mAP")
    plt.plot(epochs, val_map, label="Val mAP")
    plt.xlabel("Epoch")
    plt.ylabel("mAP")
    plt.title("mAP over Epochs")
    plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
# After training:
plot_training_history(clf.history)