In [35]:
import argparse
import torch
from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, TrainerCallback
from peft import LoraConfig, get_peft_model, TaskType
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, matthews_corrcoef
from scipy.stats import pearsonr, spearmanr
import torch.nn.functional as F

In [36]:
glue_tasks = [
    "cola", "sst2", "mrpc", "qqp", "stsb",
    "mnli", "qnli", "rte", "wnli",
]
glue_metrics = {}

In [71]:
import evaluate

for task in glue_tasks:
    print('task:', task)
    glue_metrics[task] = evaluate.load('glue', task)

task: cola


task: sst2
task: mrpc
task: qqp
task: stsb
task: mnli
task: qnli
task: rte
task: wnli


In [38]:
class Config:
    def __init__(self):
        # Model arguments
        self.teacher_model_name = "./models/bert-base-uncased"
        self.student_model_name = "./models/distilbert-base-uncased"
        self.task = 'sst2'

        # Dataset and training parameters
        self.num_labels = 2
        self.train_batch_size = 16
        self.num_train_epochs = 3

        # LoRA parameters
        self.rank = 8
        self.lora_alpha = 16
        self.lora_dropout = 0.1

        # Learning rates
        self.teacher_learning_rate = 5e-5
        self.student_learning_rate = 5e-5

args = Config()

# Now you can access them as usual:
print(args.teacher_model_name)

./models/bert-base-uncased


In [39]:
# Step 1: Fine-tune a Teacher Model
print(f"Fine-tuning the teacher model: {args.teacher_model_name}")
teacher_model = AutoModelForSequenceClassification.from_pretrained(args.teacher_model_name, num_labels=args.num_labels)
teacher_tokenizer = AutoTokenizer.from_pretrained(args.teacher_model_name)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./models/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fine-tuning the teacher model: ./models/bert-base-uncased


