<a href="https://colab.research.google.com/github/lokwq/TextBrewer/blob/add_note_examples/sqaudv1.1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook shows how to fine-tune a model on squadv1.1 dataset and how to distill the model with TextBrewer.

Detailed Docs can be find here: https://github.com/airaria/TextBrewer

In [1]:
import torch
device='cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
!pip install transformers
!pip install textbrewer

### Prepare dataset to train teacher model

In [3]:
import os
import random
import timeit

import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm_notebook
from transformers import AdamW,BertConfig,BertForQuestionAnswering,BertTokenizer,get_linear_schedule_with_warmup,squad_convert_examples_to_features

from transformers.data.metrics.squad_metrics import compute_predictions_logits,squad_evaluate
from transformers.data.processors.squad import SquadResult, SquadV1Processor

In [4]:
!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json

--2021-07-08 04:22:17--  https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 30288272 (29M) [application/json]
Saving to: ‘train-v1.1.json’


2021-07-08 04:22:18 (78.2 MB/s) - ‘train-v1.1.json’ saved [30288272/30288272]

--2021-07-08 04:22:18--  https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json
Reusing existing connection to rajpurkar.github.io:443.
HTTP request sent, awaiting response... 200 OK
Length: 4854279 (4.6M) [application/json]
Saving to: ‘dev-v1.1.json’


2021-07-08 04:22:19 (215 MB/s) - ‘dev-v1.1.json’ saved [4854279/4854279]

FINISHED --2021-07-08 04:22:19--
Total wall clock time: 1.7s
Downloaded: 2 files, 34M in 0.4s (85.7 MB/s)


In [5]:
processor = SquadV1Processor()
examples = processor.get_train_examples('/content/')


100%|██████████| 442/442 [00:33<00:00, 13.29it/s]


In [17]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertForQuestionAnswering.from_pretrained('bert-base-cased')

model.to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForQuestionAnswering: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-cased and a

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

In [7]:
features,train_dataset = squad_convert_examples_to_features(
    examples=examples,
    tokenizer=tokenizer,
    max_seq_length=384,
    doc_stride=128,
    max_query_length=64,
    is_training=True,
    return_dataset="pt"
)


convert squad examples to features: 100%|██████████| 87599/87599 [11:27<00:00, 127.42it/s]
add example index and unique id: 100%|██████████| 87599/87599 [00:00<00:00, 761102.36it/s]


In [8]:
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=8)

In [9]:
next(iter(train_dataloader))

