In [1]:
import argparse
import torch
from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, TrainerCallback
from peft import LoraConfig, get_peft_model, TaskType
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, matthews_corrcoef
from scipy.stats import pearsonr, spearmanr
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class Config:
    def __init__(self):
        # Model arguments
        self.teacher_model_name = "./models/bert-base-uncased"
        self.student_model_name = "./models/distilbert-base-uncased"

        # Dataset and training parameters
        self.num_labels = 2
        self.train_batch_size = 16
        self.num_train_epochs = 3

        # LoRA parameters
        self.rank = 8
        self.lora_alpha = 16
        self.lora_dropout = 0.1

        # Learning rates
        self.teacher_learning_rate = 5e-5
        self.student_learning_rate = 5e-5

args = Config()

# Now you can access them as usual:
print(args.teacher_model_name)

./models/bert-base-uncased


In [9]:
# Step 1: Fine-tune a Teacher Model
print(f"Fine-tuning the teacher model: {args.teacher_model_name}")
teacher_model = AutoModelForSequenceClassification.from_pretrained(args.teacher_model_name, num_labels=args.num_labels)
teacher_tokenizer = AutoTokenizer.from_pretrained(args.teacher_model_name)


Fine-tuning the teacher model: ./models/bert-base-uncased


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./models/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
teacher_training_args = TrainingArguments(
    output_dir="./teacher_results",
    learning_rate=args.teacher_learning_rate,
    per_device_train_batch_size=args.train_batch_size,
    num_train_epochs=args.num_train_epochs,
    weight_decay=0.01,
)

In [13]:
pwd

'/mnt/data2/congfeng/kd-lora'

In [14]:
from datasets import load_dataset

# 下载并加载 GLUE 的 WNLI 子集
teacher_dataset = load_dataset("glue", "wnli", cache_dir='./dataset')


In [15]:
teacher_dataset

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 635
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 71
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 146
    })
})

In [16]:
# teacher_dataset = load_from_disk(args.dataset_path)
tokenized_teacher_dataset = teacher_dataset.map(
    lambda x: teacher_tokenizer(x["sentence1"], x["sentence2"], padding="max_length", truncation=True),
    batched=True
)

Map: 100%|██████████| 635/635 [00:00<00:00, 3538.61 examples/s]
Map: 100%|██████████| 71/71 [00:00<00:00, 3173.37 examples/s]
Map: 100%|██████████| 146/146 [00:00<00:00, 1865.67 examples/s]


In [18]:
# Define trainer for teacher model
teacher_trainer = Trainer(
    model=teacher_model,
    args=teacher_training_args,
    train_dataset=tokenized_teacher_dataset["train"],
    eval_dataset=tokenized_teacher_dataset["validation"]
)
teacher_trainer.train()



Step,Training Loss


TrainOutput(global_step=15, training_loss=0.7022274653116862, metrics={'train_runtime': 57.8655, 'train_samples_per_second': 32.921, 'train_steps_per_second': 0.259, 'total_flos': 501226560460800.0, 'train_loss': 0.7022274653116862, 'epoch': 3.0})

In [None]:
# Save teacher model predictions (logits) as soft labels
teacher_logits = teacher_trainer.predict(tokenized_teacher_dataset["train"]).predictions
teacher_soft_labels = torch.tensor(teacher_logit)




In [22]:
teacher_model.save_pretrained('./pretrained/bert-base-uncased-FFT-wnli')

In [20]:
teacher_soft_labels.shape

torch.Size([635, 2])

In [23]:
# Step 2: Initialize a Smaller Student Model with LoRA
print(f"Initializing student model: {args.student_model_name} with LoRA")
student_model = AutoModelForSequenceClassification.from_pretrained(args.student_model_name, num_labels=args.num_labels)
student_tokenizer = AutoTokenizer.from_pretrained(args.student_model_name)


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at ./models/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Initializing student model: ./models/distilbert-base-uncased with LoRA


In [27]:
lora_config = LoraConfig(
    r=args.rank,
    lora_alpha=args.lora_alpha,
    target_modules=["q_lin", "v_lin"],
    lora_dropout=args.lora_dropout,
    bias="none",
    task_type="SEQ_CLS"
)

In [26]:
# Run this to see the exact names of your layers
for name, module in student_model.named_modules():
    print(name)


distilbert
distilbert.embeddings
distilbert.embeddings.word_embeddings
distilbert.embeddings.position_embeddings
distilbert.embeddings.LayerNorm
distilbert.embeddings.dropout
distilbert.transformer
distilbert.transformer.layer
distilbert.transformer.layer.0
distilbert.transformer.layer.0.attention
distilbert.transformer.layer.0.attention.dropout
distilbert.transformer.layer.0.attention.q_lin
distilbert.transformer.layer.0.attention.k_lin
distilbert.transformer.layer.0.attention.v_lin
distilbert.transformer.layer.0.attention.out_lin
distilbert.transformer.layer.0.sa_layer_norm
distilbert.transformer.layer.0.ffn
distilbert.transformer.layer.0.ffn.dropout
distilbert.transformer.layer.0.ffn.lin1
distilbert.transformer.layer.0.ffn.lin2
distilbert.transformer.layer.0.ffn.activation
distilbert.transformer.layer.0.output_layer_norm
distilbert.transformer.layer.1
distilbert.transformer.layer.1.attention
distilbert.transformer.layer.1.attention.dropout
distilbert.transformer.layer.1.attention.q

