In [1]:
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from transformers import GlueDataTrainingArguments, GlueDataset
from torch.utils.data.dataloader import DataLoader
from transformers.data.data_collator import default_data_collator


In [2]:
data_args = GlueDataTrainingArguments(task_name="CoLA", data_dir="/home/keyur/medhas/hf211/glue_data/CoLA/")

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [4]:
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")

In [5]:
n_gpu = torch.cuda.device_count()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
model = model.to(device)
if n_gpu > 1:
    model = torch.nn.DataParallel(model)

## Visualise attentions of a layer

In [7]:
from textualheatmap import TextualHeatmap

In [8]:
#Since this is for single segment, it would have first token [CLS], and last [SEP]
def insightful_ahead(valid_tokens_attentions, cutoff=0.6):
    total_weight = valid_tokens_attentions.sum()
    to_SEP_weight = valid_tokens_attentions[:,-1].sum()
    from_SEP_weight = valid_tokens_attentions[-1,:].sum()
    to_CLS_weight = valid_tokens_attentions[:,0].sum()
    from_CLS_weight = valid_tokens_attentions[0,:].sum()
    SEP_CLS_weights = to_SEP_weight+from_SEP_weight+to_CLS_weight+from_CLS_weight
    if SEP_CLS_weights > cutoff*total_weight:
        return False
    return True

In [9]:
def get_token_attention_data(inputs, tokenizer, attentions, record_index, cutoff=0.6, layers_heads=None):
    data = []
    tiles = []
    record = inputs["input_ids"][record_index]
    tokens = tokenizer.convert_ids_to_tokens(record)
    valid_tokens = inputs["input_ids"][record_index].nonzero().max()
    layer=11
    for layer in range(12):
        for ahead in range(12):
            if layers_heads is not None:
                if not "%s-%s"%(layer, ahead) in layers_heads:
                    continue
            record_data = []
            valid_tokens_attentions = attentions[layer][record_index][ahead][:valid_tokens+1,:valid_tokens+1].detach().cpu().numpy()

            if (False == insightful_ahead(valid_tokens_attentions, cutoff)):
                continue

            for j in range(valid_tokens+1):
                token_attentions = list(valid_tokens_attentions[j,:])
                record_data.append({"token": "%s "%tokens[j],
                                    "heat": list(map(lambda x: float(x), token_attentions)),
                                   })
            data.append(record_data)
            tiles.append("H-%s-%s"%(layer, ahead))
    return data, tiles

## Get specific example and visualize attention weights

In [10]:
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler

In [11]:
batch_sampler = BatchSampler(SubsetRandomSampler([101, 1042]), batch_size=2, drop_last=False)

In [12]:
#data_loader = DataLoader(eval_dataset, batch_size=16, collate_fn=default_data_collator) #Regular Data Loader
data_loader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_sampler=batch_sampler)

In [13]:
inputs = data_loader.__iter__().next()

In [14]:
model.train()
model.zero_grad()

In [15]:
for k, v in inputs.items():
    if isinstance(v, torch.Tensor):
        inputs[k] = v.to(device)

In [16]:
outputs = model(**inputs, output_attentions=True)



In [17]:
loss = outputs[0]
if (n_gpu > 1):
    loss = loss.mean()
loss.backward()

In [18]:
attentions=outputs[2]

In [19]:
data, tiles = get_token_attention_data(inputs, tokenizer, attentions, 0, 0.65)

In [20]:
data, tiles = get_token_attention_data(inputs, tokenizer, attentions, 0, 0.7)
heatmap = TextualHeatmap(facet_titles = tiles, show_meta=False)
heatmap.set_data(data)

<IPython.core.display.Javascript object>

In [21]:
data, tiles = get_token_attention_data(inputs, tokenizer, attentions, 1, 0.7)
heatmap = TextualHeatmap(facet_titles = tiles, show_meta=False)
heatmap.set_data(data)

<IPython.core.display.Javascript object>

In [26]:
data_0, tiles_0 = get_token_attention_data(inputs, tokenizer, attentions, 1, 1, layers_heads=['2-8', '3-5'])
heatmap_0 = TextualHeatmap(facet_titles = tiles_0, show_meta=False)
heatmap_0.set_data(data_0)

data_1, tiles_1 = get_token_attention_data(inputs, tokenizer, attentions, 0, 1, layers_heads=['2-8', '3-5'])
heatmap_1 = TextualHeatmap(facet_titles = tiles_1, show_meta=False)
heatmap_1.set_data(data_1)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>