# OOD Detection for Intent Classification — CLINC150
### Full pipeline: data → fine-tune BERT → evaluate all OOD methods

**Runtime:** GPU (Runtime → Change runtime type → T4 GPU)

## 0. Setup

In [None]:
!pip install -q transformers datasets scikit-learn accelerate

import subprocess, sys
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
print(result.stdout if result.returncode == 0 else 'No GPU detected — switch runtime to GPU!')

In [None]:
!git clone https://github.com/denmalbas007/clinc150-ood-detection.git
%cd clinc150-ood-detection
!python scripts/download_data.py

## 1. Imports & Config

In [None]:
import sys, json
sys.path.insert(0, 'src')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from tqdm.auto import tqdm

from dataset import load_clinc150, CLINC150Dataset
from models import IntentClassifier, MCDropoutClassifier
from metrics import compute_all_metrics
from methods.msp import compute_msp_scores
from methods.energy import compute_energy_scores
from methods.mahalanobis import fit_mahalanobis, compute_mahalanobis_scores
from methods.knn import fit_knn, compute_knn_scores
from methods.mc_dropout import compute_mc_dropout_scores

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

# Config
MODEL_NAME   = 'bert-base-uncased'
EPOCHS       = 5
BATCH_SIZE   = 64
LR           = 2e-5
MAX_LEN      = 64
SEED         = 42
CKPT_PATH    = Path('checkpoints/best_model.pt')
CKPT_PATH.parent.mkdir(exist_ok=True)

torch.manual_seed(SEED)
np.random.seed(SEED)

## 2. Load Dataset

In [None]:
splits, label2id = load_clinc150()
num_classes = len(label2id)
print(f'Intent classes: {num_classes}')

for split, samples in splits.items():
    n_in  = sum(1 for *_, is_ood in samples if not is_ood)
    n_ood = sum(1 for *_, is_ood in samples if is_ood)
    print(f'  {split:5s}: {n_in:5d} in-domain | {n_ood:4d} OOD')

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def make_loader(split_name, shuffle=False):
    ds = CLINC150Dataset(splits[split_name], label2id, tokenizer, MAX_LEN)
    return DataLoader(ds, batch_size=BATCH_SIZE, shuffle=shuffle,
                      num_workers=2, pin_memory=True)

train_loader = make_loader('train', shuffle=True)
val_loader   = make_loader('val')
test_loader  = make_loader('test')

## 3. Fine-tune BERT

In [None]:
model = MCDropoutClassifier(MODEL_NAME, num_classes).to(DEVICE)
criterion = torch.nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=0.01)

total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=total_steps // 10,
    num_training_steps=total_steps
)

def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    total_loss, correct, total = 0.0, 0, 0
    ctx = torch.enable_grad() if train else torch.no_grad()
    with ctx:
        for batch in tqdm(loader, leave=False):
            ids   = batch['input_ids'].to(DEVICE)
            mask  = batch['attention_mask'].to(DEVICE)
            labels = batch['label'].to(DEVICE)
            # skip OOD rows
            keep = labels != -1
            if keep.sum() == 0: continue
            ids, mask, labels = ids[keep], mask[keep], labels[keep]

            logits = model(ids, mask)
            loss   = criterion(logits, labels)

            if train:
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                scheduler.step()

            total_loss += loss.item() * labels.size(0)
            correct    += (logits.argmax(-1) == labels).sum().item()
            total      += labels.size(0)

    return total_loss / total, correct / total

history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
best_val_acc = 0.0

for epoch in range(1, EPOCHS + 1):
    tr_loss, tr_acc = run_epoch(train_loader, train=True)
    vl_loss, vl_acc = run_epoch(val_loader,   train=False)

    history['train_loss'].append(tr_loss)
    history['val_loss'].append(vl_loss)
    history['train_acc'].append(tr_acc)
    history['val_acc'].append(vl_acc)

    print(f'Epoch {epoch}/{EPOCHS} | '
          f'Train loss={tr_loss:.4f} acc={tr_acc:.4f} | '
          f'Val   loss={vl_loss:.4f} acc={vl_acc:.4f}')

    if vl_acc > best_val_acc:
        best_val_acc = vl_acc
        torch.save({
            'model_state_dict': model.state_dict(),
            'label2id': label2id,
            'model_name': MODEL_NAME,
            'num_classes': num_classes,
            'val_acc': vl_acc,
        }, CKPT_PATH)
        print(f'  ✓ Saved checkpoint (val_acc={vl_acc:.4f})')

