In [37]:
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import GlueDataTrainingArguments, GlueDataset
from torch.utils.data.dataloader import DataLoader
from transformers.data.data_collator import default_data_collator
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
from textualheatmap import TextualHeatmap

In [38]:
data_args = GlueDataTrainingArguments(task_name="CoLA", data_dir="/home/keyur/medhas/hf211/glue_data/CoLA/")

In [39]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [40]:
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")

In [41]:
#data_loader = DataLoader(eval_dataset, collate_fn=default_data_collator, 
#                         sampler=SubsetRandomSampler([1, 302, 101, 565, 600, 680, 799, 1042]), batch_size=8)
data_loader = DataLoader(eval_dataset, collate_fn=default_data_collator, 
                         sampler=SubsetRandomSampler([101, 565, 1042]), batch_size=8)

In [42]:
inputs = data_loader.__iter__().next()

In [43]:
model.zero_grad()
model.train()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [44]:
outputs = model(**inputs, output_attentions=True)
loss, logit, attentions=outputs[:3]

In [45]:
loss.backward()

In [46]:
saliency_data = model.bert.embeddings.word_embeddings.weight.grad.data[inputs["input_ids"]].mean(dim=2)
saliency_data = np.array(saliency_data.detach().cpu())
norm_saliency_data = np.absolute(saliency_data)/np.absolute(saliency_data).sum(axis=1)[:, np.newaxis]
data = []
tiles = []
for record_index in range(inputs["input_ids"].shape[0]):
    record = inputs["input_ids"][record_index]
    tokens = tokenizer.convert_ids_to_tokens(record)
    valid_tokens = record.nonzero().squeeze().max()
    record_norm_sdata = norm_saliency_data[record_index][:valid_tokens+1]
    record_data = []
    for j in range(valid_tokens+1):
        record_data.append({
            'token': "%s "%tokens[j],
            'heat': list(map(lambda x: float(x)*3, record_norm_sdata)),
        })
    data.append(record_data) 
    tiles.append("Sal: %s"%record_index)

In [35]:
heatmap = TextualHeatmap(facet_titles = tiles, show_meta=False)
heatmap.set_data(data)
heatmap.highlight(0)