<a href="https://colab.research.google.com/github/bhadreshpsavani/UnderstandingNLP/blob/master/DistilbertPerformance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q transformers
!pip install -q datasets

[K     |████████████████████████████████| 1.9MB 10.9MB/s 
[K     |████████████████████████████████| 3.2MB 52.8MB/s 
[K     |████████████████████████████████| 890kB 58.0MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 184kB 7.7MB/s 
[K     |████████████████████████████████| 245kB 13.2MB/s 
[K     |████████████████████████████████| 112kB 15.7MB/s 
[K     |████████████████████████████████| 20.7MB 46.5MB/s 
[?25h

In [2]:
from transformers import squad_convert_examples_to_features
from transformers.data.processors.squad import SquadV2Processor, SquadFeatures
from tqdm.notebook import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from datasets import load_metric, load_dataset
device =  'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [5]:
def get_result(predictions, references, dataset='squad_v2'):
  squad_metric = load_metric(dataset)
  score = squad_metric.compute(predictions=predictions, references=references)
  return score

def get_validation_data(dataset='squad_v2'):
  datasets = load_dataset(dataset)
  valid_dataset = datasets['validation']
  return valid_dataset

def get_infernece(valid_dataset, model, tokenizer, device):
  predictions=[]
  references=[]
  model.to(device)
  for example in tqdm(valid_dataset):
    inputs = tokenizer(example['question'], example['context'], return_tensors="pt", truncation=True)
    inputs.to(device)
    output = model(**inputs)
    start_index = torch.argmax(output['start_logits'])
    end_index = torch.argmax(output['end_logits'])
    ans_ids = inputs['input_ids'][0][start_index :end_index+1]
    answer = tokenizer.decode(ans_ids)
    answer = tokenizer.clean_up_tokenization(answer).strip()
    no_answer_probability = 1 if len(example['answers']['answer_start'])!=0 else 0
    """
    predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22', 'no_answer_probability': 0.}]
    references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}]
    """
    pred = {'prediction_text':answer, 'id': example['id'], 'no_answer_probability' : no_answer_probability}
    ref = {'answers': example['answers'] , 'id': example['id']}
    predictions.append(pred)
    references.append(ref)
  return predictions, references

In [6]:
def print_results(model, tokenizer, device, dataset_name='squad_v2'):
  validation_dataset = get_validation_data(dataset_name)
  predictions, references = get_infernece(validation_dataset, model, tokenizer, device)
  score = get_result(predictions, references, dataset_name)
  print(score)

In [7]:
distilbert_path='twmkn9/distilbert-base-uncased-squad2'
distilbert_tokenizer = AutoTokenizer.from_pretrained(distilbert_path)
distilbert_model = AutoModelForQuestionAnswering.from_pretrained(distilbert_path)
distilbert_model.to(device)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=478.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=39.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=265482418.0, style=ProgressStyle(descri…




DistilBertForQuestionAnswering(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            

In [8]:
print_results(distilbert_model, distilbert_tokenizer, device)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1806.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=963.0, style=ProgressStyle(description_…


Downloading and preparing dataset squad_v2/squad_v2 (download: 44.34 MiB, generated: 122.57 MiB, post-processed: Unknown size, total: 166.91 MiB) to /root/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/9cac55034b086140f0649ecb5c604d09d7da2f2f5b73a90caa2e2bcc1f5cac09...


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=9551051.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=800683.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset squad_v2 downloaded and prepared to /root/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/9cac55034b086140f0649ecb5c604d09d7da2f2f5b73a90caa2e2bcc1f5cac09. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, max=11873.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2264.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3182.0, style=ProgressStyle(description…


{'exact': 31.272635391223783, 'f1': 35.63616173418905, 'total': 11873, 'HasAns_exact': 59.83468286099865, 'HasAns_f1': 68.57424903340527, 'HasAns_total': 5928, 'NoAns_exact': 2.7922624053826746, 'NoAns_f1': 2.7922624053826746, 'NoAns_total': 5945, 'best_exact': 50.07159100480081, 'best_exact_thresh': 0.0, 'best_f1': 50.07159100480081, 'best_f1_thresh': 0.0}