print(f'\nBest val accuracy: {best_val_acc:.4f}')

In [None]:
# Training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
epochs_x = range(1, EPOCHS + 1)

ax1.plot(epochs_x, history['train_loss'], label='Train')
ax1.plot(epochs_x, history['val_loss'],   label='Val')
ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss'); ax1.set_title('Loss')
ax1.legend()

ax2.plot(epochs_x, history['train_acc'], label='Train')
ax2.plot(epochs_x, history['val_acc'],   label='Val')
ax2.set_xlabel('Epoch'); ax2.set_ylabel('Accuracy'); ax2.set_title('Accuracy')
ax2.legend()

plt.tight_layout()
plt.savefig('report/training_curves.pdf', bbox_inches='tight')
plt.show()

## 4. Load Best Checkpoint

In [None]:
ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
print(f"Loaded checkpoint (val_acc={ckpt['val_acc']:.4f})")

# Ground truth for test set
is_ood_gt = np.array([int(s[2]) for s in splits['test']])
print(f'Test: {(is_ood_gt==0).sum()} in-domain, {is_ood_gt.sum()} OOD')

## 5. OOD Detection — All Methods

In [None]:
# Need train_loader without shuffle for fitting Mahalanobis/KNN
train_loader_eval = make_loader('train', shuffle=False)

results = {}

# --- MSP ---
print('MSP...')
scores = compute_msp_scores(model, test_loader, DEVICE).numpy()
results['MSP'] = compute_all_metrics(scores, is_ood_gt)

# --- Energy ---
print('Energy...')
scores = compute_energy_scores(model, test_loader, DEVICE).numpy()
results['Energy'] = compute_all_metrics(scores, is_ood_gt)

# --- Mahalanobis ---
print('Mahalanobis (fitting)...')
class_means, precision = fit_mahalanobis(model, train_loader_eval, num_classes, DEVICE)
scores = compute_mahalanobis_scores(model, test_loader, class_means, precision, DEVICE).numpy()
results['Mahalanobis'] = compute_all_metrics(scores, is_ood_gt)

# --- KNN k=1 ---
print('KNN (k=1, fitting)...')
train_feats = fit_knn(model, train_loader_eval, DEVICE)
scores = compute_knn_scores(model, test_loader, train_feats, DEVICE, k=1).numpy()
results['KNN (k=1)'] = compute_all_metrics(scores, is_ood_gt)

# --- KNN k=10 ---
print('KNN (k=10)...')
scores = compute_knn_scores(model, test_loader, train_feats, DEVICE, k=10).numpy()
results['KNN (k=10)'] = compute_all_metrics(scores, is_ood_gt)

# --- MC Dropout ---
print('MC Dropout (20 passes)...')
scores = compute_mc_dropout_scores(model, test_loader, DEVICE, n_passes=20).numpy()
results['MC Dropout'] = compute_all_metrics(scores, is_ood_gt)

print('Done!')

## 6b. Per-Class KNN — Our Method

Extension of Sun et al. (2022): retrieve neighbours **only from the predicted class** instead of the full training bank.
An OOD sample is flagged if it is far from the cluster of its own predicted class, eliminating false negatives caused by proximity to irrelevant classes.

In [None]:
from methods.per_class_knn import fit_per_class_knn, compute_per_class_knn_scores

print('Fitting Per-Class KNN...')
class_banks = fit_per_class_knn(model, train_loader_eval, num_classes, DEVICE)
print(f'  Banks fitted for {len(class_banks)} classes')

print('Scoring test set...')
pc_knn_scores = compute_per_class_knn_scores(
    model, test_loader, class_banks, DEVICE, k=1
).numpy()

pc_knn_metrics = compute_all_metrics(pc_knn_scores, test_is_ood)
results['Per-Class KNN (ours)'] = pc_knn_metrics
print(f"  AUROC={pc_knn_metrics['AUROC']:.4f}  "
      f"FPR@95={pc_knn_metrics['FPR@95TPR']:.4f}  "
      f"AUPR={pc_knn_metrics['AUPR']:.4f}")


## 7. Results — All Methods

Comparison table including baselines and our Per-Class KNN method.

