# 07 â€“ Improved Classification with Grid Search

This notebook improves on the simple baselines in `06_Baseline_Classification.ipynb` by:

- Using **richer feature representations** (AST embeddings plus discrete token histograms).
- Performing **hyperparameter tuning** via `GridSearchCV` instead of a single fixed logistic-regression configuration.
- Using **separate classification functions** for emitter and context prediction, with appropriate handling of class imbalance.

In [1]:
from __future__ import annotations

from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import GridSearchCV, StratifiedKFold, train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder, StandardScaler

# Paths (assume this notebook is run from starter_code/)
ROOT = Path.cwd().resolve()
DATA_DIR = ROOT / "data"
DERIVED_DIR = ROOT / "derived"
AUDIO_DIR = DATA_DIR / "audio"
MELS_48K_DIR = DERIVED_DIR / "mels_48k"
TOKENS_DIR = DERIVED_DIR / "tokens"
AST_DIR = DERIVED_DIR / "ast_features"
KMEANS_DIR = TOKENS_DIR / "k_means"
VQ_TOKENS_DIR = TOKENS_DIR / "vqvae"
ANNOT_PATH = ROOT / "annotations_10k.csv"

ROOT, AST_DIR, KMEANS_DIR, VQ_TOKENS_DIR, ANNOT_PATH


(PosixPath('/Users/mahikacalyanakoti/Downloads/College/Year4/Year4Sem1/ESE5460/bat-llm/starter_code'),
 PosixPath('/Users/mahikacalyanakoti/Downloads/College/Year4/Year4Sem1/ESE5460/bat-llm/starter_code/derived/ast_features'),
 PosixPath('/Users/mahikacalyanakoti/Downloads/College/Year4/Year4Sem1/ESE5460/bat-llm/starter_code/derived/tokens/k_means'),
 PosixPath('/Users/mahikacalyanakoti/Downloads/College/Year4/Year4Sem1/ESE5460/bat-llm/starter_code/derived/tokens/vqvae'),
 PosixPath('/Users/mahikacalyanakoti/Downloads/College/Year4/Year4Sem1/ESE5460/bat-llm/starter_code/annotations_10k.csv'))

## Load annotations

We reuse the 10k annotations file and check that the key columns are present.


In [2]:
def load_annotations(path: Path) -> pd.DataFrame:
    if not path.exists():
        raise FileNotFoundError(f"annotations file not found at {path}")
    ann = pd.read_csv(path)
    required_cols = ["Emitter", "File Name", "Context"]
    missing = [c for c in required_cols if c not in ann.columns]
    if missing:
        raise ValueError(f"annotations_10k.csv missing columns: {missing}")
    return ann

ann = load_annotations(ANNOT_PATH)
ann.head()


Unnamed: 0,Emitter,File Name,FileID,Addressee,Context,Emitter pre-vocalization action,Addressee pre-vocalization action,Emitter post-vocalization action,Addressee post-vocalization action,Start sample,End sample
0,216,69809.wav,233219,221,11,2,3,3,3,1,590672
1,215,71889.wav,237330,220,12,2,2,3,3,1,328528
2,216,46690.wav,173649,231,12,2,2,3,3,1,467792
3,230,85411.wav,268012,221,12,2,2,3,3,1,475984
4,215,45609.wav,170616,220,12,2,2,3,3,1,336720


## Collect combined features

We build one feature vector per annotated example by combining:

- **AST pooled embeddings** (`derived/ast_features/ast_<stem>.npy`).
- **wav2vec2 + k-means token histograms** (`derived/tokens/k_means/w2v_kmeans_<stem>.npy`).
- **VQ-VAE code histograms** (`derived/tokens/vqvae/vqvae_<stem>.npy`).

You can toggle which components are used; by default we use all three for richer representations.


In [3]:
def _load_ast_vector(stem: str) -> np.ndarray | None:
    ast_path = AST_DIR / f"ast_{stem}.npy"
    if not ast_path.exists():
        return None
    vec = np.load(ast_path)
    return np.asarray(vec, dtype=np.float32).reshape(-1)


