<a href="https://colab.research.google.com/github/lokwq/TextBrewer/blob/add_note_examples/sst2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook shows how to fine-tune a model on sst-2 dataset and how to distill the model with TextBrewer.

Detailed Docs can be find here:
https://github.com/airaria/TextBrewer

In [1]:
import torch
device='cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'

In [2]:
import os
import torch
from transformers import BertForSequenceClassification, BertTokenizer,BertConfig, AutoModelForSequenceClassification
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import Trainer, TrainingArguments
from datasets import load_dataset,load_metric

from functools import partial
from predict_function import predict

### Prepare dataset to train

In [3]:
train_dataset = load_dataset('glue', 'sst2', split='train')
val_dataset = load_dataset('glue', 'sst2', split='validation')
test_dataset = load_dataset('glue', 'sst2', split='test')

Reusing dataset glue (/home/mhessent/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Reusing dataset glue (/home/mhessent/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Reusing dataset glue (/home/mhessent/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


In [4]:
train_dataset = train_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)
val_dataset = val_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)
test_dataset = test_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)
val_dataset = val_dataset.remove_columns(['label'])
test_dataset = test_dataset.remove_columns(['label'])
train_dataset = train_dataset.remove_columns(['label'])

Loading cached processed dataset at /home/mhessent/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-b9e78673c79f89ac.arrow
Loading cached processed dataset at /home/mhessent/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-2258e450e7114d5b.arrow
Loading cached processed dataset at /home/mhessent/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-fc7e588b0e3f2b71.arrow