In [None]:
# Published SotA for reference
published = {
    'MSP (Hendrycks 2017)':        {'AUROC': 0.8236, 'FPR@95TPR': 0.5782, 'AUPR': None},
    'Energy (Liu 2020)':           {'AUROC': 0.8844, 'FPR@95TPR': 0.4620, 'AUPR': None},
    'Mahalanobis (Podolskiy 2021)':{'AUROC': 0.9676, 'FPR@95TPR': 0.1832, 'AUPR': None},
    'KNN (Sun 2022)':              {'AUROC': 0.9530, 'FPR@95TPR': 0.2210, 'AUPR': None},
}

print(f"{'Method':<35} {'AUROC':>8} {'FPR@95':>8} {'AUPR':>8}")
print('─' * 62)
print('Published:')
for method, m in published.items():
    aupr_s = '  N/A  ' if m['AUPR'] is None else f"{m['AUPR']:.4f}"
    print(f"  {method:<33} {m['AUROC']:>8.4f} {m['FPR@95TPR']:>8.4f} {aupr_s:>8}")
print('Ours:')
for method, m in results.items():
    print(f"  {method:<33} {m['AUROC']:>8.4f} {m['FPR@95TPR']:>8.4f} {m['AUPR']:>8.4f}")

# Save
with open('results.json', 'w') as f:
    json.dump(results, f, indent=2)
print('\nSaved to results.json')

## 7. Visualization

In [None]:
from sklearn.metrics import roc_curve

# Re-collect raw scores for ROC curves
raw_scores = {}

with torch.no_grad():
    raw_scores['MSP']         = compute_msp_scores(model, test_loader, DEVICE).numpy()
    raw_scores['Energy']      = compute_energy_scores(model, test_loader, DEVICE).numpy()
    raw_scores['Mahalanobis'] = compute_mahalanobis_scores(model, test_loader, class_means, precision, DEVICE).numpy()
    raw_scores['KNN (k=1)']   = compute_knn_scores(model, test_loader, train_feats, DEVICE, k=1).numpy()
    raw_scores['MC Dropout']  = compute_mc_dropout_scores(model, test_loader, DEVICE, n_passes=20).numpy()

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# ROC curves
colors = ['#e41a1c','#377eb8','#4daf4a','#984ea3','#ff7f00']
for (method, scores), color in zip(raw_scores.items(), colors):
    fpr, tpr, _ = roc_curve(is_ood_gt, scores)
    auroc_val = results[method]['AUROC']
    axes[0].plot(fpr, tpr, label=f'{method} ({auroc_val:.3f})', color=color, lw=1.8)
axes[0].plot([0,1],[0,1],'k--',lw=1)
axes[0].set_xlabel('FPR'); axes[0].set_ylabel('TPR')
axes[0].set_title('ROC Curves — OOD Detection on CLINC150')
axes[0].legend(fontsize=9)

# AUROC bar chart with SotA reference
methods_all  = list(results.keys())
aurocs_ours  = [results[m]['AUROC'] * 100 for m in methods_all]
sota_auroc   = 96.76  # Podolskiy 2021

bars = axes[1].bar(methods_all, aurocs_ours,
                   color=['#4daf4a' if v >= sota_auroc else '#377eb8' for v in aurocs_ours])
axes[1].axhline(sota_auroc, color='red', linestyle='--', lw=1.5,
                label=f'SotA (Podolskiy 2021): {sota_auroc:.2f}%')
for bar, val in zip(bars, aurocs_ours):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.3,
                 f'{val:.2f}', ha='center', va='bottom', fontsize=9)
axes[1].set_ylabel('AUROC (%)')
axes[1].set_title('AUROC Comparison (green = beats SotA)')
axes[1].set_xticklabels(methods_all, rotation=25, ha='right')
axes[1].legend()
axes[1].set_ylim(70, 100)

plt.tight_layout()
plt.savefig('report/ood_results.pdf', bbox_inches='tight', dpi=150)
plt.show()

## 8. Update Report with Real Numbers

In [None]:
# Auto-fill result numbers into the LaTeX report
tex_path = Path('report/report.tex')
tex = tex_path.read_text()

replacements = {
    'MSP (ours)':         'MSP',
    'Energy (ours)':      'Energy',
    'Mahalanobis (ours)': 'Mahalanobis',
    '$k$-NN k=1 (ours)':  'KNN (k=1)',
    'MC Dropout (ours)':  'MC Dropout',
}