def _load_kmeans_hist(stem: str, n_clusters: int = 128) -> np.ndarray | None:
    tok_path = KMEANS_DIR / f"w2v_kmeans_{stem}.npy"
    if not tok_path.exists():
        return None
    tokens = np.load(tok_path).astype(int)
    hist = np.bincount(tokens, minlength=n_clusters).astype(np.float32)
    total = hist.sum()
    if total > 0:
        hist /= total  # normalize to frequencies
    return hist


def _load_vqvae_hist(stem: str, n_codes: int = 256) -> np.ndarray | None:
    tok_path = VQ_TOKENS_DIR / f"vqvae_{stem}.npy"
    if not tok_path.exists():
        return None
    tokens = np.load(tok_path).astype(int)
    hist = np.bincount(tokens, minlength=n_codes).astype(np.float32)
    total = hist.sum()
    if total > 0:
        hist /= total
    return hist


def collect_features(
    ann: pd.DataFrame,
    use_ast: bool = True,
    use_kmeans_tokens: bool = True,
    use_vqvae_tokens: bool = True,
) -> Tuple[np.ndarray, List[str], List[str]]:
    """Collect feature matrix and labels for all examples with available features."""

    stems = ann["File Name"].apply(lambda s: Path(str(s)).stem)

    X_list: List[np.ndarray] = []
    emitters: List[str] = []
    contexts: List[str] = []

    missing_any = 0

    for fn, stem, emitter, ctx in zip(
        ann["File Name"], stems, ann["Emitter"], ann["Context"],
    ):
        parts: List[np.ndarray] = []

        if use_ast:
            ast_vec = _load_ast_vector(stem)
            if ast_vec is None:
                missing_any += 1
                continue
            parts.append(ast_vec)

        if use_kmeans_tokens:
            km_hist = _load_kmeans_hist(stem)
            if km_hist is None:
                missing_any += 1
                continue
            parts.append(km_hist)

        if use_vqvae_tokens:
            vq_hist = _load_vqvae_hist(stem)
            if vq_hist is None:
                missing_any += 1
                continue
            parts.append(vq_hist)

        if not parts:
            # No features requested
            continue

        feat_vec = np.concatenate(parts).astype(np.float32)
        X_list.append(feat_vec)
        emitters.append(str(emitter))
        contexts.append(str(ctx))

    if not X_list:
        raise RuntimeError(
            "No feature vectors constructed. Make sure 05_Tokenization_Strategies.ipynb "
            "has been run to generate AST embeddings and token files."
        )

    X = np.vstack(X_list)
    print(
        f"Built combined features for {X.shape[0]} examples; "
        f"dim={X.shape[1]} (skipped {missing_any} with missing components)."
    )
    return X, emitters, contexts


X_all, y_emitters_all, y_contexts_all = collect_features(
    ann,
    use_ast=True,
    use_kmeans_tokens=True,
    use_vqvae_tokens=True,
)
X_all.shape, len(y_emitters_all), len(y_contexts_all)


Built combined features for 10000 examples; dim=1152 (skipped 0 with missing components).


((10000, 1152), 10000, 10000)

## Helper: emitter classification with grid search

We fit a logistic-regression classifier on top of standardized features and use `GridSearchCV`
with stratified folds to pick the best hyperparameters for the emitter task.


In [8]:
def run_emitter_grid_search(
    X: np.ndarray,
    y_emitters: List[str],
    test_size: float = 0.2,
    random_state: int = 42,
    n_splits: int = 5,
) -> None:
    """Train an improved emitter classifier with hyperparameter search."""

    y_arr = np.asarray(y_emitters, dtype=object)

    le = LabelEncoder()
    y_enc = le.fit_transform(y_arr)

    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y_enc,
        test_size=test_size,
        stratify=y_enc,
        random_state=random_state,
    )

    pipe = Pipeline(
        steps=[
            ("scaler", StandardScaler()),
            (
                "clf",
                LogisticRegression(
                    max_iter=2000,
                    n_jobs=-1,
                ),
            ),
        ]
    )

    param_grid = {
        "clf__C": [0.1, 1.0, 10.0],
        "clf__class_weight": [None, "balanced"],
        "clf__solver": ["lbfgs"],
        "clf__penalty": ["l2"],
    }

    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

    grid = GridSearchCV(
        pipe,
        param_grid=param_grid,
        scoring="f1_macro",
        n_jobs=-1,
        cv=cv,
        refit=True,
        verbose=1,
    )

    grid.fit(X_train, y_train)

    best_clf = grid.best_estimator_

    y_pred = best_clf.predict(X_test)

    print("\n" + "=" * 80)
    print("Improved baseline: Emitter classification (AST + tokens + logistic regression)")
    print("Best CV macro-F1:", grid.best_score_)
    print("Best params:", grid.best_params_)

    classes = le.classes_
    if classes is None:
        raise RuntimeError("LabelEncoder.classes_ is None; encoder was not fitted correctly.")

    print("Classes:", list(classes))
    print("\nClassification report (test set):")
    print(classification_report(y_test, y_pred, target_names=classes))

    cm = confusion_matrix(y_test, y_pred)
    print("Confusion matrix (rows=true, cols=pred):")
    print(cm)


