# Transformer Fine-Tuning for Chinese NLI

Fine-tune encoder models on the JCLCv2 corpus for Native Language Identification.

**Instructions:**
1. Set runtime to **GPU** (Runtime → Change runtime type → T4 GPU)
2. Upload `JCLCv2/` folder and `index.csv` to your Google Drive under `NNP/JCLCv2/`
3. Run all cells

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Path to your data on Drive — adjust if needed
DRIVE_DATA_DIR = '/content/drive/MyDrive/NNP/JCLCv2'

In [None]:
!pip install -q transformers jieba scikit-learn tqdm

In [None]:
import os
import time
import json
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm, trange
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')
if DEVICE.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')

## Configuration

In [None]:
# ── Choose your model ──────────────────────────────────────────────────
MODEL_NAME = 'google-bert/bert-base-chinese'  # change this to try other models

# Available models:
# 'google-bert/bert-base-chinese'
# 'google-bert/bert-base-uncased'
# 'google-bert/bert-large-uncased'
# 'google-bert/bert-base-multilingual-cased'
# 'hfl/chinese-roberta-wwm-ext'
# 'voidful/albert_chinese_base'

# ── Hyperparameters ────────────────────────────────────────────────────
MAX_LENGTH = 512
BATCH_SIZE = 16
LR = 2e-5
EPOCHS = 5
WARMUP_RATIO = 0.1
PATIENCE = 2
RANDOM_SEED = 42

# ── Data paths ─────────────────────────────────────────────────────────
DATA_DIR = Path(DRIVE_DATA_DIR)
INDEX_CSV = DATA_DIR / 'index.csv'
RESULTS_DIR = Path('/content/results')
RESULTS_DIR.mkdir(exist_ok=True)

## Load & Split Data

In [None]:
def load_corpus(data_dir, index_csv):
    df = pd.read_csv(
        index_csv, header=None,
        names=['doc_id', 'context', 'native_language', 'gender'],
    )
    texts = []
    for doc_id in tqdm(df['doc_id'], desc='Loading texts'):
        path = data_dir / f'{doc_id}.txt'
        texts.append(path.read_text(encoding='utf-8').strip())
    df['text'] = texts
    return df


def stratified_split(df, seed=42):
    df = df.dropna(subset=['native_language'])
    counts = df['native_language'].value_counts()
    rare_langs = counts[counts < 3].index
    df_rare = df[df['native_language'].isin(rare_langs)]
    df_main = df[~df['native_language'].isin(rare_langs)]
    df_main = df_main[df_main['native_language'].map(df_main['native_language'].value_counts()) > 1]

    df_train, df_valtest = train_test_split(
        df_main, test_size=0.2, random_state=seed,
        stratify=df_main['native_language'],
    )
    df_valtest = df_valtest[df_valtest['native_language'].map(df_valtest['native_language'].value_counts()) > 1]
    df_val, df_test = train_test_split(
        df_valtest, test_size=0.5, random_state=seed,
        stratify=df_valtest['native_language'],
    )
    df_train = pd.concat([df_train, df_rare], ignore_index=True)
    df_train = df_train.sample(frac=1, random_state=seed).reset_index(drop=True)
    return df_train, df_val.reset_index(drop=True), df_test.reset_index(drop=True)


df = load_corpus(DATA_DIR, INDEX_CSV)
le = LabelEncoder()
df['label'] = le.fit_transform(df['native_language'])
train_df, val_df, test_df = stratified_split(df, RANDOM_SEED)
label_names = list(le.classes_)
num_classes = len(label_names)

print(f'Train: {len(train_df)}  Val: {len(val_df)}  Test: {len(test_df)}')
print(f'Classes: {num_classes}')

## Tokenize & Create DataLoaders

In [None]:
class NLIDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length, desc='Tokenizing'):
        # Tokenize in batches to show progress
        batch_size = 256
        all_input_ids, all_attn, all_ttype = [], [], []
        for i in trange(0, len(texts), batch_size, desc=desc):
            enc = tokenizer(
                texts[i:i+batch_size], truncation=True,
                padding='max_length', max_length=max_length,
                return_tensors='pt',
            )
            all_input_ids.append(enc['input_ids'])
            all_attn.append(enc['attention_mask'])
            if 'token_type_ids' in enc:
                all_ttype.append(enc['token_type_ids'])
        self.encodings = {
            'input_ids': torch.cat(all_input_ids),
            'attention_mask': torch.cat(all_attn),
        }
        if all_ttype:
            self.encodings['token_type_ids'] = torch.cat(all_ttype)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {k: v[idx] for k, v in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item


print(f'Loading {MODEL_NAME}...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=num_classes, trust_remote_code=True,
).to(DEVICE)

train_dataset = NLIDataset(train_df['text'].tolist(), train_df['label'].tolist(), tokenizer, MAX_LENGTH, 'Tokenizing train')
val_dataset = NLIDataset(val_df['text'].tolist(), val_df['label'].tolist(), tokenizer, MAX_LENGTH, 'Tokenizing val')
test_dataset = NLIDataset(test_df['text'].tolist(), test_df['label'].tolist(), tokenizer, MAX_LENGTH, 'Tokenizing test')

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
print('Dataloaders ready.')

## Training Loop

In [None]:
# Class weights for imbalanced data
counts = np.bincount(train_df['label'].values, minlength=num_classes).astype(float)
counts[counts == 0] = 1.0
class_weights = torch.tensor(len(train_df) / (num_classes * counts), dtype=torch.float32).to(DEVICE)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
total_steps = len(train_loader) * EPOCHS
warmup_steps = int(total_steps * WARMUP_RATIO)
scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)


