<a href="https://colab.research.google.com/github/dar-tau/nlp-experiments/blob/master/introbert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Initialization

In [None]:
!pip install transformers datasets
# !pip install simpletransformers

In [36]:
%cd /content
!mkdir data
%cd /content/data
!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json

/content
--2020-09-19 15:56:17--  https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 42123633 (40M) [application/json]
Saving to: ‘train-v2.0.json’


2020-09-19 15:56:17 (153 MB/s) - ‘train-v2.0.json’ saved [42123633/42123633]

/content/data: Scheme missing.
FINISHED --2020-09-19 15:56:17--
Total wall clock time: 0.4s
Downloaded: 1 files, 40M in 0.3s (153 MB/s)


In [47]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from tqdm import tqdm_notebook as tqdm
import json
import os

import re
import torch
from torch.utils.data import Dataset, DataLoader
import datasets

from transformers import AutoModelForQuestionAnswering, AutoTokenizer
from transformers import pipeline
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup

from transformers.data.processors.squad import SquadV2Processor, squad_convert_examples_to_features

In [3]:
device = 'cuda'

def torchTokenize(*args):
  return tokenizer(*args, truncation = True,
                       padding = True, return_tensors = 'pt')


def setModelHooks(model):
  attentionLayerRegex = r'^(.+\.)*layer\.(\d+)\.attention$'
  def _guyAttentionHook(name):
    layerNum = int(re.match(attentionLayerRegex, name).group(2))
    # Assumes there's only one attention per number
    def _myHook(m, inp, outp):
      assert((type(outp) == tuple) and (len(outp) == 1) )
      model.guyData[layerNum] = outp[0]#.argmax(dim = -1)

    return _myHook


  if hasattr(model, 'guyHooks'):
    print("Removing existing hooks!")
    [hook.remove() for hook in model.guyHooks]
  
  model.guyData = {}
  model.guyHooks = [module.register_forward_hook(_guyAttentionHook(name)) for name, module in model.named_modules()
                                                                          if re.match(attentionLayerRegex, name) is not None]

def dictToDevice(d, device):
  d_ = {}
  for k, v in d.items():
    if isinstance(v, torch.Tensor):
      d_[k] = v.to(device)
    else:
      d_[k] = v
  return d_

In [4]:
class IntrobertDataset(Dataset):
  def __init__(self, srcDataset, func, device = device):
    self.ds = srcDataset
    self.func = func
    self.device = device
    self.isModelSet = False

  def setModel(self, model, nLayers):
    self.model = model
    self.nLayers = nLayers
    self.isModelSet = True 

  def __getitem__(self, i):
    assert(self.isModelSet)
    context = self.func(self.ds[i])

    inputs = torchTokenize(context)
    inputs = dictToDevice(inputs, self.device)
    chosenLayer = np.random.choice(self.nLayers)
    question = "what is the most attended word in layer {}?".format(chosenLayer)
    def introspection(model):
      return model.guyData[chosenLayer].sum(dim = -1).argmax(dim = -1)
    
    return {'context': context, 'inputs': inputs,
            'question': question, 'introspection': introspection}

  def __len__(self):
    return len(self.ds)

## Main

