In [1]:
from IPython.display import Image

## trainer arguments & trainer

- $q(x)$：from student model，$p(x)$：from teacher model
- 其次对于 $q(x), p(x)$ 在计算时需要加温度
$$
\begin{split}
L_{\text{student}}&=\alpha L_{\text{CE}} + (1-\alpha)L_{KD}\\
&=\alpha L_{\text{CE}} + (1-\alpha)T^2D_{KL}\\
&=\alpha L_{\text{CE}} + (1-\alpha)T^2\sum_ip_i(x)\log\frac{p_i(x)}{q_i(x)}
\end{split}
$$

- 关于 `nn.KLDivLoss()`
    - inputs ($q(x)$): log probabilities
    - labels ($p(x)$): normal probabilities

In [2]:
from transformers import TrainingArguments, Trainer
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import wandb

In [3]:
class DistillTrainingArguments(TrainingArguments):
    # TrainingArguments: @dataclass
    # 增加两个 KD 所需的参数参数
    def __init__(self, *args, alpha=0.5, temperature=2., **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.temperature = temperature

In [4]:
class DistillTrainer(Trainer):
    
    def __init__(self, *args, teacher_model=None, **kwargs):
        # 增加 teacher_model 参数
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        
    # 重写 trainer 中核心方法
    def compute_loss(self, model, inputs, return_outputs=False):
        s_output = model(**inputs)
        s_ce = s_output.loss
        s_logits = s_output.logits
        
        with torch.no_grad():
            t_output = self.teacher_model(**inputs)
            t_logits = t_output.logits
        
        loss_kl_fct = nn.KLDivLoss(reduction='batchmean')
        loss_kd = self.args.temperature**2 * loss_kl_fct(F.log_softmax(s_logits/self.args.temperature, dim=-1), 
                                                        F.softmax(t_logits/self.args.temperature, dim=-1))
        loss = self.args.alpha * s_ce + (1-self.args.alpha) * loss_kd
        return (loss, s_output) if return_outputs else loss


## pipeline

### datasets

In [5]:
from datasets import load_dataset
clinc = load_dataset("clinc_oos", "plus")

Using the latest cached version of the module from /home/whaow/.cache/huggingface/modules/datasets_modules/datasets/clinc_oos/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1 (last modified on Wed Jun 14 00:03:25 2023) since it couldn't be found locally at clinc_oos., or remotely on the Hugging Face Hub.
Found cached dataset clinc_oos (/home/whaow/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1)


  0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
clinc

DatasetDict({
    train: Dataset({
        features: ['text', 'intent'],
        num_rows: 15250
    })
    validation: Dataset({
        features: ['text', 'intent'],
        num_rows: 3100
    })
    test: Dataset({
        features: ['text', 'intent'],
        num_rows: 5500
    })
})

In [30]:
clinc['train'][:10]

{'text': ['what expression would i use to say i love you if i were an italian',
  "can you tell me how to say 'i do not speak much spanish', in spanish",
  "what is the equivalent of, 'life is good' in french",
  "tell me how to say, 'it is a beautiful morning' in italian",
  'if i were mongolian, how would i say that i am a tourist',
  "how do i say 'hotel' in finnish",
  "i need you to translate the sentence, 'we will be there soon' into portuguese",
  'please tell me how to ask for a taxi in french',
  "can you tell me how i would say, 'more bread please' in french",
  "what is the correct way to say 'i am a visitor' in french"],
 'intent': [61, 61, 61, 61, 61, 61, 61, 61, 61, 61]}

In [14]:
intents = clinc['train'].features['intent']
num_labels = intents.num_classes

### Student model 初始化

In [9]:
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoModelForSequenceClassification

In [39]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

s_ckpt = 'distilbert-base-uncased'
s_tokenizer = AutoTokenizer.from_pretrained(s_ckpt)

t_ckpt = 'transformersbook/bert-base-uncased-finetuned-clinc'
t_model = AutoModelForSequenceClassification.from_pretrained(t_ckpt, num_labels=num_labels).to(device)

In [17]:
clinc_enc = clinc.map(lambda batch: s_tokenizer(batch['text'], truncation=True), 
                      batched=True, 
                      remove_columns=["text"]
                     )
clinc_enc = clinc_enc.rename_columns({'intent': 'labels'})
clinc_enc

Loading cached processed dataset at /home/whaow/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1/cache-f9c6f4c987e9be44.arrow
Loading cached processed dataset at /home/whaow/.cache/huggingface/datasets/clinc_oos/plus/1.0.0/abcc41d382f8137f039adc747af44714941e8196e845dfbdd8ae7a7e020e6ba1/cache-83787f330c65b9f8.arrow


Map:   0%|          | 0/5500 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 15250
    })
    validation: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 3100
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 5500
    })
})