In [5]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
#tokenizer = BertTokenizer.from_pretrained("howey/bert-base-uncased-sst2")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [6]:
MAX_LENGTH = 128
train_dataset = train_dataset.map(lambda e: tokenizer(e['sentence'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)
val_dataset = val_dataset.map(lambda e: tokenizer(e['sentence'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)
test_dataset = test_dataset.map(lambda e: tokenizer(e['sentence'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)

Loading cached processed dataset at /home/mhessent/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-8b91720ce2aeadc7.arrow
Loading cached processed dataset at /home/mhessent/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-31c6ef88b4a5a656.arrow
Loading cached processed dataset at /home/mhessent/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-caa256527e820588.arrow


In [7]:
train_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
val_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
test_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

In [8]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [9]:
#start training 
"""
training_args = TrainingArguments(
    output_dir='outputs/results',          #output directory
    learning_rate=3e-5,
    num_train_epochs=3,              
    per_device_train_batch_size=32,                #batch size per device during training
    per_device_eval_batch_size=32,                #batch size for evaluation
    logging_dir='outputs/logs',            
    logging_steps=100,
    do_train=True,
    do_eval=True,
    no_cuda=False,
    load_best_model_at_end=True,
    # eval_steps=100,
    evaluation_strategy="epoch",
    save_strategy="epoch"
)

trainer = Trainer(
    model=model,                         
    args=training_args,                  
    train_dataset=train_dataset,         
    eval_dataset=val_dataset,            
    compute_metrics=compute_metrics
)

train_out = trainer.train()
"""
#after training, you could find traing logs and checpoints in your own dirve. also you can reset the file address in training args

'\ntraining_args = TrainingArguments(\n    output_dir=\'outputs/results\',          #output directory\n    learning_rate=3e-5,\n    num_train_epochs=3,              \n    per_device_train_batch_size=32,                #batch size per device during training\n    per_device_eval_batch_size=32,                #batch size for evaluation\n    logging_dir=\'outputs/logs\',            \n    logging_steps=100,\n    do_train=True,\n    do_eval=True,\n    no_cuda=False,\n    load_best_model_at_end=True,\n    # eval_steps=100,\n    evaluation_strategy="epoch",\n    save_strategy="epoch"\n)\n\ntrainer = Trainer(\n    model=model,                         \n    args=training_args,                  \n    train_dataset=train_dataset,         \n    eval_dataset=val_dataset,            \n    compute_metrics=compute_metrics\n)\n\ntrain_out = trainer.train()\n'

In [10]:
#torch.save(model.state_dict(), 'outputs/sst2_teacher_model.pt')


### Start distillation

In [11]:
#train_dataset = train_dataset.select(range(10))
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32) #prepare dataloader

In [12]:
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig
from transformers import BertForSequenceClassification, BertConfig, AdamW,BertTokenizer
from transformers import get_linear_schedule_with_warmup

Initialize the student model by BertConfig and prepare the teacher model.

bert_config_L3.json refers to a 3-layer Bert.

bert_config.json refers to a standard 12-layer Bert.

In [13]:
#model = BertForSequenceClassification.from_pretrained("howey/bert-base-uncased-sst2")
#torch.save(model.state_dict(), 'outputs/hub_sst_teacher_model.pt')
bert_config = BertConfig.from_json_file('/work/mhessent/TextBrewer/examples/student_config/bert_base_cased_config/bert_config.json')
bert_config.output_hidden_states = True
bert_config.vocab_size = 30522
#bert_config.num_labels = 2
teacher_model = BertForSequenceClassification(bert_config) #, num_labels = 2
teacher_model.load_state_dict(torch.load('outputs/hub_sst_teacher_model.pt'))

teacher_model = teacher_model.to(device=device)
#bert_config_T3 = BertConfig.from_json_file('/work/mhessent/TextBrewer/examples/student_config/bert_base_cased_config/bert_config_L3_v2.json')#相对路径
bert_config_T3 = BertConfig.from_json_file('/work/mhessent/TextBrewer/examples/student_config/bert_base_cased_config/bert_config_L3.json')
bert_config_T3.output_hidden_states = True
bert_config_T3.vocab_size = teacher_model.config.vocab_size

student_model = BertForSequenceClassification(bert_config_T3) #, num_labels = 2
student_model = student_model.to(device=device)

"""
bert_config = BertConfig.from_json_file('/work/mhessent/TextBrewer/examples/student_config/bert_base_cased_config/bert_config.json')
bert_config.output_hidden_states = True
bert_config.vocab_size = 30522
#print(bert_config)
teacher_model = BertForSequenceClassification(bert_config) #, num_labels = 2
teacher_model.load_state_dict(torch.load('outputs/sst2_teacher_model.pt'))
#pretrained_config = BertConfig(output_hidden_states=True, output_attentions=True, num_labels=2)
#teacher_model = AutoModelForSequenceClassification.from_pretrained("howey/bert-base-uncased-sst2", config=pretrained_config)

teacher_model = teacher_model.to(device=device)
"""

'\nbert_config = BertConfig.from_json_file(\'/work/mhessent/TextBrewer/examples/student_config/bert_base_cased_config/bert_config.json\')\nbert_config.output_hidden_states = True\nbert_config.vocab_size = 30522\n#print(bert_config)\nteacher_model = BertForSequenceClassification(bert_config) #, num_labels = 2\nteacher_model.load_state_dict(torch.load(\'outputs/sst2_teacher_model.pt\'))\n#pretrained_config = BertConfig(output_hidden_states=True, output_attentions=True, num_labels=2)\n#teacher_model = AutoModelForSequenceClassification.from_pretrained("howey/bert-base-uncased-sst2", config=pretrained_config)\n\nteacher_model = teacher_model.to(device=device)\n'

In [14]:
print(teacher_model.config.vocab_size)
print(student_model.config.vocab_size)
print(len(tokenizer))


30522
30522
30522


The cell below is to distill the teacher model to student model you prepared.

After the code execution is complete, the distilled model will be in 'saved_model' in colab file list

In [15]:
num_epochs = 20
num_training_steps = len(train_dataloader) * num_epochs
# Optimizer and learning rate scheduler
optimizer = AdamW(student_model.parameters(), lr=1e-4)

scheduler_class = get_linear_schedule_with_warmup
# arguments dict except 'optimizer'
scheduler_args = {'num_warmup_steps':int(0.1*num_training_steps), 'num_training_steps':num_training_steps}


def simple_adaptor(batch, model_outputs):
    return {'logits': model_outputs.logits, 'hidden': model_outputs.hidden_states}

distill_config = DistillationConfig(
    intermediate_matches=[    
     {'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},
     {'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])
train_config = TrainingConfig(device=device)

distiller = GeneralDistiller(
    train_config=train_config, distill_config=distill_config,
    model_T=teacher_model, model_S=student_model, 
    adaptor_T=simple_adaptor, adaptor_S=simple_adaptor)


task_name = "sst-2"
local_rank = -1
predict_batch_size = 32
device = device
output_dir = "outputs/" + task_name + "/" 

callback_func = partial(predict, eval_datasets=[val_dataset], output_dir=output_dir,task_name=task_name,local_rank=local_rank,predict_batch_size=predict_batch_size,device=device)

with distiller:
    distiller.train(optimizer, train_dataloader, num_epochs, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=callback_func)

2022/03/18 18:42:50 - INFO - Distillation -  Training steps per epoch: 2105
2022/03/18 18:42:50 - INFO - Distillation -  Checkpoints(step): [0]
2022/03/18 18:42:50 - INFO - Distillation -  Epoch 1
2022/03/18 18:42:50 - INFO - Distillation -  Length of current epoch in forward batch: 2105
2022/03/18 18:43:00 - INFO - Distillation -  Global step: 105, epoch step:105
2022/03/18 18:43:09 - INFO - Distillation -  Global step: 210, epoch step:210
2022/03/18 18:43:18 - INFO - Distillation -  Global step: 315, epoch step:315
2022/03/18 18:43:28 - INFO - Distillation -  Global step: 420, epoch step:420
2022/03/18 18:43:37 - INFO - Distillation -  Global step: 525, epoch step:525
2022/03/18 18:43:47 - INFO - Distillation -  Global step: 630, epoch step:630
2022/03/18 18:43:56 - INFO - Distillation -  Global step: 735, epoch step:735
2022/03/18 18:44:06 - INFO - Distillation -  Global step: 840, epoch step:840
2022/03/18 18:44:15 - INFO - Distillation -  Global step: 945, epoch step:945
2022/03/1

In [16]:
test_model = BertForSequenceClassification(bert_config_T3)
test_model.load_state_dict(torch.load('/work/mhessent/TextBrewer/examples/notebook_examples/saved_models_old/gs2105.pkl'))#gs4210 is the distilled model weights file

<All keys matched successfully>

In [17]:
from torch.utils.data import DataLoader
eval_dataloader = DataLoader(val_dataset, batch_size=8)

In [18]:
metric= load_metric("accuracy")
test_model.eval()
for batch in eval_dataloader:
    batch = {k: v for k, v in batch.items()}
    with torch.no_grad():
        outputs = test_model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()

{'accuracy': 0.7901376146788991}