<a href="https://colab.research.google.com/github/lokwq/TextBrewer/blob/add_note_examples/sst2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook shows how to fine-tune a model on sst-2 dataset and how to distill the model with TextBrewer.

Detailed Docs can be find here:
https://github.com/airaria/TextBrewer

In [1]:
import torch
device='cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'

In [2]:
import os
import torch
from transformers import BertForSequenceClassification, BertTokenizer,BertConfig, AutoModelForSequenceClassification,RobertaTokenizer, RobertaForSequenceClassification
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import Trainer, TrainingArguments
from datasets import load_dataset,load_metric
from functools import partial
from predict_function import predict

In [None]:
# Settings 
task_name = "stsb"
base_model_name = 'roberta-base'

### Prepare dataset to train

In [3]:
train_dataset = load_dataset('glue', 'stsb', split='train')#,cache_dir="/work/mhessent/cache")
val_dataset = load_dataset('glue', 'stsb', split='validation')#,cache_dir="/work/mhessent/cache")
test_dataset = load_dataset('glue', 'stsb', split='test')#,cache_dir="/work/mhessent/cache")

Reusing dataset glue (/home/mhessent/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Reusing dataset glue (/home/mhessent/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Reusing dataset glue (/home/mhessent/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


In [4]:
train_dataset = train_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)
val_dataset = val_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)
test_dataset = test_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)

val_dataset = val_dataset.remove_columns(['label'])
test_dataset = test_dataset.remove_columns(['label'])
train_dataset = train_dataset.remove_columns(['label'])

Loading cached processed dataset at /home/mhessent/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-2a1905efa4704bcb.arrow
Loading cached processed dataset at /home/mhessent/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-e194e0b596fcf478.arrow
Loading cached processed dataset at /home/mhessent/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-c97e180e5b68e6bf.arrow


In [5]:
#model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=1)
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