In [53]:
batch_size = 64
s_training_args = DistillTrainingArguments(output_dir='distilbert-base-uncased-ft-clinc', 
                                           evaluation_strategy='epoch', num_train_epochs=5, 
                                           learning_rate=3e-4, 
                                           per_device_train_batch_size=batch_size, 
                                           per_device_eval_batch_size=batch_size, 
                                           alpha=1, weight_decay=0.01, 
                                           logging_strategy='epoch',
                                           push_to_hub=True)
s_config = AutoConfig.from_pretrained(s_ckpt, num_labels=num_labels, 
                                      id2label=t_model.config.id2label, label2id=t_model.config.label2id)
# s_config

In [35]:

def student_init():
    return AutoModelForSequenceClassification.from_pretrained(s_ckpt, config=s_config).to(device)

### trainer.train

In [36]:
# from datasets import load_metric
# accuracy_score = load_metric('accuracy')
# SequenceClassification
import evaluate
accuracy_score = evaluate.load('accuracy')

Using the latest cached version of the module from /home/whaow/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--accuracy/f887c0aab52c2d38e1f8a215681126379eca617f96c447638f751434e8e65b14 (last modified on Sun Jun 25 23:06:51 2023) since it couldn't be found locally at evaluate-metric--accuracy, or remotely on the Hugging Face Hub.


In [54]:
# trainer 重要的回调函数，非成员函数
def compute_metrics(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=-1)
    return accuracy_score.compute(references=labels, predictions=predictions)

In [57]:
distill_trainer = DistillTrainer(model_init=student_init, teacher_model=t_model, args=s_training_args, 
                                 train_dataset=clinc_enc['train'], eval_dataset=clinc_enc['validation'], 
                                 compute_metrics=compute_metrics, tokenizer=s_tokenizer)
distill_trainer.train()

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias', 'pre_classifier.weight', 'pre_classifier.

Epoch,Training Loss,Validation Loss,Accuracy
1,1.7659,0.369926,0.912903
2,0.156,0.278254,0.936452
3,0.0623,0.295684,0.939032
4,0.0241,0.239457,0.947097
5,0.0094,0.234831,0.947742




TrainOutput(global_step=600, training_loss=0.4035262276728948, metrics={'train_runtime': 66.1569, 'train_samples_per_second': 1152.563, 'train_steps_per_second': 9.069, 'total_flos': 456233053284036.0, 'train_loss': 0.4035262276728948, 'epoch': 5.0})

In [45]:
import math
math.ceil(15250/(64*2)) * 5

600

### 使用

In [60]:
import os
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'

In [61]:
ft_ckpt = 'lanchunhui/distilbert-base-uncased-ft-clinc'
distill_trainer.push_to_hub('finetune completed!')

Several commits (3) will be pushed upstream.
The progress bars may be unreliable.
fatal: unable to access 'https://huggingface.co/lanchunhui/distilbert-base-uncased-ft-clinc/': gnutls_handshake() failed: Error in the pull function.

Error pushing update to the model card. Please read logs and retry.
$fatal: unable to access 'https://huggingface.co/lanchunhui/distilbert-base-uncased-ft-clinc/': gnutls_handshake() failed: Error in the pull function.

