##  review

- $q(x)$：from student model，$p(x)$：from teacher model
- 其次对于 $q(x), p(x)$ 在计算时需要加温度
$$
\begin{split}
L_{\text{student}}&=\alpha L_{\text{CE}} + (1-\alpha)L_{KD}\\
&=\alpha L_{\text{CE}} + (1-\alpha)T^2D_{KL}\\
&=\alpha L_{\text{CE}} + (1-\alpha)T^2\sum_ip_i(x)\log\frac{p_i(x)}{q_i(x)}
\end{split}
$$

- 关于 `nn.KLDivLoss()`
    - inputs ($q(x)$): log probabilities
    - labels ($p(x)$): normal probabilities

## trainer arguments & trainer

In [11]:
from transformers import TrainingArguments, Trainer
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import wandb

In [12]:
class DistillTrainingArguments(TrainingArguments):
    # TrainingArguments: @dataclass
    # 增加两个 KD 所需的参数参数
    def __init__(self, *args, alpha=0.5, temperature=2., **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.temperature = temperature

In [13]:
class DistillTrainer(Trainer):
    
    def __init__(self, *args, teacher_model=None, **kwargs):
        # 增加 teacher_model 参数
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        
    # 重写 trainer 中核心方法
    # forward 计算损失
    def compute_loss(self, model, inputs, return_outputs=False):
        s_output = model(**inputs)
        s_ce = s_output.loss
        s_logits = s_output.logits
        
        with torch.no_grad():
            t_output = self.teacher_model(**inputs)
            t_logits = t_output.logits
        
        loss_kl_fct = nn.KLDivLoss(reduction='batchmean')
        loss_kd = self.args.temperature**2 * loss_kl_fct(F.log_softmax(s_logits/self.args.temperature, dim=-1), 
                                                        F.softmax(t_logits/self.args.temperature, dim=-1))
        loss = self.args.alpha * s_ce + (1-self.args.alpha) * loss_kd
        return (loss, s_output) if return_outputs else loss


## pipeline

### datasets

In [14]:
# import os
# os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'

In [15]:
from datasets import load_dataset
clinc = load_dataset("clinc_oos", "plus")

Found cached dataset clinc_oos (/media/whaow/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1)


  0%|          | 0/3 [00:00<?, ?it/s]

In [16]:
clinc

DatasetDict({
    train: Dataset({
        features: ['text', 'intent'],
        num_rows: 15250
    })
    validation: Dataset({
        features: ['text', 'intent'],
        num_rows: 3100
    })
    test: Dataset({
        features: ['text', 'intent'],
        num_rows: 5500
    })
})

In [17]:
clinc['train'][:10]

{'text': ['what expression would i use to say i love you if i were an italian',
  "can you tell me how to say 'i do not speak much spanish', in spanish",
  "what is the equivalent of, 'life is good' in french",
  "tell me how to say, 'it is a beautiful morning' in italian",
  'if i were mongolian, how would i say that i am a tourist',
  "how do i say 'hotel' in finnish",
  "i need you to translate the sentence, 'we will be there soon' into portuguese",
  'please tell me how to ask for a taxi in french',
  "can you tell me how i would say, 'more bread please' in french",
  "what is the correct way to say 'i am a visitor' in french"],
 'intent': [61, 61, 61, 61, 61, 61, 61, 61, 61, 61]}

In [19]:
intents = clinc['train'].features['intent']
num_labels = intents.num_classes
num_labels

151

### Student model 初始化

In [20]:
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoModelForSequenceClassification

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

s_ckpt = 'distilbert-base-uncased'
s_tokenizer = AutoTokenizer.from_pretrained(s_ckpt)

t_ckpt = 'transformersbook/bert-base-uncased-finetuned-clinc'
t_model = AutoModelForSequenceClassification.from_pretrained(t_ckpt, num_labels=num_labels).to(device)

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/8.18k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

In [22]:
clinc_enc = clinc.map(lambda batch: s_tokenizer(batch['text'], truncation=True), 
                      batched=True, 
                      remove_columns=["text"]
                     )
clinc_enc = clinc_enc.rename_columns({'intent': 'labels'})
clinc_enc

Map:   0%|          | 0/15250 [00:00<?, ? examples/s]

Map:   0%|          | 0/3100 [00:00<?, ? examples/s]

Map:   0%|          | 0/5500 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 15250
    })
    validation: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 3100
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 5500
    })
})

In [24]:
batch_size = 64
s_training_args = DistillTrainingArguments(output_dir='distilbert-base-uncased-ft-clinc', 
                                           evaluation_strategy='epoch', num_train_epochs=5, 
                                           learning_rate=3e-4, 
                                           per_device_train_batch_size=batch_size, 
                                           per_device_eval_batch_size=batch_size, 
                                           alpha=0.5, weight_decay=0.01, 
                                           logging_strategy='epoch',
                                           push_to_hub=False)
s_config = AutoConfig.from_pretrained(s_ckpt, num_labels=num_labels, 
                                      id2label=t_model.config.id2label, label2id=t_model.config.label2id)
# s_config

In [25]:

def student_init():
    return AutoModelForSequenceClassification.from_pretrained(s_ckpt, config=s_config).to(device)

### trainer.train

In [26]:
# from datasets import load_metric
# accuracy_score = load_metric('accuracy')
# SequenceClassification
import evaluate
accuracy_score = evaluate.load('accuracy')

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [27]:
# trainer 重要的回调函数，非成员函数
def compute_metrics(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=-1)
    return accuracy_score.compute(references=labels, predictions=predictions)

In [28]:
distill_trainer = DistillTrainer(model_init=student_init, teacher_model=t_model, args=s_training_args, 
                                 train_dataset=clinc_enc['train'], eval_dataset=clinc_enc['validation'], 
                                 compute_metrics=compute_metrics, tokenizer=s_tokenizer)
distill_trainer.train()

Downloading model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias', 'classifier.

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668651233332336, max=1.0…

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Accuracy
1,1.2065,0.401095,0.912903
2,0.28,0.318143,0.943871
3,0.2096,0.278893,0.953226
4,0.1901,0.268503,0.956452
5,0.1838,0.266352,0.957742




TrainOutput(global_step=600, training_loss=0.41400583267211916, metrics={'train_runtime': 76.8404, 'train_samples_per_second': 992.317, 'train_steps_per_second': 7.808, 'total_flos': 456233053284036.0, 'train_loss': 0.41400583267211916, 'epoch': 5.0})

In [29]:
import math
math.ceil(15250/(64*2)) * 5

600

### 使用

In [None]:
# import os
# os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'

In [30]:
from transformers import pipeline

# ft_ckpt = 'lanchunhui/distilbert-base-uncased-ft-clinc'
# distill_trainer.push_to_hub('finetune completed!')

pipe = pipeline('text-classification', model='./distilbert-base-uncased-ft-clinc/')