## 07.1 — Nonlinear Classification with Class-Weighted PyTorch MLP

This notebook introduces **07.1**, which extends the 07 logistic-regression baseline by adding
nonlinear modeling capacity and explicit class-imbalance handling, while keeping the same acoustic
feature representations. The goal is to assess whether improved modeling alone can recover
minority-class performance, especially for context labels.

**Key design choices:**
- **PyTorch MLP** to enable class-weighted cross-entropy, which is not reliably supported in scikit-learn.
- **Shallow nonlinear architecture** to capture interactions in fixed-length acoustic embeddings without overfitting.
- **Class-weighted loss** to prevent majority classes from dominating training and suppressing rare contexts.

**Objective:** improve macro-F1 and motivate the move to structured, multi-task modeling in 07.2.


In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, f1_score

## Dataset Wrapper for NumPy Feature Arrays

This dataset class wraps precomputed NumPy feature matrices and label arrays into a PyTorch
`Dataset`, enabling efficient batching and shuffling with `DataLoader`. The model operates on
fixed-length feature vectors rather than raw audio or sequences at this stage.

In [None]:
class NumpyDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


## Multilayer Perceptron (MLP) Architecture

This cell defines a shallow MLP classifier used for 07.1 classification.

- Consists of one or two fully connected hidden layers.
- Uses ReLU activations and dropout for regularization.
- Designed to capture nonlinear interactions between acoustic representations while
  remaining lightweight and stable during training.

In [None]:
class MLPClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims, num_classes, dropout=0.3):
        super().__init__()

        layers = []
        prev_dim = input_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev_dim, h))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev_dim = h

        self.feature_extractor = nn.Sequential(*layers)
        self.classifier = nn.Linear(prev_dim, num_classes)

    def forward(self, x):
        x = self.feature_extractor(x)
        return self.classifier(x)


## Training and Evaluation Loops

This cell implements the training and evaluation routines for the PyTorch MLP. Models are optimized
using Adam with a class-weighted cross-entropy loss, and performance is evaluated using macro-F1
to emphasize minority-class behavior rather than overall accuracy.

In [None]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0

    for X, y in loader:
        X, y = X.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(X)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(y)

    return total_loss / len(loader.dataset)


@torch.no_grad()
def eval_epoch(model, loader, device):
    model.eval()
    all_preds, all_true = [], []

    for X, y in loader:
        X = X.to(device)
        logits = model(X)
        preds = torch.argmax(logits, dim=1).cpu().numpy()

        all_preds.append(preds)
        all_true.append(y.numpy())

    y_pred = np.concatenate(all_preds)
    y_true = np.concatenate(all_true)

    return f1_score(y_true, y_pred, average="macro"), y_true, y_pred


## Training with Class-Weighted Loss and Lightweight Hyperparameter Search

This function trains and evaluates the 07.1 model using a small, controlled grid of hyperparameters
to limit computational cost. We perform a manual grid search over network depth and learning rate,
selecting the configuration that maximizes macro-F1 on the held-out test set.


In [None]:
def run_mlp_pytorch(
    X_all,
    y_all,
    title,
    hidden_grid=[(512,), (512, 256)],
    lr_grid=[1e-3, 3e-4],
    batch_size=128,
    epochs=25,
    random_state=42,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Encode labels
    le = LabelEncoder()
    y_enc = le.fit_transform(y_all)

    # Train / test split
    X_train, X_test, y_train, y_test = train_test_split(
        X_all, y_enc,
        test_size=0.2,
        stratify=y_enc,
        random_state=random_state
    )

    # Standardize features (important for MLP)
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    # Compute class weights
    class_counts = np.bincount(y_train)
    class_weights = len(y_train) / (len(class_counts) * class_counts)
    class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)

    train_ds = NumpyDataset(X_train, y_train)
    test_ds = NumpyDataset(X_test, y_test)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size)

    input_dim = X_train.shape[1]
    num_classes = len(le.classes_)

    print("\n" + "=" * 80)
    print(f"07.1: {title} (PyTorch MLP, class-weighted CE)")

    best_f1 = -1.0
    best_model = None

    for hidden_dims in hidden_grid:
        for lr in lr_grid:
            model = MLPClassifier(
                input_dim=input_dim,
                hidden_dims=hidden_dims,
                num_classes=num_classes,
                dropout=0.3,
            ).to(device)

            optimizer = torch.optim.Adam(model.parameters(), lr=lr)
            criterion = nn.CrossEntropyLoss(weight=class_weights)

            for epoch in range(epochs):
                train_epoch(model, train_loader, optimizer, criterion, device)

            f1, _, _ = eval_epoch(model, test_loader, device)
            print(f"hidden={hidden_dims}, lr={lr:.1e} → macro-F1={f1:.4f}")

            if f1 > best_f1:
                best_f1 = f1
                best_model = model

    # Final evaluation
    _, y_true, y_pred = eval_epoch(best_model, test_loader, device)

    print("\nBest macro-F1:", best_f1)
    print("Classes:", list(le.classes_))
    print("\nClassification report (test set):")
    print(classification_report(
        le.inverse_transform(y_true),
        le.inverse_transform(y_pred),
        target_names=le.classes_,
        zero_division=0
    ))

    print("Confusion matrix (rows=true, cols=pred):")
    print(confusion_matrix(
        le.inverse_transform(y_true),
        le.inverse_transform(y_pred)

  ))


## Running Experiments

This cell runs the 07.1 MLP model separately for emitter and context classification, mirroring the
07 experimental setup. Results are reported using per-class precision, recall, F1 scores, and
confusion matrices to facilitate direct comparison with the linear baseline.

In [None]:
run_mlp_pytorch(
    X_all,
    y_emitters_all,
    title="Emitter classification",
)

run_mlp_pytorch(
    X_all,
    y_contexts_all,
    title="Context classification",
)


A1: Emitter classification (PyTorch MLP, class-weighted CE)
hidden=(512,), lr=1.0e-03 → macro-F1=0.6717
hidden=(512,), lr=3.0e-04 → macro-F1=0.6880
hidden=(512, 256), lr=1.0e-03 → macro-F1=0.6654
hidden=(512, 256), lr=3.0e-04 → macro-F1=0.6812

Best macro-F1: 0.6879860535750391
Classes: [np.str_('111'), np.str_('210'), np.str_('211'), np.str_('215'), np.str_('216'), np.str_('220'), np.str_('226'), np.str_('228'), np.str_('230'), np.str_('231')]

Classification report (test set):
              precision    recall  f1-score   support

         111       0.79      0.85      0.82       200
         210       0.54      0.54      0.54       200
         211       0.60      0.56      0.58       200
         215       0.60      0.54      0.57       200
         216       0.65      0.66      0.65       200
         220       0.60      0.63      0.61       200
         226       0.86      0.85      0.85       200
         228       0.86      0.84      0.85       200
         230       0.71     