# MLP Probe — All Encoder Models

Extract frozen embeddings from each encoder model and train an MLP classifier on top.
Runs all models sequentially and produces a comparison table.

**Instructions:**
1. Set runtime to **GPU** (Runtime → Change runtime type → T4 GPU)
2. Upload `JCLCv2/` folder and `index.csv` to your Google Drive under `NNP/JCLCv2/`
3. Run all cells

In [None]:
from google.colab import drive
drive.mount('/content/drive')

DRIVE_DATA_DIR = '/content/drive/MyDrive/NNP/JCLCv2'

In [None]:
!pip install -q transformers jieba scikit-learn tqdm

In [None]:
import gc
import time
import traceback
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm, trange
from transformers import AutoModel, AutoTokenizer

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')
if DEVICE.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')

## Configuration

In [None]:
ENCODER_MODELS = [
    'google-bert/bert-base-chinese',
    'google-bert/bert-base-uncased',
    'google-bert/bert-large-uncased',
    'google-bert/bert-base-multilingual-cased',
    'hfl/chinese-roberta-wwm-ext',
    'voidful/albert_chinese_base',
    'shibing624/text2vec-base-chinese',
    'jinaai/jina-embeddings-v2-base-zh',
    'jinaai/jina-embeddings-v3',
    'Qwen/Qwen3-Embedding-0.6B',
    'Qwen/Qwen3-Embedding-4B',
    'DMetaSoul/Dmeta-embedding-zh-small',
]

MAX_LENGTH = 512
BATCH_SIZE = 32
RANDOM_SEED = 42

DATA_DIR = Path(DRIVE_DATA_DIR)
INDEX_CSV = DATA_DIR / 'index.csv'
RESULTS_DIR = Path('/content/results')
RESULTS_DIR.mkdir(exist_ok=True)

# Smaller batch sizes for large models (adjust if OOM)
BATCH_OVERRIDES = {
    'google-bert/bert-large-uncased': 8,
    'jinaai/jina-embeddings-v3': 16,
    'Qwen/Qwen3-Embedding-4B': 4,
}

## Load & Split Data

In [None]:
def load_corpus(data_dir, index_csv):
    df = pd.read_csv(
        index_csv, header=None,
        names=['doc_id', 'context', 'native_language', 'gender'],
    )
    texts = []
    for doc_id in tqdm(df['doc_id'], desc='Loading texts'):
        path = data_dir / f'{doc_id}.txt'
        texts.append(path.read_text(encoding='utf-8').strip())
    df['text'] = texts
    return df


def stratified_split(df, seed=42):
    df = df.dropna(subset=['native_language'])
    counts = df['native_language'].value_counts()
    rare_langs = counts[counts < 3].index
    df_rare = df[df['native_language'].isin(rare_langs)]
    df_main = df[~df['native_language'].isin(rare_langs)]
    df_main = df_main[df_main['native_language'].map(df_main['native_language'].value_counts()) > 1]

    df_train, df_valtest = train_test_split(
        df_main, test_size=0.2, random_state=seed,
        stratify=df_main['native_language'],
    )
    df_valtest = df_valtest[df_valtest['native_language'].map(df_valtest['native_language'].value_counts()) > 1]
    df_val, df_test = train_test_split(
        df_valtest, test_size=0.5, random_state=seed,
        stratify=df_valtest['native_language'],
    )
    df_train = pd.concat([df_train, df_rare], ignore_index=True)
    df_train = df_train.sample(frac=1, random_state=seed).reset_index(drop=True)
    return df_train, df_val.reset_index(drop=True), df_test.reset_index(drop=True)


df = load_corpus(DATA_DIR, INDEX_CSV)
le = LabelEncoder()
df['label'] = le.fit_transform(df['native_language'])
train_df, val_df, test_df = stratified_split(df, RANDOM_SEED)
label_names = list(le.classes_)

