# LoRA Distillation Pipeline

This notebook demonstrates the full pipeline for knowledge‑distillation fine‑tuning of `roberta-base` on AG News using LoRA adapters.

**Contents:**
1. Imports & Configuration  
2. Helper Functions  
3. Data Loading & Preview  
4. Teacher Logit Generation  
5. Logit Distribution Visualization  
6. Student Model & LoRA Setup  
7. Distillation Training & Evaluation  
8. Training Curves Visualization  
9. Parameter Count   
10. Inference on Unlabelled Test Set  


## 1. Imports & Configuration

In [None]:
import os, pickle
import numpy as np, pandas as pd, torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    RobertaForSequenceClassification, RobertaTokenizer,
    TrainingArguments, Trainer, DataCollatorWithPadding
)
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
import evaluate
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# Global config
BASE_MODEL        = "roberta-base"
NUM_EPOCHS        = 3
BATCH_SIZE_TRAIN  = 16
BATCH_SIZE_EVAL   = 64
LEARNING_RATE     = 2e-4
TEMPERATURE       = 2.0
ALPHA             = 0.7
LORA_RANK         = 8
LORA_ALPHA        = 16
LORA_DROPOUT      = 0.1
LORA_START        = 6
LORA_END          = 11
DEVICE            = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## 2. Helper Functions

In [None]:
def make_lora_target_modules(start, end):
    modules = ["attention.self.query","attention.self.value","output.dense"]
    return [f"encoder.layer.{i}.{m}" for i in range(start, end+1) for m in modules]

def distillation_loss(student_logits, teacher_logits, labels, T, alpha):
    kl = F.kl_div(
        input=F.log_softmax(student_logits / T, dim=-1),
        target=F.softmax(teacher_logits / T, dim=-1),
        reduction="batchmean"
    ) * (T*T)
    ce = F.cross_entropy(student_logits, labels)
    return alpha*kl + (1-alpha)*ce

def compute_metrics(pred):
    preds = pred.predictions.argmax(-1)
    acc = evaluate.load("accuracy").compute(predictions=preds, references=pred.label_ids)
    return {"accuracy": acc}

def evaluate_model(model, dataset, collator, batch_size=8):
    """
    Run evaluation manually: returns (metrics_dict, predictions_array).
    """
    loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collator)
    model.to(DEVICE).eval()
    metric = evaluate.load("accuracy")
    all_preds = []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            labels = batch.pop("labels", None)
            idxs   = batch.pop("idx", None)
            batch = {k: v.to(DEVICE) for k, v in batch.items()}
            logits = model(**batch).logits
            preds  = logits.argmax(dim=-1).cpu().numpy()
            all_preds.extend(preds)
            if labels is not None:
                metric.add_batch(predictions=preds, references=labels.numpy())
    return metric.compute(), np.array(all_preds)

## 3. Data Loading & Preview

- Load AG News training split.
- Tokenize texts.
- Preview first 5 samples.


In [None]:
raw = load_dataset("ag_news", split="train")
tokenizer = RobertaTokenizer.from_pretrained(BASE_MODEL)

def preprocess_fn(batch):
    return tokenizer(batch["text"], truncation=True, padding=False)

tokenized = raw.map(preprocess_fn, batched=True).rename_column("label", "labels")

# Preview
pd.DataFrame({"text": raw["text"][:5], "label": raw["label"][:5]})


## 4. Teacher Logit Generation

- Wrap dataset with CustomDataset + DataCollatorWithIdx.
- Load teacher model and generate logits for each training example.


In [None]:
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, hf_ds): ...
    def __len__(self): ...
    def __getitem__(self, idx): ...

class DataCollatorWithIdx(DataCollatorWithPadding):
    def __call__(self, features): ...

split = tokenized.train_test_split(test_size=640, seed=42)
train_ds = CustomDataset(split["train"])
collator = DataCollatorWithIdx(tokenizer, return_tensors="pt")

teacher = RobertaForSequenceClassification.from_pretrained(
    BASE_MODEL, id2label=id2label, label2id={v:k for k,v in id2label.items()}
).to(DEVICE).eval()

teacher_logits = torch.zeros(len(train_ds), num_labels)
for batch in tqdm(DataLoader(train_ds, batch_size=64, collate_fn=collator), desc="Teacher"):
    idxs = batch.pop("idx")
    inputs = {k:v.to(DEVICE) for k,v in batch.items()}
    with torch.no_grad():
        logits = teacher(**inputs).logits.cpu()
    teacher_logits[idxs] = logits


## 5. Logit Distribution Visualization

Plot a histogram of all teacher logits to inspect their distribution.

In [None]:
vals = teacher_logits.flatten().numpy()
plt.hist(vals, bins=50)
plt.title("Teacher Logit Distribution")
plt.show()


## 6. Student Model & LoRA Setup

- Load base RoBERTa.
- Inject LoRA adapters.
- Print number of trainable parameters.


In [None]:
student = RobertaForSequenceClassification.from_pretrained(BASE_MODEL)
lora_cfg = LoraConfig(
    r=LORA_RANK, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
    target_modules=make_lora_target_modules(LORA_START, LORA_END),
    bias="none", task_type="SEQ_CLS"
)
peft_model = get_peft_model(student, lora_cfg)
peft_model.print_trainable_parameters()


## 7. Distillation Training & Evaluation

- Set up `TrainingArguments`.
- Override `compute_loss` to use our `distillation_loss`.
- Train with `Trainer` and evaluate.


In [None]:
training_args = TrainingArguments(
    output_dir="results/model_checkpoint",
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE_TRAIN,
    per_device_eval_batch_size=BATCH_SIZE_EVAL,
    learning_rate=LEARNING_RATE,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to=[]
)

def compute_loss(model, inputs, return_outputs=False):
    # copy-paste 上面第7 步的实现
    ...

trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=CustomDataset(split["test"]),
    data_collator=collator,
    compute_metrics=compute_metrics,
    compute_loss=compute_loss
)
trainer.train()
metrics = trainer.evaluate()
print(f"Validation Accuracy: {metrics['eval_accuracy']:.4f}")


## 8. Training Curves Visualization

Use the trainer’s log history to plot training loss and validation accuracy over time.

In [None]:
import pandas as pd
hist = pd.DataFrame(trainer.state.log_history)
fig, axes = plt.subplots(1,2,figsize=(12,4))
hist['loss'].plot(ax=axes[0], title='Train Loss')
hist['eval_accuracy'].plot(ax=axes[1], title='Eval Accuracy')
plt.show()


## 9. Parameter Count

Compute and display total vs. trainable parameter counts (in millions).

In [None]:
total = sum(p.numel() for p in peft_model.parameters())
trainable = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
print(f"Total Params: {total/1e6:.2f}M")
print(f"Trainable Params: {trainable/1e6:.2f}M")

## 10. Inference on Unlabelled Test Set

Run the `inference.py` script to generate `submission.csv`, then preview the first rows.

In [None]:
!python ../scripts/inference.py \
  --checkpoint results/model_checkpoint \
  --test_file ../data/test_unlabelled.pkl \
  --output_csv submission.csv

import pandas as pd
pd.read_csv("submission.csv").head()
