[Findings of EMNLP22] From Mimicking to Integrating: Knowledge Integration for Pre-Trained Language Models
We recommend to use virtual environment for re-producing the results.
conda create -n muki python=3.7.10
conda activate muki
conda install pytorch torchvision cudatoolkit=11.0 -c pytorch
pip install -r requirements.txt
The main setup of our paper is to train two teacher models specialized in different class subsets of a classification problem.
Take the THU-CNews as an example, run the following command to obtain two teacher models:
bash scripts/train_teacher.sh
After the training of teacher model finished, we can perform knowledge intergration via various distillation methods.
# vanilla KD
bash scripts/vkd.sh
# UHC
bash scripts/uhc.sh
# DFA
bash scripts/dfa.sh
# MUKI(Ours)
bash scripts/muki.sh
For our methods MUKI, please check the script and corresponding model file models/uka_multiple_teacher.py
for more details.
For the Monte-Carlo dropout, to reduce the computation of uncertainty estimation, we pre-compute the scores and saved it into a numpy file (see models/monte_carlo.py
for details). The integration can be conducted by reading the corresponding files to accelerate training.
We provide the corresponding weights in Google Drive.
It can also be achieved by compute the uncertainty on-the-fly for your own custome dataset, by adding code like below:
with torch.no_grad(): # Monte Carlo Dropout on the fly
probs = []
for m in range(self.mc_number): # monte carlo dropout number
for i, t_model in enumerate(self.teachers):
t_model.train() # activate dropout
teacher_output = t_model(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=False,
return_dict=return_dict, )
# bsz, seq_len, logits
teacher_logit = teacher_output[0]
teacher_prob = F.softmax(teacher_logit, dim=-1)
# print(teacher_prob)
if m == 0:
probs.append(teacher_prob) #
else:
probs[i] += teacher_prob
probs = [prob / self.mc_number for prob in probs]
# get the logits
t_model.eval()
for i, t_model in enumerate(self.teachers):
teacher_output = t_model(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=False,
return_dict=return_dict, )
t_logit = teacher_output[0] # bsz, seq_len, logits
t_logits.append(t_logit)
teacher_probs = [F.softmax(t_logit / self.kd_temperature, dim=-1) for t_logit in t_logits]
If you want to compute the teacher score file, please check the scripts/md_cnews.sh
and models/monte_carlo.py
for more details.