In [40]:
teacher_training_args = TrainingArguments(
    output_dir="./teacher_results/" + args.task,
    learning_rate=args.teacher_learning_rate,
    per_device_train_batch_size=args.train_batch_size,
    num_train_epochs=args.num_train_epochs,
    weight_decay=0.01,
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [41]:
pwd

'/mnt/data2/congfeng/kd-lora'

In [42]:
from datasets import load_dataset

# 下载并加载 GLUE 的 WNLI 子集
teacher_dataset = load_dataset("glue", args.task, cache_dir='./dataset')


In [43]:
teacher_dataset

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

In [49]:
if args.task == 'wnli':
    tokenized_teacher_dataset = teacher_dataset.map(
        lambda x: teacher_tokenizer(x["sentence1"], x["sentence2"], padding="max_length", truncation=True),
        batched=True
    )
elif args.task in ['cola', 'sst2']:
    def preprocess_cola(examples):
        # CoLA 只有 'sentence' 这一个输入字段
        return teacher_tokenizer(
            examples["sentence"], 
            truncation=True, 
            padding="max_length", 
        )

    # 3. 加载原始数据
    cola_dataset = load_dataset("glue", "cola")

    # 4. 批量映射
    tokenized_teacher_dataset = teacher_dataset.map(preprocess_cola, batched=True)

    # 5. 格式转换：移除原始文本列，保留模型需要的张量列
    # 注意：一定要把 'label' 改名为 'labels'
    tokenized_teacher_dataset = tokenized_teacher_dataset.rename_column("label", "labels")
    tokenized_teacher_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
else:
    raise ValueError

Map: 100%|██████████| 67349/67349 [00:14<00:00, 4627.05 examples/s]
Map: 100%|██████████| 872/872 [00:00<00:00, 4140.28 examples/s]
Map: 100%|██████████| 1821/1821 [00:00<00:00, 5255.74 examples/s]


In [50]:
tokenized_teacher_dataset

DatasetDict({
    train: Dataset({
        features: ['sentence', 'labels', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'labels', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'labels', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1821
    })
})

In [51]:
# Define trainer for teacher model
teacher_trainer = Trainer(
    model=teacher_model,
    args=teacher_training_args,
    train_dataset=tokenized_teacher_dataset["train"],
    eval_dataset=tokenized_teacher_dataset["validation"]
)
teacher_trainer.train()



Step,Training Loss
500,0.2166
1000,0.1029
1500,0.0617




TrainOutput(global_step=1581, training_loss=0.12308960998457945, metrics={'train_runtime': 1775.0017, 'train_samples_per_second': 113.829, 'train_steps_per_second': 0.891, 'total_flos': 5.316079940232192e+16, 'train_loss': 0.12308960998457945, 'epoch': 3.0})

In [52]:
# Save teacher model predictions (logits) as soft labels
teacher_logits = teacher_trainer.predict(tokenized_teacher_dataset["train"]).predictions
teacher_soft_labels = torch.tensor(teacher_logits)




In [None]:
teacher_model.save_pretrained(f'./teacher_model_FFT/{args.task}/bert-base-uncased-FFT')

In [54]:
teacher_soft_labels.shape

torch.Size([67349, 2])

In [55]:
# Step 2: Initialize a Smaller Student Model with LoRA
print(f"Initializing student model: {args.student_model_name} with LoRA")
student_model = AutoModelForSequenceClassification.from_pretrained(args.student_model_name, num_labels=args.num_labels)
student_tokenizer = AutoTokenizer.from_pretrained(args.student_model_name)


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at ./models/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Initializing student model: ./models/distilbert-base-uncased with LoRA


In [56]:
lora_config = LoraConfig(
    r=args.rank,
    lora_alpha=args.lora_alpha,
    target_modules=["q_lin", "v_lin"],
    lora_dropout=args.lora_dropout,
    bias="none",
    task_type="SEQ_CLS"
)

In [57]:
# Apply LoRA configuration to the student model
student_model = get_peft_model(student_model, lora_config)


In [58]:
student_model

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): DistilBertForSequenceClassification(
      (distilbert): DistilBertModel(
        (embeddings): Embeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (transformer): Transformer(
          (layer): ModuleList(
            (0-5): 6 x TransformerBlock(
              (attention): DistilBertSdpaAttention(
                (dropout): Dropout(p=0.1, inplace=False)
                (q_lin): lora.Linear(
                  (base_layer): Linear(in_features=768, out_features=768, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.1, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=76

In [59]:
# Freeze all layers except LoRA parameters
for param in student_model.parameters():
    param.requires_grad = False
for name, param in student_model.named_parameters():
    if "lora_" in name:
        param.requires_grad = True  # Only LoRA weights are trainable


In [60]:
# Step 3: Distillation from Teacher to Student
print("Starting knowledge distillation from teacher to student")
student_training_args = TrainingArguments(
    output_dir="./student_results/" + args.task,
    learning_rate=args.student_learning_rate,
    per_device_train_batch_size=args.train_batch_size,
    num_train_epochs=args.num_train_epochs,
    weight_decay=0.01,
    remove_unused_columns=False,
)

Starting knowledge distillation from teacher to student


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [61]:
def distillation_loss(student_logits, teacher_logits, labels, temperature=2.0, alpha=0.5):
    # Compute the distillation loss with temperature scaling
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / temperature, dim=-1),
        F.softmax(teacher_logits / temperature, dim=-1),
        reduction="batchmean"
    ) * (temperature ** 2)
    hard_loss = F.cross_entropy(student_logits, labels)
    return alpha * soft_loss + (1 - alpha) * hard_loss

save_inputs = None

# Define a custom training loop for distillation
class DistillationTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwds):
        global save_inputs
        save_inputs = inputs.copy()

        labels = inputs.pop("labels")
        idx = inputs.pop('idx').long().cpu()
        outputs = model(**inputs)
        student_logits = outputs.logits
        teacher_logits = teacher_soft_labels[idx]  # Align teacher logits with batch size
        # teacher_logits = teacher_soft_labels[inputs["input_ids"].shape[0]]  # Align teacher logits with batch size
        teacher_logits = teacher_logits.to(student_logits.device)
        loss = distillation_loss(student_logits, teacher_logits, labels)
        return (loss, outputs) if return_outputs else loss


In [62]:
# Tokenize student dataset
if args.task == 'wnli':
    tokenized_student_dataset = teacher_dataset.map(
        lambda x, idx: {**student_tokenizer(x["sentence1"], x["sentence2"], padding="max_length", truncation=True), 'idx': idx},
        batched=True, with_indices=True
    )
elif args.task in ['sst2', 'cola']:
    tokenized_student_dataset = teacher_dataset.map(
        lambda x, idx: {**student_tokenizer(x["sentence"], padding="max_length", truncation=True), 'idx': idx},
        batched=True, with_indices=True
    )
else:
    raise ValueError

Map:  15%|█▍        | 10000/67349 [00:02<00:13, 4236.34 examples/s]

Map: 100%|██████████| 67349/67349 [00:12<00:00, 5574.30 examples/s]
Map: 100%|██████████| 872/872 [00:00<00:00, 5297.05 examples/s]
Map: 100%|██████████| 1821/1821 [00:00<00:00, 3626.74 examples/s]


In [63]:
tokenized_student_dataset['train'][0]['idx']

0

In [64]:
# Initialize Distillation Trainer
student_trainer = DistillationTrainer(
    model=student_model,
    args=student_training_args,
    train_dataset=tokenized_student_dataset["train"],
    eval_dataset=tokenized_student_dataset["validation"]
)

# Train student model with knowledge distillation
student_trainer.train()

Step,Training Loss
500,1.1096
1000,0.7003
1500,0.6486


TrainOutput(global_step=1581, training_loss=0.8099147186182784, metrics={'train_runtime': 1127.0118, 'train_samples_per_second': 179.277, 'train_steps_per_second': 1.403, 'total_flos': 2.7223692935417856e+16, 'train_loss': 0.8099147186182784, 'epoch': 3.0})

In [65]:
# Evaluate student model
student_trainer.evaluate()


{'eval_loss': 1.57597017288208,
 'eval_runtime': 3.2514,
 'eval_samples_per_second': 268.192,
 'eval_steps_per_second': 4.306,
 'epoch': 3.0}

In [None]:
# Save the fine-tuned LoRA student model
output_dir = "./student_model_LoRA/" + args.task
student_model.save_pretrained(output_dir)
student_tokenizer.save_pretrained(output_dir)
print(f"Student model saved to {output_dir}")

Student model saved to ./fine_tuned_student_model/sst2


In [67]:
import torch
from tqdm import tqdm
import numpy as np
import evaluate
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

def run_glue_inference(task_name, model, tokenizer, eval_dataset, device="cuda", batch_size=16):
    """
    针对 GLUE 任务运行推理并计算指标
    """
    model.to(device)
    model.eval()

    # 1. 加载对应任务的指标
    metric = glue_metrics[task] #evaluate.load("glue", task_name)

    # 2. 准备 DataCollator (自动处理 Padding)
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    # 3. 创建 DataLoader
    # 确保 eval_dataset 已经包含了 input_ids 等模型需要的字段
    dataloader = DataLoader(
        eval_dataset,
        batch_size=batch_size,
        collate_fn=data_collator
    )

    all_preds = []
    all_labels = []

    print(f"开始推理任务: {task_name.upper()}")
    for batch in tqdm(dataloader):
        # 将数据移动到设备
        inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}
        labels = batch["labels"]

        with torch.no_grad():
            outputs = model(**inputs)

        logits = outputs.logits

        # --- 核心逻辑：区分分类任务和回归任务 ---
        if task_name == "stsb":
            # STS-B 是回归任务，直接取第一个数值
            preds = logits.squeeze().cpu().numpy()
        else:
            # 其他是分类任务，取概率最大的索引
            preds = torch.argmax(logits, dim=-1).cpu().numpy()

        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())

    # 4. 计算最终指标
    results = metric.compute(predictions=all_preds, references=all_labels)
    return results

# --- 使用示例 ---
# 假设你已经准备好了之前加载的模型 model, tokenizer 和 tokenized_dataset
# results = run_glue_inference("wnli", model, tokenizer, tokenized_dataset["validation"])
# print(results)

In [72]:
args.task

'sst2'

In [73]:
run_glue_inference(task_name=args.task, model=student_model, tokenizer=student_tokenizer,
                   eval_dataset=tokenized_teacher_dataset['validation'], device="cuda", batch_size=16)

开始推理任务: SST2


  4%|▎         | 2/55 [00:00<00:10,  5.15it/s]

100%|██████████| 55/55 [00:09<00:00,  5.94it/s]


{'accuracy': 0.8600917431192661}

Exactly the same as reported in the paper.