In [1]:
from IPython.display import Image

## trainer arguments & trainer

- $q(x)$：from student model，$p(x)$：from teacher model
- 其次对于 $q(x), p(x)$ 在计算时需要加温度
$$
\begin{split}
L_{\text{student}}&=\alpha L_{\text{CE}} + (1-\alpha)L_{KD}\\
&=\alpha L_{\text{CE}} + (1-\alpha)T^2D_{KL}\\
&=\alpha L_{\text{CE}} + (1-\alpha)T^2\sum_ip_i(x)\log\frac{p_i(x)}{q_i(x)}
\end{split}
$$

- 关于 `nn.KLDivLoss()`
    - inputs ($q(x)$): log probabilities
    - labels ($p(x)$): normal probabilities

In [27]:
from transformers import TrainingArguments, Trainer
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

In [2]:
class DistillTrainingArguments(TrainingArguments):
    # TrainingArguments: @dataclass
    # 增加两个 KD 所需的参数参数
    def __init__(self, *args, alpha=0.5, temperature=2., **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.temperature = temperature

In [26]:
class DistillTrainer(Trainer):
    
    def __init__(self, *args, teacher_model=None, **kwargs):
        # 增加 teacher_model 参数
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        
    # 重写 trainer 中核心方法
    def compute_loss(self, model, inputs, return_outputs=False):
        s_output = model(**inputs)
        s_ce = s_output.loss
        s_logits = s_output.logits
        
        with torch.no_grad():
            t_output = self.teacher_model(**inputs)
            t_logits = t_output.logits
        
        loss_kl_fct = nn.KLDivLoss(reduction='batchmean')
        loss_kd = self.args.temperature**2 * loss_kl_fct(F.log_softmax(s_logits/self.args.temperature, dim=-1), 
                                                        F.softmax(t_logits/self.args.temperature, dim=-1))
        loss = self.args.alpha * s_ce + (1-self.args.alpha) * loss_kd
        return (loss, s_output) if return_outputs else loss


## pipeline

### datasets

In [6]:
from datasets import load_dataset
clinc = load_dataset("clinc_oos", "plus")

Using the latest cached version of the module from /home/whaow/.cache/huggingface/modules/datasets_modules/datasets/clinc_oos/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1 (last modified on Wed Jun 14 00:03:25 2023) since it couldn't be found locally at clinc_oos., or remotely on the Hugging Face Hub.
Found cached dataset clinc_oos (/home/whaow/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1)


  0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
clinc

DatasetDict({
    train: Dataset({
        features: ['text', 'intent'],
        num_rows: 15250
    })
    validation: Dataset({
        features: ['text', 'intent'],
        num_rows: 3100
    })
    test: Dataset({
        features: ['text', 'intent'],
        num_rows: 5500
    })
})

In [25]:
# from datasets import load_metric
# accuracy_score = load_metric('accuracy')
import evaluate
accuracy_score = evaluate.load('accuracy')

Using the latest cached version of the module from /home/whaow/.cache/huggingface/modules/datasets_modules/metrics/accuracy/9756d5fa4a0f9da966341741fc3926eafdc604b8276add51d5abbaa8958a25f9 (last modified on Thu Jun 15 21:41:00 2023) since it couldn't be found locally at accuracy, or remotely on the Hugging Face Hub.


### Student model 初始化

In [4]:
from transformers import AutoTokenizer

In [5]:
s_ckpt = 'distilbert-base-uncased'
s_tokenizer = AutoTokenizer.from_pretrained(s_ckpt)

In [28]:
clinc_enc = clinc.map(lambda batch: s_tokenizer(batch['text'], truncation=True), 
                      batched=True, 
                      remove_columns=["text"]
                     )
clinc_enc = clinc_enc.rename_columns({'intent': 'labels'})
clinc_enc

Loading cached processed dataset at /home/whaow/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1/cache-e8b4fc4135f51853.arrow
Loading cached processed dataset at /home/whaow/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1/cache-2c7c66ce85fed60a.arrow
Loading cached processed dataset at /home/whaow/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1/cache-8bb21133bce10942.arrow


DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 15250
    })
    validation: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 3100
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 5500
    })
})

In [29]:
# trainer 重要的回调函数，非成员函数
def compute_metrics(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=-1)
    return accuracy_score.compute(references=labels, predictions=predictions)