<a href="https://colab.research.google.com/github/dar-tau/nlp-experiments/blob/master/introbert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Initialization

In [1]:
!pip install transformers datasets
# !pip install simpletransformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ae/05/c8c55b600308dc04e95100dc8ad8a244dd800fe75dfafcf1d6348c6f6209/transformers-3.1.0-py3-none-any.whl (884kB)
[K     |████████████████████████████████| 890kB 5.9MB/s 
[?25hCollecting datasets
[?25l  Downloading https://files.pythonhosted.org/packages/8e/f2/d213673d76ee56d907e462e6c144f1418368d35e6a9221799403116516de/datasets-1.0.1-py3-none-any.whl (1.8MB)
[K     |████████████████████████████████| 1.8MB 20.0MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 38.7MB/s 
Collecting sentencepiece!=0.1.92
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |███████████████████████

In [37]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from tqdm import tqdm_notebook as tqdm
import json
import os

import re
import torch
from torch.utils.data import Dataset, DataLoader
import datasets

from transformers import AutoModelForQuestionAnswering, AutoTokenizer
from transformers import pipeline
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup

In [19]:
device = 'cuda'

def torchTokenize(*args):
  return tokenizer(*args, truncation = True,
                       padding = True, return_tensors = 'pt')


def setModelHooks(model):
  attentionLayerRegex = r'^(.+\.)*layer\.(\d+)\.attention$'
  def _guyAttentionHook(name):
    layerNum = int(re.match(attentionLayerRegex, name).group(2))
    # Assumes there's only one attention per number
    def _myHook(m, inp, outp):
      assert((type(outp) == tuple) and (len(outp) == 1) )
      model.guyData[layerNum] = outp[0]#.argmax(dim = -1)

    return _myHook


  if hasattr(model, 'guyHooks'):
    print("Removing existing hooks!")
    [hook.remove() for hook in model.guyHooks]
  
  model.guyData = {}
  model.guyHooks = [module.register_forward_hook(_guyAttentionHook(name)) for name, module in model.named_modules()
                                                                          if re.match(attentionLayerRegex, name) is not None]

def dictToDevice(d, device):
  d_ = {}
  for k, v in d.items():
    if isinstance(v, torch.Tensor):
      d_[k] = v.to(device)
    else:
      d_[k] = v
  return d_

In [27]:
class IntrobertDataset(Dataset):
  def __init__(self, srcDataset, func, device = device):
    self.ds = srcDataset
    self.func = func
    self.device = device


  def __getitem__(self, i):
    context = self.func(self.ds[i])

    inputs = torchTokenize(context)
    inputs = dictToDevice(inputs, self.device)
    
    question = "what is the most attended word in layer 3?"
    def introspection(model):
      return model.guyData[3].sum(dim = -1).argmax(dim = -1)
    
    return {'context': context, 'inputs': inputs,
            'question': question, 'introspection': introspection}

  def __len__(self):
    return len(self.ds)

## Main

In [38]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased-distilled-squad")
model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad", return_dict = True)
model.to(device)
optimizer = AdamW(model.parameters(), lr = 1e-5)
setModelHooks(model)

In [39]:
imdb = datasets.load_dataset("imdb")
dataset = IntrobertDataset(imdb['train'], lambda x: x['text'])
# dataloader = DataLoader(dataset, batch_size = 8, shuffle = True)

Checking /root/.cache/huggingface/datasets/4d2b2997408b65402b80ecde9f2710be3b9edec2632497552299709859efe061.c39acffee84b8d7965ae2e5269ad438ebdb9a40b0607f38a5fdd81b1f8607864.py for additional imports.
Found main folder for dataset https://raw.githubusercontent.com/huggingface/datasets/1.0.1/datasets/imdb/imdb.py at /root/.cache/huggingface/modules/datasets_modules/datasets/imdb
Found specific version folder for dataset https://raw.githubusercontent.com/huggingface/datasets/1.0.1/datasets/imdb/imdb.py at /root/.cache/huggingface/modules/datasets_modules/datasets/imdb/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3
Found script file from https://raw.githubusercontent.com/huggingface/datasets/1.0.1/datasets/imdb/imdb.py to /root/.cache/huggingface/modules/datasets_modules/datasets/imdb/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3/imdb.py
Found dataset infos file from https://raw.githubusercontent.com/huggingface/datasets/1.0.1/datasets/imdb/dataset_info

In [None]:
total = 1000
n_epochs = 10
num_training_steps = total * n_epochs 
num_warmup_steps = total 
losses = []
scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps, num_training_steps)


model.train()

for e in range(n_epochs):
  losses.append([])
  t = tqdm(dataset, total = total)
  for i, data in enumerate(t):
    if i >= total:
      break    
    model.eval()
    context = data['context']
    inputs = data['inputs']
    question = data['question']
    introspection = data['introspection']
    model(**inputs)

    model.train()
    model.zero_grad()
    res = introspection(model).detach()
    inputs = torchTokenize(context, question)
    inputs = dictToDevice(inputs, device)
    outputs = model(**inputs, start_positions = res,
                    end_positions = res)
    
    loss = outputs.loss
    losses[e].append(loss.item())
    t.set_postfix_str("Loss: {}".format(loss.item()))
    loss.backward()
    optimizer.step()

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  del sys.path[0]


HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))