## Helper: context classification with grid search

We define a separate function for the context task, using a similar pipeline but
emphasizing macro-F1 and allowing `class_weight='balanced'` to better handle
class imbalance.


In [9]:
def run_context_grid_search(
    X: np.ndarray,
    y_contexts: List[str],
    test_size: float = 0.2,
    random_state: int = 42,
    n_splits: int = 5,
) -> None:
    """Train an improved context classifier with hyperparameter search."""

    y_arr = np.asarray(y_contexts, dtype=object)

    le = LabelEncoder()
    y_enc = le.fit_transform(y_arr)

    # If there is only a single unique context label, skip training.
    classes = le.classes_
    if classes is None or len(classes) < 2:
        print("[info] Skipping context baseline: only one unique context label found.")
        return

    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y_enc,
        test_size=test_size,
        stratify=y_enc,
        random_state=random_state,
    )

    pipe = Pipeline(
        steps=[
            ("scaler", StandardScaler()),
            (
                "clf",
                LogisticRegression(
                    max_iter=2000,
                    n_jobs=-1,
                ),
            ),
        ]
    )

    param_grid = {
        "clf__C": [0.1, 1.0, 10.0],
        "clf__class_weight": [None, "balanced"],
        "clf__solver": ["lbfgs"],
        "clf__penalty": ["l2"],
    }

    cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

    grid = GridSearchCV(
        pipe,
        param_grid=param_grid,
        scoring="f1_macro",
        n_jobs=-1,
        cv=cv,
        refit=True,
        verbose=1,
    )

    grid.fit(X_train, y_train)

    best_clf = grid.best_estimator_

    y_pred = best_clf.predict(X_test)

    print("\n" + "=" * 80)
    print("Improved baseline: Context classification (AST + tokens + logistic regression)")
    print("Best CV macro-F1:", grid.best_score_)
    print("Best params:", grid.best_params_)

    classes = le.classes_
    if classes is None:
        raise RuntimeError("LabelEncoder.classes_ is None; encoder was not fitted correctly.")

    print("Classes:", list(classes))
    print("\nClassification report (test set):")
    print(classification_report(y_test, y_pred, target_names=classes))

    cm = confusion_matrix(y_test, y_pred)
    print("Confusion matrix (rows=true, cols=pred):")
    print(cm)


## Run improved baselines

Now we run the improved emitter and context classifiers on the combined features.


In [10]:
# Emitter classification
run_emitter_grid_search(X_all, y_emitters_all)

# Context classification
run_context_grid_search(X_all, y_contexts_all)


Fitting 5 folds for each of 6 candidates, totalling 30 fits

Improved baseline: Emitter classification (AST + tokens + logistic regression)
Best CV macro-F1: 0.6333754654976389
Best params: {'clf__C': 0.1, 'clf__class_weight': None, 'clf__penalty': 'l2', 'clf__solver': 'lbfgs'}
Classes: ['111', '210', '211', '215', '216', '220', '226', '228', '230', '231']

Classification report (test set):
              precision    recall  f1-score   support

         111       0.78      0.82      0.80       200
         210       0.46      0.46      0.46       200
         211       0.52      0.54      0.53       200
         215       0.56      0.49      0.52       200
         216       0.61      0.65      0.63       200
         220       0.54      0.57      0.55       200
         226       0.83      0.86      0.85       200
         228       0.88      0.81      0.85       200
         230       0.69      0.64      0.66       200
         231       0.63      0.67      0.65       200

    accura

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
