In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer



class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.3, gamma=0.2, temp=3.0):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.temp = temp

    def forward(self, student_outputs, teacher_outputs, labels):
        # 硬目标损失
        ce_loss = F.cross_entropy(
            student_outputs.logits.view(-1, student_outputs.logits.size(-1)),
            labels.view(-1)
        )
        
        # 软目标损失
        soft_teacher = F.softmax(teacher_outputs.logits / self.temp, dim=-1)
        soft_student = F.log_softmax(student_outputs.logits / self.temp, dim=-1)
        kl_loss = F.kl_div(soft_student, soft_teacher, reduction="batchmean") * (self.temp**2)
        
        # 中间层对齐损失
        hidden_loss = F.mse_loss(
            student_outputs.hidden_states[-1],  # 取最后一层隐藏状态
            self.proj_layer(teacher_outputs.hidden_states[-1])  # 可学习的投影矩阵
        )
        
        return self.alpha*ce_loss + self.beta*kl_loss + self.gamma*hidden_loss
    
def generate_soft_labels(batch):
    inputs = tokenizer(
        [f"{ins} {inp}" for ins, inp in zip(batch["instruction"], batch["input"])],
        return_tensors="pt",
        padding=True,
        max_length=512,
        truncation=True
    ).to("cuda")
    
    with torch.no_grad():
        outputs = teacher_model(**inputs, output_hidden_states=True)
    
    # 提取logits和最后一层隐藏状态
    return {
        "logits": outputs.logits.cpu(),
        "hidden_states": outputs.hidden_states[-1].cpu()
    }

# 对数据集批量处理
soft_dataset = dataset.map(
    generate_soft_labels,
    batched=True,
    batch_size=4,
    remove_columns=dataset.column_names
)


teacher_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-13b-chat-hf",
    device_map="auto",
    load_in_4bit=True,  # 4bit量化节省显存
    torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-13b-chat-hf")
from transformers import AutoConfig

# 原始LLaMA-2配置
original_config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")

# 学生模型配置调整
student_config = original_config.copy()
student_config.update({
    "num_hidden_layers": 12,  # 从32层减至12层
    "intermediate_size": 2048,  # FFN维度减半
    "num_attention_heads": 16  # 注意力头数减半
})

# 初始化学生模型
student_model = AutoModelForCausalLM.from_config(student_config)


In [None]:
from transformers import TrainingArguments, Trainer
import numpy as np


training_args = TrainingArguments(
    output_dir="./distill_results",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=16,  # 有效批次大小=128
    learning_rate=2e-4,
    warmup_ratio=0.1,
    weight_decay=0.01,
    fp16=True,
    logging_steps=50,
    max_steps=5000,
    gradient_checkpointing=True  # 节省显存
)

# 动态温度调度函数
def dynamic_temperature_schedule(step, total_steps):
    initial_temp = 5.0
    final_temp = 1.0
    # 余弦退火调整
    temp = final_temp + 0.5 * (initial_temp - final_temp) * (1 + np.cos(np.pi * step / total_steps))
    return temp

class DistillationTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        # 解包数据
        labels = inputs.pop("labels")
        teacher_logits = inputs.pop("teacher_logits")
        teacher_hidden = inputs.pop("teacher_hidden")

        # 学生模型前向
        outputs = model(**inputs, output_hidden_states=True)

        # 构建教师输出对象
        teacher_outputs = type(outputs)(
            logits=teacher_logits,
            hidden_states=[teacher_hidden]
        )

        # 动态调整温度
        current_step = self.state.global_step  # 当前训练步数
        max_steps = self.args.max_steps  # 最大训练步数
        current_temp = dynamic_temperature_schedule(current_step, max_steps)

        # 设置损失函数的温度
        self.loss_fn.temp = current_temp

        # 计算复合损失
        loss = self.loss_fn(outputs, teacher_outputs, labels)
        return (loss, outputs) if return_outputs else loss

    def __init__(self, *args, loss_fn=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fn = loss_fn  # 传递损失函数

# 初始化损失函数
class DistillationLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.temperature = 1.0  # 默认温度
        self.kl_loss = nn.KLDivLoss(reduction="batchmean")

    def forward(self, student_outputs, teacher_outputs, labels=None):
        # 计算学生和教师的 logits
        student_logits = student_outputs.logits / self.temperature
        teacher_logits = teacher_outputs.logits / self.temperature

        # 计算 KL 散度损失
        kl_loss = self.kl_loss(
            F.log_softmax(student_logits, dim=-1),
            F.softmax(teacher_logits, dim=-1)
        )

        # 如果有标签，计算交叉熵损失
        if labels is not None:
            ce_loss = F.cross_entropy(student_logits, labels)
            return kl_loss + ce_loss  # 或者根据需