train_texts = train_df['text'].tolist()
val_texts = val_df['text'].tolist()
test_texts = test_df['text'].tolist()
y_train = train_df['label'].values
y_val = val_df['label'].values
y_test = test_df['label'].values

print(f'Train: {len(train_df)}  Val: {len(val_df)}  Test: {len(test_df)}  Classes: {len(label_names)}')

## Helper Functions

In [None]:
class TextDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __len__(self):
        return self.encodings['input_ids'].shape[0]
    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.encodings.items()}


def tokenize_texts(texts, tokenizer, max_length, desc):
    batch_size = 256
    all_ids, all_attn, all_ttype = [], [], []
    for i in trange(0, len(texts), batch_size, desc=desc):
        enc = tokenizer(
            texts[i:i+batch_size], truncation=True,
            padding='max_length', max_length=max_length,
            return_tensors='pt',
        )
        all_ids.append(enc['input_ids'])
        all_attn.append(enc['attention_mask'])
        if 'token_type_ids' in enc:
            all_ttype.append(enc['token_type_ids'])
    result = {'input_ids': torch.cat(all_ids), 'attention_mask': torch.cat(all_attn)}
    if all_ttype:
        result['token_type_ids'] = torch.cat(all_ttype)
    return result


@torch.no_grad()
def extract_embeddings(model, encodings, batch_size, desc):
    model.eval()
    dataset = TextDataset(encodings)
    loader = DataLoader(dataset, batch_size=batch_size)
    all_embeds = []
    loader_iter = iter(loader)
    for _ in trange(len(loader), desc=desc):
        batch = next(loader_iter)
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        attn = batch['attention_mask']
        outputs = model(**batch)
        hidden = outputs.last_hidden_state
        mask = attn.unsqueeze(-1).float()
        pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
        all_embeds.append(pooled.cpu().numpy())
    return np.concatenate(all_embeds)

## Run All Probes

In [None]:
rows = []

for i, model_name in enumerate(ENCODER_MODELS):
    print(f'\n{"=" * 60}')
    print(f'[{i+1}/{len(ENCODER_MODELS)}] {model_name}')
    print(f'{"=" * 60}')

    bs = BATCH_OVERRIDES.get(model_name, BATCH_SIZE)
    t0 = time.time()

    try:
        # Load
        print(f'  Loading model (batch_size={bs})...')
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        encoder = AutoModel.from_pretrained(model_name, trust_remote_code=True).to(DEVICE)

        # Tokenize
        train_enc = tokenize_texts(train_texts, tokenizer, MAX_LENGTH, '  Tok train')
        val_enc = tokenize_texts(val_texts, tokenizer, MAX_LENGTH, '  Tok val')
        test_enc = tokenize_texts(test_texts, tokenizer, MAX_LENGTH, '  Tok test')

        # Extract
        X_train = extract_embeddings(encoder, train_enc, bs, '  Embed train')
        X_val = extract_embeddings(encoder, val_enc, bs, '  Embed val')
        X_test = extract_embeddings(encoder, test_enc, bs, '  Embed test')
        emb_dim = X_train.shape[1]
        print(f'  Embedding dim: {emb_dim}')

        # Free GPU
        del encoder, tokenizer, train_enc, val_enc, test_enc
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # Train MLP
        print('  Training MLP...')
        clf = MLPClassifier(
            hidden_layer_sizes=(512, 256), activation='relu',
            max_iter=200, early_stopping=True, validation_fraction=0.1,
            random_state=RANDOM_SEED, verbose=False,
        )
        clf.fit(X_train, y_train)

        # Evaluate
        val_pred = clf.predict(X_val)
        test_pred = clf.predict(X_test)
        elapsed = time.time() - t0

        row = {
            'model': model_name,
            'emb_dim': emb_dim,
            'val_acc': accuracy_score(y_val, val_pred),
            'val_f1': f1_score(y_val, val_pred, average='macro', zero_division=0),
            'test_acc': accuracy_score(y_test, test_pred),
            'test_f1': f1_score(y_test, test_pred, average='macro', zero_division=0),
            'test_wf1': f1_score(y_test, test_pred, average='weighted', zero_division=0),
            'time_s': f'{elapsed:.0f}',
        }
        rows.append(row)
        print(f'  Done in {elapsed:.0f}s — test_acc={row["test_acc"]:.4f}, test_f1={row["test_f1"]:.4f}')

        del clf, X_train, X_val, X_test
        gc.collect()

    except Exception:
        elapsed = time.time() - t0
        print(f'  FAILED after {elapsed:.0f}s:')
        traceback.print_exc()
        rows.append({
            'model': model_name, 'emb_dim': None,
            'val_acc': None, 'val_f1': None,
            'test_acc': None, 'test_f1': None, 'test_wf1': None,
            'time_s': 'FAIL',
        })
        # Make sure GPU is cleared even on failure
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