[tensor([[ 101, 1332, 1108,  ...,    0,    0,    0],
         [ 101, 1327, 2523,  ...,    0,    0,    0],
         [ 101, 1130, 1184,  ...,    0,    0,    0],
         ...,
         [ 101, 1731, 1242,  ...,    0,    0,    0],
         [ 101, 1327, 1110,  ...,    0,    0,    0],
         [ 101, 1327, 1160,  ...,    0,    0,    0]]),
 tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 tensor([ 81,  99,  21, 146, 122, 130, 191,  77]),
 tensor([ 82, 100,  21, 147, 125, 130, 192,  80]),
 tensor([0, 0, 0, 0, 0, 0, 0, 0]),
 tensor([[0., 1., 1.,  ..., 1., 1., 1.],
         [0., 1., 1.,  ..., 1., 1., 

In [None]:
#Start training
import tqdm
epochs = 2
t_total = len(train_dataloader) * epochs

# Prepare optimizer and schedule 
optimizer = AdamW(model.parameters(), lr=3e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=t_total)


print("  Num examples = ", len(train_dataset))
print("  Total optimization steps = ", t_total)

steps = 1
tr_loss = 0.0
model.zero_grad()

for epoch in range(epochs):
    print('Epoch:{}'.format(epoch+1))
    epoch_iterator = tqdm.notebook.tqdm(train_dataloader, desc="Iteration", disable=False)
    for step, batch in enumerate(epoch_iterator):

        model.train()
        batch = tuple(t.to(device) for t in batch)

        inputs = {
            "input_ids": batch[0],
            "attention_mask": batch[1],
            "token_type_ids": batch[2],
            "start_positions": batch[3],
            "end_positions": batch[4],
        }

        outputs = model(**inputs)
        loss = outputs.loss

        loss.backward()

        tr_loss += loss.item()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()  
        model.zero_grad()
        steps += 1
        
        # Log 
        if steps % 500 == 0:
            print('steps = {}, logging_loss = {}'.format(steps,tr_loss))

print(" steps = %s, average loss = %s", steps, tr_loss / steps)

In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/squad_teacher_model.pt') #save the teacher model weights to distill

### Start distillation

In [9]:
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig
from transformers import BertForSequenceClassification, BertConfig, AdamW,BertTokenizer
from transformers import get_linear_schedule_with_warmup
import torch

In [None]:
from transformers import BertForQuestionAnswering,BertConfig


bert_config = BertConfig.from_json_file('/content/drive/MyDrive/TextBrewer-master/examples/student_config/bert_base_cased_config/bert_config.json')
bert_config_T3 = BertConfig.from_json_file('/content/drive/MyDrive/TextBrewer-master/examples/student_config/bert_base_cased_config/bert_config_L3.json')

bert_config.output_hidden_states = True
bert_config_T3.output_hidden_states = True


teacher_model = BertForQuestionAnswering(bert_config) 
student_model = BertForQuestionAnswering(bert_config_T3) 

teacher_model.load_state_dict(torch.load('/content/drive/MyDrive/squad_teacher_model.pt'))

teacher_model.to(device=device)
student_model.to(device=device)

In [None]:
num_epochs = 20
num_training_steps = len(train_dataloader) * num_epochs
# Optimizer and learning rate scheduler
optimizer = AdamW(student_model.parameters(), lr=1e-4)

scheduler_class = get_linear_schedule_with_warmup
# arguments dict except 'optimizer'
scheduler_args = {'num_warmup_steps':int(0.1*num_training_steps), 'num_training_steps':num_training_steps}


def simple_adaptor(batch, model_outputs):
    return {'logits': (model_outputs.start_logits,model_outputs.end_logits), 
        'hidden': model_outputs.hidden_states, 
        'attention': model_outputs.attentions}

distill_config = DistillationConfig(
    temperature = 1,
    intermediate_matches=[{"layer_T":[0,0],  "layer_S":[0,0], "feature":"hidden", "loss":"mmd", "weight":1},
               {"layer_T":[4,4],  "layer_S":[1,1], "feature":"hidden", "loss":"mmd", "weight":1},
               {"layer_T":[8,8],  "layer_S":[2,2], "feature":"hidden", "loss":"mmd", "weight":1},
               {"layer_T":[12,12], "layer_S":[3,3], "feature":"hidden", "loss":"mmd", "weight":1}])
train_config = TrainingConfig()

distiller = GeneralDistiller(
    train_config=train_config, distill_config=distill_config,
    model_T=teacher_model, model_S=student_model, 
    adaptor_T=simple_adaptor, adaptor_S=simple_adaptor)

def batch_postprocessor(batch):
  return {"input_ids": batch[0],
       "attention_mask": batch[1],
       "token_type_ids": batch[2],
       "start_positions": batch[3],
       "end_positions": batch[4]}

with distiller:
    distiller.train(optimizer, train_dataloader, num_epochs, scheduler_class=scheduler_class, scheduler_args = scheduler_args, batch_postprocessor=batch_postprocessor, callback=None)

In [None]:
bert_config_T3 = BertConfig.from_json_file('/content/drive/MyDrive/data/bert_config/bert_config_L3.json')
test_model = BertForQuestionAnswering(bert_config_T3)

test_model.load_state_dict(torch.load('/content/saved_models/gs2813.pkl'))
test_model.to(device)


In [12]:
processor = SquadV1Processor()
examples = processor.get_dev_examples('/content/')

features,eval_dataset = squad_convert_examples_to_features(
    examples=examples,
    tokenizer=tokenizer,
    max_seq_length=384,
    doc_stride=128,
    max_query_length=64,
    is_training=False,
    return_dataset="pt"
)


100%|██████████| 48/48 [00:06<00:00,  7.61it/s]
convert squad examples to features: 100%|██████████| 10570/10570 [01:24<00:00, 125.69it/s]
add example index and unique id: 100%|██████████| 10570/10570 [00:00<00:00, 685666.02it/s]


In [None]:
# Evaluate
import tqdm
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=8)

print("  Num examples = ", len(eval_dataset))

all_results = []

def to_tuple(self) -> Tuple[Any]:
    """
    Convert self to a tuple containing all the attributes/keys that are not ``None``.
    """
    return tuple(self[k] for k in self.keys())

for batch in tqdm.notebook.tqdm(eval_dataloader, desc="Evaluating"):
    test_model.eval()
    batch = tuple(t.to(device) for t in batch)

    with torch.no_grad():
        inputs = {
            "input_ids": batch[0],
            "attention_mask": batch[1],
            "token_type_ids": batch[2],
        }

        feature_indices = batch[3]

        outputs = model(**inputs)

    for i, feature_index in enumerate(feature_indices):
        eval_feature = features[feature_index.item()]
        unique_id = int(eval_feature.unique_id)
        #output = [output[i].detach().cpu().tolist() for output in outputs]
        #start_logits, end_logits = output
        output = [(output[i]).detach().cpu().tolist() for output in outputs.to_tuple()]
        start_logits, end_logits = output
        result = SquadResult(unique_id, start_logits, end_logits)
        all_results.append(result)


predictions = compute_predictions_logits(
    examples,
    features,
    all_results,
    n_best_size=20,
    max_answer_length=30,
    do_lower_case=False,
    output_prediction_file="predictions.json",
    output_nbest_file="nbest_predictions.json",
    output_null_log_odds_file=None,
    verbose_logging=False,
    version_2_with_negative=False,
    null_score_diff_threshold=0.0,
    tokenizer=tokenizer,)

# Compute the F1 and exact scores.
results = squad_evaluate(examples, predictions)

print("Results: {}".format(results))