def train_epoch(model, loader, optimizer, scheduler, criterion, epoch):
    model.train()
    total_loss = 0
    pbar = tqdm(loader, desc=f'Train epoch {epoch}', leave=False)
    for batch in pbar:
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        labels = batch.pop('labels')
        outputs = model(**batch)
        loss = criterion(outputs.logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        total_loss += loss.item() * len(labels)
        pbar.set_postfix(loss=f'{loss.item():.4f}')
    return total_loss / len(loader.dataset)


@torch.no_grad()
def evaluate_model(model, loader, criterion, desc='Eval'):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    for batch in tqdm(loader, desc=desc, leave=False):
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        labels = batch.pop('labels')
        outputs = model(**batch)
        total_loss += criterion(outputs.logits, labels).item() * len(labels)
        all_preds.append(outputs.logits.argmax(dim=1).cpu().numpy())
        all_labels.append(labels.cpu().numpy())
    return (
        total_loss / len(loader.dataset),
        np.concatenate(all_preds),
        np.concatenate(all_labels),
    )


# ── Training ───────────────────────────────────────────────────────────
best_val_loss = float('inf')
patience_counter = 0
best_state = None
history = []

for epoch in range(1, EPOCHS + 1):
    t0 = time.time()
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, criterion, epoch)
    val_loss, val_preds, val_labels = evaluate_model(model, val_loader, criterion, 'Val')
    elapsed = time.time() - t0

    val_acc = accuracy_score(val_labels, val_preds)
    val_f1 = f1_score(val_labels, val_preds, average='macro', zero_division=0)
    history.append({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss, 'val_acc': val_acc, 'val_f1': val_f1})

    print(
        f'Epoch {epoch:3d} | train_loss={train_loss:.4f}  val_loss={val_loss:.4f}  '
        f'val_acc={val_acc:.4f}  val_macro_f1={val_f1:.4f}  ({elapsed:.1f}s)'
    )

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f'Early stopping at epoch {epoch}')
            break

# Restore best model
model.load_state_dict(best_state)
model = model.to(DEVICE)
print('\nTraining complete. Best val_loss:', f'{best_val_loss:.4f}')

## Evaluation

In [None]:
# ── Validation ─────────────────────────────────────────────────────────
_, val_preds, val_labels = evaluate_model(model, val_loader, criterion, 'Val')
present = sorted(set(val_labels) | set(val_preds))
names = [label_names[i] for i in present]

print('=== Validation Set ===')
print(f'Accuracy:    {accuracy_score(val_labels, val_preds):.4f}')
print(f'Macro-F1:    {f1_score(val_labels, val_preds, average="macro", zero_division=0):.4f}')
print(f'Weighted-F1: {f1_score(val_labels, val_preds, average="weighted", zero_division=0):.4f}')

# ── Test ───────────────────────────────────────────────────────────────
_, test_preds, test_labels = evaluate_model(model, test_loader, criterion, 'Test')
present = sorted(set(test_labels) | set(test_preds))
names = [label_names[i] for i in present]

print('\n=== Test Set ===')
print(f'Accuracy:    {accuracy_score(test_labels, test_preds):.4f}')
print(f'Macro-F1:    {f1_score(test_labels, test_preds, average="macro", zero_division=0):.4f}')
print(f'Weighted-F1: {f1_score(test_labels, test_preds, average="weighted", zero_division=0):.4f}')
print()
print(classification_report(test_labels, test_preds, labels=present, target_names=names, zero_division=0))

In [None]:
# ── Save results ──────────────────────────────────────────────────────
safe_name = MODEL_NAME.replace('/', '_')
results = {
    'model': MODEL_NAME,
    'test_accuracy': float(accuracy_score(test_labels, test_preds)),
    'test_macro_f1': float(f1_score(test_labels, test_preds, average='macro', zero_division=0)),
    'test_weighted_f1': float(f1_score(test_labels, test_preds, average='weighted', zero_division=0)),
    'val_accuracy': float(accuracy_score(val_labels, val_preds)),
    'val_macro_f1': float(f1_score(val_labels, val_preds, average='macro', zero_division=0)),
    'history': history,
}
out_path = RESULTS_DIR / f'finetune_{safe_name}.json'
with open(out_path, 'w') as f:
    json.dump(results, f, indent=2)
print(f'Saved results to {out_path}')

# Save model checkpoint to Drive
ckpt_dir = Path(DRIVE_DATA_DIR).parent / 'checkpoints' / safe_name
ckpt_dir.mkdir(parents=True, exist_ok=True)
model.save_pretrained(ckpt_dir)
tokenizer.save_pretrained(ckpt_dir)
print(f'Saved checkpoint to {ckpt_dir}')

## Training History Plot

In [None]:
import matplotlib.pyplot as plt

hist_df = pd.DataFrame(history)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(hist_df['epoch'], hist_df['train_loss'], label='Train')
ax1.plot(hist_df['epoch'], hist_df['val_loss'], label='Val')
ax1.set_xlabel('Epoch'); ax1.set_ylabel('Loss'); ax1.legend(); ax1.set_title('Loss')

ax2.plot(hist_df['epoch'], hist_df['val_acc'], label='Val Accuracy')
ax2.plot(hist_df['epoch'], hist_df['val_f1'], label='Val Macro-F1')
ax2.set_xlabel('Epoch'); ax2.set_ylabel('Score'); ax2.legend(); ax2.set_title('Validation Metrics')

plt.suptitle(MODEL_NAME)
plt.tight_layout()
plt.show()