In [28]:
# Apply LoRA configuration to the student model
student_model = get_peft_model(student_model, lora_config)


In [29]:
student_model

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): DistilBertForSequenceClassification(
      (distilbert): DistilBertModel(
        (embeddings): Embeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (transformer): Transformer(
          (layer): ModuleList(
            (0-5): 6 x TransformerBlock(
              (attention): DistilBertSdpaAttention(
                (dropout): Dropout(p=0.1, inplace=False)
                (q_lin): lora.Linear(
                  (base_layer): Linear(in_features=768, out_features=768, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.1, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=76

In [30]:
# Freeze all layers except LoRA parameters
for param in student_model.parameters():
    param.requires_grad = False
for name, param in student_model.named_parameters():
    if "lora_" in name:
        param.requires_grad = True  # Only LoRA weights are trainable


In [57]:
# Step 3: Distillation from Teacher to Student
print("Starting knowledge distillation from teacher to student")
student_training_args = TrainingArguments(
    output_dir="./student_results",
    learning_rate=args.student_learning_rate,
    per_device_train_batch_size=args.train_batch_size,
    num_train_epochs=args.num_train_epochs,
    weight_decay=0.01,
    remove_unused_columns=False,
)

Starting knowledge distillation from teacher to student


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [65]:
def distillation_loss(student_logits, teacher_logits, labels, temperature=2.0, alpha=0.5):
    # Compute the distillation loss with temperature scaling
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / temperature, dim=-1),
        F.softmax(teacher_logits / temperature, dim=-1),
        reduction="batchmean"
    ) * (temperature ** 2)
    hard_loss = F.cross_entropy(student_logits, labels)
    return alpha * soft_loss + (1 - alpha) * hard_loss

save_inputs = None

# Define a custom training loop for distillation
class DistillationTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwds):
        global save_inputs
        save_inputs = inputs.copy()

        labels = inputs.pop("labels")
        idx = inputs.pop('idx').long().cpu()
        outputs = model(**inputs)
        student_logits = outputs.logits
        teacher_logits = teacher_soft_labels[idx]  # Align teacher logits with batch size
        # teacher_logits = teacher_soft_labels[inputs["input_ids"].shape[0]]  # Align teacher logits with batch size
        teacher_logits = teacher_logits.to(student_logits.device)
        loss = distillation_loss(student_logits, teacher_logits, labels)
        return (loss, outputs) if return_outputs else loss


In [49]:
# Tokenize student dataset
tokenized_student_dataset = teacher_dataset.map(
    lambda x, idx: {**student_tokenizer(x["sentence1"], x["sentence2"], padding="max_length", truncation=True), 'idx': idx},
    batched=True, with_indices=True
)


In [53]:
tokenized_student_dataset['train'][0]['idx']

0

In [66]:
# Initialize Distillation Trainer
student_trainer = DistillationTrainer(
    model=student_model,
    args=student_training_args,
    train_dataset=tokenized_student_dataset["train"],
    eval_dataset=tokenized_student_dataset["validation"]
)

# Train student model with knowledge distillation
student_trainer.train()

Step,Training Loss


TrainOutput(global_step=15, training_loss=0.34891834259033205, metrics={'train_runtime': 10.6745, 'train_samples_per_second': 178.463, 'train_steps_per_second': 1.405, 'total_flos': 256678570045440.0, 'train_loss': 0.34891834259033205, 'epoch': 3.0})

In [56]:
save_inputs.keys()

dict_keys(['labels', 'input_ids', 'attention_mask'])

In [67]:
# Evaluate student model
student_trainer.evaluate()


{'eval_loss': 0.3447301983833313,
 'eval_runtime': 4.7439,
 'eval_samples_per_second': 14.967,
 'eval_steps_per_second': 0.422,
 'epoch': 3.0}

In [68]:
# Save the fine-tuned LoRA student model
output_dir = "./fine_tuned_student_model"
student_model.save_pretrained(output_dir)
student_tokenizer.save_pretrained(output_dir)
print(f"Student model saved to {output_dir}")

Student model saved to ./fine_tuned_student_model


In [69]:
import torch
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel
from torch.utils.data import DataLoader
from tqdm import tqdm

In [72]:
test_data = teacher_dataset['validation'] # 如果要生成提交文件，请换成 dataset["test"]
model = student_model

# 4. 预处理函数
def preprocess_function(examples):
    return student_tokenizer(examples["sentence1"], examples["sentence2"], 
                     truncation=True, padding="max_length", max_length=128)

tokenized_test = test_data.map(preprocess_function, batched=True)
tokenized_test.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

# 5. 推理循环
predictions = []
references = []

dataloader = DataLoader(tokenized_test, batch_size=16)

print("正在进行推理...")
for batch in tqdm(dataloader):
    inputs = {k: v.to(model.device) for k, v in batch.items() if k != "label"}
    with torch.no_grad():
        outputs = model(**inputs)
    
    logits = outputs.logits
    preds = torch.argmax(logits, dim=-1)
    predictions.extend(preds.cpu().numpy())
    references.extend(batch["label"].cpu().numpy())

# 6. 计算准确率 (仅适用于 validation)
correct = sum(1 for p, r in zip(predictions, references) if p == r)
accuracy = correct / len(references)
print(f"\nValidation Accuracy: {accuracy:.4f}")

正在进行推理...


100%|██████████| 5/5 [00:00<00:00, 23.40it/s]


Validation Accuracy: 0.5634





Exactly the same as reported in the paper.