## Results Table

In [None]:
results = pd.DataFrame(rows)
display(results.style.format({
    'val_acc': '{:.4f}', 'val_f1': '{:.4f}',
    'test_acc': '{:.4f}', 'test_f1': '{:.4f}', 'test_wf1': '{:.4f}',
}, na_rep='FAIL').set_caption('MLP Probe Results (frozen embeddings)'))

# Save CSV
csv_path = RESULTS_DIR / 'probe_results.csv'
results.to_csv(csv_path, index=False)
print(f'Saved {csv_path}')

In [None]:
# ── Bar chart comparison ──────────────────────────────────────────────
import matplotlib.pyplot as plt

valid = results.dropna(subset=['test_f1']).sort_values('test_f1', ascending=True)
fig, ax = plt.subplots(figsize=(10, max(4, len(valid) * 0.5)))
short_names = [m.split('/')[-1] for m in valid['model']]
ax.barh(short_names, valid['test_f1'], color='steelblue')
ax.set_xlabel('Test Macro-F1')
ax.set_title('MLP Probe — Frozen Embeddings')
for i, v in enumerate(valid['test_f1']):
    ax.text(v + 0.005, i, f'{v:.3f}', va='center', fontsize=9)
plt.tight_layout()
plt.show()

In [None]:
# ── LaTeX table ───────────────────────────────────────────────────────
valid = results.dropna(subset=['test_acc'])
lines = [
    r'\begin{table}[htbp]',
    r'\centering',
    r'\caption{MLP probe results (frozen embeddings).}',
    r'\label{tab:probe-results}',
    r'\resizebox{\textwidth}{!}{%',
    r'\begin{tabular}{lrrrrrr}',
    r'\toprule',
    r'\textbf{Model} & \textbf{Dim} & \textbf{Val Acc} & \textbf{Val F1} & \textbf{Test Acc} & \textbf{Test F1} & \textbf{Time (s)} \\',
    r'\midrule',
]
for _, r in valid.iterrows():
    name = r['model'].replace('_', r'\_')
    lines.append(
        f"{name} & {int(r['emb_dim'])} & {r['val_acc']:.4f} & "
        f"{r['val_f1']:.4f} & {r['test_acc']:.4f} & "
        f"{r['test_f1']:.4f} & {r['time_s']} \\\\"
    )
lines += [r'\bottomrule', r'\end{tabular}}', r'\end{table}']

tex = '\n'.join(lines)
print(tex)

tex_path = RESULTS_DIR / 'probe_results.tex'
tex_path.write_text(tex)
print(f'\nSaved {tex_path}')

In [None]:
# Copy results to Drive for persistence
drive_results = Path(DRIVE_DATA_DIR).parent / 'results'
drive_results.mkdir(exist_ok=True)
import shutil
for f in RESULTS_DIR.iterdir():
    shutil.copy2(f, drive_results / f.name)
print(f'Copied results to {drive_results}')