# 06 – Baseline Classification with AST Features

This notebook builds simple baseline classifiers for the bat vocalization dataset,
using pre-computed **AST embeddings** from `05_Tokenization_Strategies.ipynb` and
labels from `annotations_10k.csv`.

We train logistic-regression baselines for:
- **Emitter classification** (which bat emitted the call)
- **Context classification** (behavioral context)


In [1]:
from __future__ import annotations

from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler

# Paths (assume this notebook is run from starter_code/)
ROOT = Path.cwd().resolve()
DERIVED_DIR = ROOT / 'derived'
AST_DIR = DERIVED_DIR / 'ast_features'
ANNOT_PATH = ROOT / 'annotations_10k.csv'

ROOT, AST_DIR, ANNOT_PATH


(PosixPath('/Users/mahikacalyanakoti/Downloads/College/Year4/Year4Sem1/ESE5460/bat-llm/starter_code'),
 PosixPath('/Users/mahikacalyanakoti/Downloads/College/Year4/Year4Sem1/ESE5460/bat-llm/starter_code/derived/ast_features'),
 PosixPath('/Users/mahikacalyanakoti/Downloads/College/Year4/Year4Sem1/ESE5460/bat-llm/starter_code/annotations_10k.csv'))

## Load annotations

We load the 10k annotations file and sanity-check that the key columns exist.

In [None]:
def load_annotations(path: Path) -> pd.DataFrame:
    if not path.exists():
        raise FileNotFoundError(f'annotations file not found at {path}')
    ann = pd.read_csv(path)
    # Expect at least these columns:
    # Emitter, File Name, FileID, Addressee, Context,
    # Emitter pre-vocalization action, Addressee pre-vocalization action,
    # Emitter post-vocalization action, Addressee post-vocalization action,
    # Start sample, End sample
    required_cols = ['Emitter', 'File Name', 'Context']
    missing = [c for c in required_cols if c not in ann.columns]
    if missing:
        raise ValueError(f'annotations_10k.csv missing columns: {missing}')
    return ann

ann = load_annotations(ANNOT_PATH)
ann.head()


## Collect AST features

Each file has a pooled AST embedding saved as `derived/ast_features/ast_<stem>.npy`.
We load these and align them with the labels from the annotations.

In [None]:
def collect_ast_features(ann: pd.DataFrame) -> tuple[np.ndarray, list[str], list[str]]:
    """Load one AST embedding per file and align with labels.

    Returns
    -------
    X : np.ndarray, shape (N, D)
        AST pooled embeddings.
    emitters : list[str]
        Emitter labels per example.
    contexts : list[str]
        Context labels per example.
    """

    stems = ann['File Name'].apply(lambda s: Path(str(s)).stem)

    X_list: list[np.ndarray] = []
    emitters: list[str] = []
    contexts: list[str] = []

    missing = 0
    for fn, stem, emitter, ctx in zip(ann['File Name'], stems, ann['Emitter'], ann['Context']):
        ast_path = AST_DIR / f'ast_{stem}.npy'
        if not ast_path.exists():
            missing += 1
            continue
        vec = np.load(ast_path)
        # Expect 1D vector; ensure that.
        vec = np.asarray(vec, dtype=np.float32).reshape(-1)
        X_list.append(vec)
        emitters.append(str(emitter))
        contexts.append(str(ctx))

    if not X_list:
        raise RuntimeError(
            'No AST feature files found. Make sure 05_Tokenization_Strategies.ipynb '
            'has been run to generate derived/ast_features/ast_*.npy.'
        )

    X = np.vstack(X_list)
    print(
        f'Loaded AST embeddings for {X.shape[0]} examples '
        f'(skipped {missing} without features); dim={X.shape[1]}'
    )
    return X, emitters, contexts

X_ast, y_emitters, y_contexts = collect_ast_features(ann)
X_ast.shape, len(y_emitters), len(y_contexts)


## Helper to train a simple baseline

We use logistic regression on top of standardized AST features and report standard
classification metrics.

In [None]:
def run_baseline(
    X: np.ndarray,
    y: list[str],
    task_name: str,
    test_size: float = 0.2,
    random_state: int = 42,
) -> None:
    """Train a simple logistic-regression baseline and print metrics."""

    y_arr = np.asarray(y, dtype=object)

    # Encode labels
    le = LabelEncoder()
    y_enc = le.fit_transform(y_arr)

    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y_enc,
        test_size=test_size,
        stratify=y_enc,
        random_state=random_state,
    )

    # Standardize features
    scaler = StandardScaler()
    X_train_s = scaler.fit_transform(X_train)
    X_test_s = scaler.transform(X_test)

    clf = LogisticRegression(max_iter=2000, n_jobs=-1)
    clf.fit(X_train_s, y_train)

    y_pred = clf.predict(X_test_s)

    print("\n" + "=" * 80)
    print(f"Baseline: {task_name} (AST + logistic regression)")
    print("Classes:", list(le.classes_))
    print("\nClassification report:")
    print(classification_report(y_test, y_pred, target_names=le.classes_))

    cm = confusion_matrix(y_test, y_pred)
    print("Confusion matrix (rows=true, cols=pred):")
    print(cm)

## Baseline 1 – Emitter classification

Predict which bat (emitter) produced each call from its AST embedding.

In [None]:
run_baseline(X_ast, y_emitters, task_name='Emitter classification')


## Baseline 2 – Context classification

Predict the behavioral context label from the same AST embedding.

In [None]:
if len(set(y_contexts)) > 1:
    run_baseline(X_ast, y_contexts, task_name='Context classification')
else:
    print('[info] Skipping context baseline: only one unique context label found.')