In [5]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased-distilled-squad")
model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad", return_dict = True)
model.to(device)
optimizer = AdamW(model.parameters(), lr = 5e-5)
setModelHooks(model)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=473.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=260793700.0, style=ProgressStyle(descri…




In [43]:
max_seq_length = 384
doc_stride = 128
max_query_length = 64

squadExamples = SquadV2Processor().get_train_examples("/content/data")

100%|██████████| 442/442 [00:44<00:00,  9.91it/s]


In [None]:
squad_convert_examples_to_features(squadExamples, tokenizer = tokenizer, 
                                   max_seq_length = max_seq_length,
                                   max_query_length = max_query_length,
                                   doc_stride = doc_stride, is_training = True, return_dataset = 'pt')




convert squad examples to features:  30%|██▉       | 38497/130319 [04:06<11:27, 133.51it/s][A[A[A


convert squad examples to features:  30%|██▉       | 38529/130319 [04:07<11:15, 135.89it/s][A[A[A


convert squad examples to features:  30%|██▉       | 38561/130319 [04:07<11:02, 138.47it/s][A[A[A


convert squad examples to features:  30%|██▉       | 38593/130319 [04:07<10:52, 140.68it/s][A[A[A


convert squad examples to features:  30%|██▉       | 38625/130319 [04:07<10:42, 142.63it/s][A[A[A


convert squad examples to features:  30%|██▉       | 38657/130319 [04:07<11:15, 135.67it/s][A[A[A


convert squad examples to features:  30%|██▉       | 38689/130319 [04:08<11:07, 137.30it/s][A[A[A


convert squad examples to features:  30%|██▉       | 38721/130319 [04:08<11:47, 129.54it/s][A[A[A


convert squad examples to features:  30%|██▉       | 38753/130319 [04:08<12:07, 125.88it/s][A[A[A


convert squad examples to features:  30%|██▉       | 38785/130319 [04:

In [7]:
dataset = IntrobertDataset(squad['train'], lambda x: x['text'])
dataset.setModel(model, 6)

Checking /root/.cache/huggingface/datasets/1825be4101447d340c1153faa326883028c67acd5c49bbf76ba67648fb87c216.85f43de978b9b25921cb78d7a2f2b350c04acdbaedb9ecb5f7101cd7c0950e68.py for additional imports.
Found main folder for dataset https://raw.githubusercontent.com/huggingface/datasets/1.0.1/datasets/squad/squad.py at /root/.cache/huggingface/modules/datasets_modules/datasets/squad
Found specific version folder for dataset https://raw.githubusercontent.com/huggingface/datasets/1.0.1/datasets/squad/squad.py at /root/.cache/huggingface/modules/datasets_modules/datasets/squad/1244d044b266a5e4dbd4174d23cb995eead372fbca31a03edc3f8a132787af41
Found script file from https://raw.githubusercontent.com/huggingface/datasets/1.0.1/datasets/squad/squad.py to /root/.cache/huggingface/modules/datasets_modules/datasets/squad/1244d044b266a5e4dbd4174d23cb995eead372fbca31a03edc3f8a132787af41/squad.py
Found dataset infos file from https://raw.githubusercontent.com/huggingface/datasets/1.0.1/datasets/squad/d

In [None]:
SquadExample()

In [74]:
total = 5000
n_epochs = 10
num_training_steps = total * n_epochs 
num_warmup_steps = total 
losses = []
scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps, num_training_steps)


model.train()

for e in range(n_epochs):
  losses.append([])
  t = tqdm(dataset, total = total)
  acc_sum1 = 0
  acc_sum2 = 0
  for i, data in enumerate(t):
    if i >= total:
      break    
    model.eval()
    context = data['context']
    inputs = data['inputs']
    question = data['question']
    introspection = data['introspection']
    model(**inputs)

    model.train()
    model.zero_grad()
    res = introspection(model).detach()
    inputs = torchTokenize(context, question)
    inputs = dictToDevice(inputs, device)
    outputs = model(**inputs, start_positions = res,
                    end_positions = res)
    
    loss = outputs.loss
    losses[e].append(loss.item())
  
    acc_sum1 += int((res[0] == outputs.start_logits.argmax()).item()) 
    acc_sum2 += int((res[0] == outputs.end_logits.argmax()).item())

    acc1 = acc_sum1/i if i!=0 else 0.0
    acc2 = acc_sum2/i if i!=0 else 0.0

    t.set_postfix_str("Loss: {:.2f}, Acc1: {:.2f}, Acc2: {:.2f}".format(loss.item(), acc1, acc2))
    loss.backward()
    optimizer.step()
    scheduler.step()
    

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  del sys.path[0]


HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))

KeyboardInterrupt: ignored