## Perplexity & Attention Weights

In [214]:
import torch
import numpy as np
from transformers import AutoTokenizer
from transformers import AutoModelForMaskedLM

In [190]:
man_words = ['man', 'men', 'male', 'he', 'him', 'his']
woman_words = ['woman', 'women', 'female', 'she', 'her', 'hers']
pronoun_oppos = dict()
for i, man_word in enumerate(man_words):
    pronoun_oppos[man_word] = woman_words[i]
    pronoun_oppos[woman_words[i]] = man_word

In [191]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("bert_mlm/block_512/bert_mlm_textbook", output_attentions=True)

In [192]:
# Contexts fed into BERT must start with a [CLS] token and (possibly?) end with a [SEP] token
mask_token, mask_id = tokenizer.mask_token, tokenizer.mask_token_id
cls_token, cls_id = tokenizer.cls_token, tokenizer.cls_token_id
sep_token, sep_id = tokenizer.sep_token, tokenizer.sep_token_id

In [193]:
sentence = "The man sat on the mat at work"
inputs = tokenizer.encode(sentence, return_tensors='pt')

In [194]:
tokens = tokenizer.convert_ids_to_tokens(inputs[0])
tokens[pronoun_idx] = mask_token
inputs[0][pronoun_idx] = mask_id

In [195]:
interest_idx = 8
pronoun_idx = 2

In [196]:
model.eval()
# init softmax to get probabilities later on
softmax = torch.nn.Softmax(dim=0)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x12db3dc50>

In [197]:
# get the position of the masked token
masked_position = (inputs.squeeze() == tokenizer.mask_token_id).nonzero().item()
print(masked_position)

2


In [198]:
# forward
outputs = model(inputs)
attention = outputs.attentions  # Output includes attention weights when output_attentions=True
last_hidden_state = outputs[0].squeeze(0)
# only get output for masked token
# output is the size of the vocabulary
mask_hidden_state = last_hidden_state[masked_position]
# convert to probabilities (softmax)
# giving a probability for each item in the vocabulary
probs = softmax(mask_hidden_state)

In [199]:
# get probability of token <pronoun>
pronoun = 'man'
pronoun_id = tokenizer.convert_tokens_to_ids(pronoun)
print(pronoun, 'probability', probs[pronoun_id].item())

# get probability of token <opposite_pronoun>
opp_pronoun = pronoun_oppos[pronoun]
opp_pronoun_id = tokenizer.convert_tokens_to_ids(opp_pronoun)
print(opp_pronoun, 'probability', probs[opp_pronoun_id].item())

man probability 0.00973216351121664
woman probability 0.005248120985925198


In [200]:
top_word = torch.argmax(probs)
print('Top Prediction:')
print(tokenizer.decode(top_word), 'probability', probs[top_word].item())

Top Prediction:
children probability 0.02098049595952034


In [225]:
att_weights = []
for att_layer in attention:
    layer = att_layer.squeeze()
    layer_weights = layer[:, pronoun_idx, interest_idx].numpy()
    att_weights.append(layer_weights)
att_weights = np.stack(att_weights, axis=0)

## Attention Head View

In [178]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
from bertviz import head_view

In [179]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("bert_mlm/block_512/bert_mlm_textbook", output_attentions=True)

In [180]:
inputs = tokenizer.encode("The cat sat on the mat", return_tensors='pt')
outputs = model(inputs)
attention = outputs.attentions  # Output includes attention weights when output_attentions=True
tokens = tokenizer.convert_ids_to_tokens(inputs[0])

In [181]:
head_view(attention, tokens)

<IPython.core.display.Javascript object>

## Attention Neuron View

In [182]:
# Import specialized versions of models (that return query/key vectors)
from bertviz.transformers_neuron_view import BertModel, BertTokenizer
from bertviz.neuron_view import show

In [183]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model = BertModel.from_pretrained('bert_mlm/block_512/bert_mlm_textbook', output_attentions=True)
model_type = 'bert'
sentence = "The cat sat on the mat"
show(model, model_type, tokenizer, sentence)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>