In [6]:
MAX_LENGTH = 128
train_dataset = train_dataset.map(lambda e: tokenizer(e['sentence1'],e['sentence2'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)
val_dataset = val_dataset.map(lambda e: tokenizer(e['sentence1'],e['sentence2'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)
test_dataset = test_dataset.map(lambda e: tokenizer(e['sentence1'],e['sentence2'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)

Loading cached processed dataset at /home/mhessent/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-8ce536a089ec4a7b.arrow
Loading cached processed dataset at /home/mhessent/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-5d5d583045ec07fc.arrow
Loading cached processed dataset at /home/mhessent/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-4b0cd40886585b17.arrow


In [7]:
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

In [8]:
train_dataset[1]

{'labels': tensor(3.8000),
 'input_ids': tensor([   0,  250,  313,   16,  816,   10,  739, 2342, 4467,    4,    2,    2,
          250,  313,   16,  816,   10, 2342, 4467,    4,    2,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0

In [9]:
val_dataset[1]

{'labels': tensor(4.7500),
 'input_ids': tensor([   0,  250,  664,  920,   16, 5793,   10, 5253,    4,    2,    2,  250,
          920,   16, 5793,   10, 5253,    4,    2,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0

In [10]:

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    metric = load_metric("glue","stsb")
    return metric.compute(predictions=preds, references=labels)


In [11]:
#start training 
"""
training_args = TrainingArguments(
    output_dir='outputs/results',          #output directory
    learning_rate=3e-5,
    num_train_epochs=3,              
    per_device_train_batch_size=32,                #batch size per device during training
    per_device_eval_batch_size=32,                #batch size for evaluation
    logging_dir='outputs/logs',            
    logging_steps=500,
    do_train=True,
    do_eval=True,
    no_cuda=False,
    load_best_model_at_end=True,
    # eval_steps=100,
    evaluation_strategy="epoch",
    save_strategy="epoch"
)

trainer = Trainer(
    model=model,                         
    args=training_args,                  
    train_dataset=train_dataset,         
    eval_dataset=val_dataset,            
    compute_metrics=compute_metrics
)

train_out = trainer.train()
"""
#after training, you could find traing logs and checpoints in your own dirve. also you can reset the file address in training args

'\ntraining_args = TrainingArguments(\n    output_dir=\'outputs/results\',          #output directory\n    learning_rate=3e-5,\n    num_train_epochs=3,              \n    per_device_train_batch_size=32,                #batch size per device during training\n    per_device_eval_batch_size=32,                #batch size for evaluation\n    logging_dir=\'outputs/logs\',            \n    logging_steps=500,\n    do_train=True,\n    do_eval=True,\n    no_cuda=False,\n    load_best_model_at_end=True,\n    # eval_steps=100,\n    evaluation_strategy="epoch",\n    save_strategy="epoch"\n)\n\ntrainer = Trainer(\n    model=model,                         \n    args=training_args,                  \n    train_dataset=train_dataset,         \n    eval_dataset=val_dataset,            \n    compute_metrics=compute_metrics\n)\n\ntrain_out = trainer.train()\n'

In [12]:
#torch.save(model.state_dict(), 'outputs/stsb_teacher_model.pt')


### Start distillation

In [13]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32) #prepare dataloader

In [14]:
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig
from transformers import BertForSequenceClassification, BertConfig, AdamW,BertTokenizer, RobertaConfig, RobertaForSequenceClassification
from transformers import get_linear_schedule_with_warmup

In [15]:
config = RobertaConfig.from_pretrained("roberta-base", output_hidden_states=True)
config

RobertaConfig {
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "output_hidden_states": true,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.17.0",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

In [16]:
distilroberta_config = RobertaConfig.from_pretrained("distilroberta-base", output_hidden_states=True)
distilroberta_config

RobertaConfig {
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "output_hidden_states": true,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.17.0",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}

Initialize the student model by BertConfig and prepare the teacher model.

bert_config_L3.json refers to a 3-layer Bert.

bert_config.json refers to a standard 12-layer Bert.

In [28]:
config = RobertaConfig.from_pretrained("roberta-base", output_hidden_states=True)
#config = RobertaConfig.from_json_file('/work/mhessent/TextBrewer/examples/student_config/bert_base_cased_config/bert_config.json')
config.output_hidden_states = True
#config.vocab_size = len(tokenizer)
config.num_labels = 1

teacher_model = RobertaForSequenceClassification(config)
#teacher_model.load_state_dict(torch.load('/work/mhessent/master_thesis/eval_out/roberta-base/stsb/lr3e-05_bs16_epochs10/torch_state_dict.pt'))
teacher_model.load_state_dict(torch.load('outputs/stsb_teacher_model.pt'))
"""
model = BertForSequenceClassification.from_pretrained("/work/mhessent/master_thesis/eval_out/bert-base-uncased/mnli/lr3e-05_bs32_epochs3/checkpoint-36816")
torch.save(model.state_dict(), 'outputs/hub_mnli_teacher_model.pt')
bert_config = BertConfig.from_json_file('/work/mhessent/TextBrewer/examples/student_config/bert_base_cased_config/bert_config.json')
bert_config.output_hidden_states = True
bert_config.vocab_size = 30522
bert_config.num_labels = 3
teacher_model = BertForSequenceClassification(bert_config) #, num_labels = 2
teacher_model.load_state_dict(torch.load('outputs/hub_mnli_teacher_model.pt'))
"""


teacher_model = teacher_model.to(device=device)



student_config = RobertaConfig.from_pretrained("roberta-base", output_hidden_states=True)
student_config.output_hidden_states = True
student_config.num_labels = 1
student_config.num_hidden_layers = 3
student_config.hidden_dropout_prob = 0.3
student_config.attention_probs_dropout_prob = 0.3

#student_config.vocab_size = teacher_model.config.vocab_size

student_model = RobertaForSequenceClassification(student_config)
student_model = student_model.to(device=device)


print(teacher_model.config.vocab_size)
print(student_model.config.vocab_size)
print(len(tokenizer))

50265
50265
50265


The cell below is to distill the teacher model to student model you prepared.

After the code execution is complete, the distilled model will be in 'saved_model' in colab file list

In [29]:
num_epochs = 60
num_training_steps = len(train_dataloader) * num_epochs
# Optimizer and learning rate scheduler
optimizer = AdamW(student_model.parameters(), lr=1e-4)

scheduler_class = get_linear_schedule_with_warmup
# arguments dict except 'optimizer'
scheduler_args = {'num_warmup_steps':int(0.1*num_training_steps), 'num_training_steps':num_training_steps}


def simple_adaptor(batch, model_outputs):
    return {'logits': model_outputs.logits, 'hidden': model_outputs.hidden_states}


from matches import matches
intermediate_matches = None
match_list_L4t = ["L4t_hidden_mse", "L4_hidden_smmd"]
match_list_L3 = ["L3_hidden_mse", "L3_hidden_smmd"]
intermediate_matches = []
for match in match_list_L3:
        intermediate_matches += matches[match]

distill_config = DistillationConfig(kd_loss_type='mse',temperature=4)#,intermediate_matches=intermediate_matches)
train_config = TrainingConfig(device=device)



task_name = "stsb"
local_rank = -1
predict_batch_size = 32
device = device
output_dir = "outputs/" + task_name + "/" 
eval_datasets = [val_dataset]
do_train_eval = True

callback_func = partial(predict, eval_datasets=eval_datasets, output_dir=output_dir,task_name=task_name,local_rank=local_rank,predict_batch_size=predict_batch_size,device=device, do_train_eval=do_train_eval, train_dataset=train_dataset)

distiller = GeneralDistiller(
    train_config=train_config, distill_config=distill_config,
    model_T=teacher_model, model_S=student_model, 
    adaptor_T=simple_adaptor, adaptor_S=simple_adaptor)


with distiller:
    distiller.train(optimizer, train_dataloader, num_epochs, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=callback_func)

2022/03/29 18:48:48 - INFO - Distillation -  Training steps per epoch: 180
2022/03/29 18:48:48 - INFO - Distillation -  Checkpoints(step): [0]
2022/03/29 18:48:48 - INFO - Distillation -  Epoch 1
2022/03/29 18:48:48 - INFO - Distillation -  Length of current epoch in forward batch: 180
2022/03/29 18:48:49 - INFO - Distillation -  Global step: 9, epoch step:9
2022/03/29 18:48:52 - INFO - Distillation -  Global step: 18, epoch step:18
2022/03/29 18:48:54 - INFO - Distillation -  Global step: 27, epoch step:27
2022/03/29 18:48:56 - INFO - Distillation -  Global step: 36, epoch step:36
2022/03/29 18:48:58 - INFO - Distillation -  Global step: 45, epoch step:45
2022/03/29 18:49:00 - INFO - Distillation -  Global step: 54, epoch step:54
2022/03/29 18:49:02 - INFO - Distillation -  Global step: 63, epoch step:63
2022/03/29 18:49:04 - INFO - Distillation -  Global step: 72, epoch step:72
2022/03/29 18:49:06 - INFO - Distillation -  Global step: 81, epoch step:81
2022/03/29 18:49:08 - INFO - Di

KeyboardInterrupt: 

In [None]:
from textbrewer.distiller_utils import move_to_device

In [None]:
test_model = RobertaForSequenceClassification(student_config)
test_model.load_state_dict(torch.load('/work/mhessent/TextBrewer/examples/notebook_examples/saved_models/gs9900.pkl'))#gs4210 is the distilled model weights file

In [None]:
from torch.utils.data import DataLoader
eval_dataloader = DataLoader(val_dataset, batch_size=8)

In [None]:
metric= load_metric("glue","stsb")
test_model.to(device)
test_model.eval()
for batch in train_dataloader:
    batch = {k: v for k, v in batch.items()}
    batch = move_to_device(batch,device)
    with torch.no_grad():
        outputs = test_model(**batch)

    logits = outputs.logits
    metric.add_batch(predictions=logits, references=batch["labels"])

metric.compute()

In [None]:
#teacher_model = RobertaForSequenceClassification.from_pretrained('/work/mhessent/master_thesis/eval_out/roberta-base/stsb/lr3e-05_bs32_epochs10/checkpoint-1620')
metric= load_metric("glue","stsb")
#teacher_model.cpu()
teacher_model.to(device)
teacher_model.eval()
for batch in train_dataloader:
    batch = {k: v for k, v in batch.items()}
    batch = move_to_device(batch,device)
    with torch.no_grad():
        outputs = teacher_model(**batch)

    logits = outputs.logits
    metric.add_batch(predictions=logits, references=batch["labels"])

metric.compute()