for tex_name, key in replacements.items():
    if key not in results:
        continue
    m = results[key]
    auroc_s = f"{m['AUROC']*100:.2f}"
    fpr_s   = f"{m['FPR@95TPR']*100:.2f}"
    aupr_s  = f"{m['AUPR']*100:.2f}"
    # Replace placeholder XX.XX in the row that starts with tex_name
    old = f'{tex_name} & XX.XX & XX.XX & XX.XX'
    new = f'{tex_name} & {auroc_s} & {fpr_s} & {aupr_s}'
    tex = tex.replace(old, new)

tex_path.write_text(tex)
print('report/report.tex updated with real numbers!')

# Print final table
df = pd.DataFrame(results).T * 100
df.index.name = 'Method'
print(df.round(2).to_string())

## 9. Save to Google Drive (optional)

In [None]:
# Uncomment to mount Drive and save checkpoint
# from google.colab import drive
# drive.mount('/content/drive')
# import shutil
# shutil.copy('checkpoints/best_model.pt', '/content/drive/MyDrive/clinc150_best_model.pt')
# shutil.copy('results.json', '/content/drive/MyDrive/clinc150_results.json')
# print('Saved to Google Drive')
print('Uncomment the block above to save to Google Drive')

## 10. Commit results back to GitHub

In [None]:
# Set your git identity first
!git config user.email "you@example.com"
!git config user.name "Your Name"

!git add results.json report/report.tex report/training_curves.pdf report/ood_results.pdf
!git commit -m "Add training results and figures"

# To push: uncomment and set your token
# import os
# token = 'ghp_YOUR_TOKEN'
# !git remote set-url origin https://{token}@github.com/denmalbas007/clinc150-ood-detection.git
# !git push origin main

## 11. Layer-wise Mahalanobis Analysis

We analyse which BERT layer produces the best OOD-discriminative features.
Podolskiy (2021) only used the last layer — we sweep all 13 hidden states.

In [None]:
from methods.mahalanobis import layer_wise_mahalanobis

print('Running layer-wise Mahalanobis (this takes ~3-5 min)...')
layer_scores, layer_metrics = layer_wise_mahalanobis(
    model, train_loader_eval, test_loader,
    is_ood_gt, num_classes, DEVICE, num_layers=12
)

# Table
print(f"\n{'Layer':>8} {'AUROC':>8} {'FPR@95':>8} {'AUPR':>8}")
print('─' * 38)
for layer_idx in sorted(layer_metrics.keys()):
    m = layer_metrics[layer_idx]
    label = f'{layer_idx} (last)' if layer_idx == 13 else str(layer_idx)
    print(f"  {label:>8} {m['AUROC']:>8.4f} {m['FPR@95TPR']:>8.4f} {m['AUPR']:>8.4f}")

best_layer = max(layer_metrics, key=lambda l: layer_metrics[l]['AUROC'])
print(f"\nBest layer: {best_layer}  AUROC={layer_metrics[best_layer]['AUROC']:.4f}")

# Save layer metrics
with open('layer_metrics.json', 'w') as f:
    json.dump({str(k): v for k, v in layer_metrics.items()}, f, indent=2)
print('Saved to layer_metrics.json')

In [None]:
# Layer-wise visualization
layers = sorted(layer_metrics.keys())
aurocs  = [layer_metrics[l]['AUROC'] * 100  for l in layers]
fprs    = [layer_metrics[l]['FPR@95TPR'] * 100 for l in layers]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 4))

ax1.plot(layers, aurocs, 'o-', color='#377eb8', lw=2, ms=6)
ax1.axhline(96.76, color='red', ls='--', lw=1.5, label='Podolskiy 2021 (last layer)')
ax1.set_xlabel('BERT Layer'); ax1.set_ylabel('AUROC (%)')
ax1.set_title('AUROC by Layer — Mahalanobis Distance')
ax1.set_xticks(layers)
ax1.legend(); ax1.grid(alpha=0.3)

ax2.plot(layers, fprs, 's-', color='#e41a1c', lw=2, ms=6)
ax2.axhline(18.32, color='red', ls='--', lw=1.5, label='Podolskiy 2021 (last layer)')
ax2.set_xlabel('BERT Layer'); ax2.set_ylabel('FPR@95TPR (%)')
ax2.set_title('FPR@95TPR by Layer — Mahalanobis Distance')
ax2.set_xticks(layers)
ax2.legend(); ax2.grid(alpha=0.3)

plt.tight_layout()
plt.savefig('report/layer_analysis.pdf', bbox_inches='tight', dpi=150)
plt.show()
print(f'Best layer by AUROC: {best_layer}')