## Extract Perplexity & Attention Weights

In [34]:
import copy

In [1]:
import torch
import numpy as np
from transformers import AutoTokenizer
from transformers import AutoModelForMaskedLM

In [2]:
import utilities

In [15]:
# Tokenizer and model used throughout
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForMaskedLM.from_pretrained("bert_mlm/block_512/bert_mlm_textbook", output_attentions=True)
# Init softmax to get probabilities later on
softmax = torch.nn.Softmax(dim=0)
# Put the model in "evaluation" mode, meaning feed-forward operation.
model.eval()
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fb639f0d710>

In [4]:
# Contexts fed into BERT must start with a [CLS] token and (possibly?) end with a [SEP] token
mask_token, mask_id = tokenizer.mask_token, tokenizer.mask_token_id
cls_token, cls_id = tokenizer.cls_token, tokenizer.cls_token_id
sep_token, sep_id = tokenizer.sep_token, tokenizer.sep_token_id

In [5]:
# Relevant paths
config = "_128_50" # window size _ max distance 
results_folder = "temporal_attn_examples" + config + "/"
low_prob_filename = "pr_0.25.txt"
mid_prob_filename = "pr_0.45_0.55.txt"
high_prob_filename = "pr_0.9999.txt"

In [6]:
# opposing pronouns dictionary
man_words = ['man', 'men', 'male', 'he', 'him', 'his']
woman_words = ['woman', 'women', 'female', 'she', 'her', 'hers']
pronoun_oppos = dict()
for i, man_word in enumerate(man_words):
    pronoun_oppos[man_word] = woman_words[i]
    pronoun_oppos[woman_words[i]] = man_word

In [7]:
# pronoun sets
man_words_set = set(['man', 'men', 'male', 'he', 'him', 'his'])
woman_words_set = set(['woman', 'women', 'female', 'she', 'her', 'hers'])

In [8]:
def prepare_mask(sentence_data):
    tokens_tensor, segments_tensor, tokenized_text, sentence_info, norm_prob = sentence_data
    tokens_tensor = torch.tensor(tokens_tensor)
    segments_tensor = torch.tensor(segments_tensor)
    gender_index, query_index, gender_word, query_word = sentence_info
    tokenized_text[gender_index] = mask_token
    tokens_tensor[0][gender_index] = mask_id
    return tokens_tensor, tokenized_text

In [16]:
def get_attention_and_probs(inputs, masked_position):
    # Forward
    outputs = model(inputs)
    attention = outputs.attentions  # Output includes attention weights when output_attentions=True
    last_hidden_state = outputs[0].squeeze(0)
    # Only get output for masked token (output is the size of the vocabulary)
    mask_hidden_state = last_hidden_state[masked_position]
    # Convert to probabilities (softmax), giving a probability for each item in the vocabulary
    probs = softmax(mask_hidden_state)
    return attention, probs

In [10]:
def get_norm_prob(probs, gender_word):
    man_prob = 0
    woman_prob = 0
    for m_word in man_words_set:
        pronoun_id = tokenizer.convert_tokens_to_ids(m_word)
        man_prob += probs[pronoun_id].item()
    for w_word in woman_words_set:
        pronoun_id = tokenizer.convert_tokens_to_ids(w_word)
        woman_prob += probs[pronoun_id].item()
    gender_prob = man_prob if gender_word in man_words_set else woman_prob
    opp_gender_prob = woman_prob if gender_word in man_words_set else man_prob

    norm_prob = gender_prob / (gender_prob + opp_gender_prob)
    correctness = 1 if norm_prob > 0.5 else 0
    
    return norm_prob

#     top_word = torch.argmax(probs)
#     print('Top Prediction:')
#     print(tokenizer.decode(top_word), 'probability', probs[top_word].item())

## Attention Head View

head_view is a function to visualize the attention map!

In [21]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
from bertviz import head_view

In [None]:
head_view(attention, tokens, layer=2, heads=[3,5])

## Dissecting Attention Examples
CURRENT TO-DOs:
- Figure out how to get the attending pair with the highest weight 
- Does it make sense to compare weights between low_prob and high_prob?
- ^First Experiment


Thinking through the files:
- In low_prob, we predicted the wrong gender with high (ish) "confidence"
- In high_prob, we predicted the right gender with high (ish) "confidence"
- In mid_prob, we weren't really sure which gender to put there

We want to do two types of experiments:
1. A perplexity study: this looks at how much BERT is attending from the MASK (gender) to the interest word when predicting the mask. It would be interesting if high attention occurs when BERT predicts the wrong or right gender with high "confidence."
2. A follow-up to cosine similarity: this looks at how much BERT is attending from the interest word to the gender word (UNMASKED). If it is high for workers but not for work, for example, that could mean workers is more gendered. 

Random thoughts on how to deal with multiple layers and heads:
- Take lowest perplexities and figure out which word has the highest attention with the masked word across each layer?
- Also need to do for highest perplexity and also see if that's any different! (ie if high perplexity the masked word doesn't attend to the masked word the most or something?
- Take the higher norm probabilities ones (BERT is predicting the right gender!) and see if there aren't any corefs or syntactic clues; see if either pronoun could be used. Then check attention heads. Compare to lower norm probabilities one in the same case. Could reveal that certain interest words are non-gendered? 
- Layer with the max attention weight?

In [75]:
DISPLAY_CONTEXT = 5
def visualize_attention(attention, tokenized_text, gender_index, query_index):
    # truncate the text to only visualize relevant parts
    if gender_index < query_index:
        relevant_tokens_start = gender_index - DISPLAY_CONTEXT
        relevant_tokens_end = query_index + DISPLAY_CONTEXT
    else:
        relevant_tokens_start = query_index - DISPLAY_CONTEXT
        relevant_tokens_end = gender_index + DISPLAY_CONTEXT
    # prevent wrap around
    relevant_tokens_start = max(0, relevant_tokens_start)
    relevant_tokens_end = min(len(tokenized_text),relevant_tokens_end)
    # truncate attention
    truncated_attention = copy.copy(attention)
    truncated_attention = list(truncated_attention)
    for i,att_layer in enumerate(attention):
        truncated_attention[i] = attention[i][:,:,relevant_tokens_start:relevant_tokens_end,relevant_tokens_start:relevant_tokens_end]
    truncated_attention = tuple(truncated_attention)
    head_view(truncated_attention, tokenized_text[relevant_tokens_start:relevant_tokens_end])
    
    

In [77]:
# set pronoun_to_interest to TRUE if we want to finding the weights attending from the MASKed pronoun to the query word
# and FALSE for vice versa 
def get_attending_weights(attention, pronoun_idx, interest_idx, pronoun_to_interest):
    att_weights = []
    for att_layer in attention:
        layer = att_layer.squeeze()
        if pronoun_to_interest:
            layer_weights = layer[:, pronoun_idx, interest_idx].numpy()
        else:
            layer_weights = layer[:, interest_idx, pronoun_idx].numpy()
        att_weights.append(layer_weights)
    att_weights = np.stack(att_weights, axis=0)
    return att_weights

In [78]:
# EXPERIMENT 1
data = utilities.read_context_windows(results_folder + low_prob_filename)
for sentence_data in data:
    tokens_tensor, segments_tensor, tokenized_text, sentence_info, norm_prob = sentence_data
    gender_index, query_index, gender_word, query_word = sentence_info
    tokens_tensor, tokenized_text = prepare_mask(sentence_data)
    attention, probs = get_attention_and_probs(tokens_tensor, gender_index)
    # only visualize if distance between words is small
    if abs(gender_index - query_index) < 10:
        print(gender_word, query_word, norm_prob)
        # get attending weights at each layer, each head (12x12)
        attending_weights = get_attending_weights(attention, gender_index, query_index, True)
#         print(attending_weights)
        visualize_attention(attention, tokenized_text, gender_index, query_index)
        break
      

        # assert(norm_prob == get_norm_prob(probs, gender_word))
    

men power 0.1975165175904562


<IPython.core.display.Javascript object>

In [None]:
# EXPERIMENT 2
data = utilities.read_context_windows(results_folder + low_prob_filename)
for sentence_data in data:
    tokens_tensor, segments_tensor, tokenized_text, sentence_info, norm_prob = sentence_data
    gender_index, query_index, gender_word, query_word = sentence_info
    if query_word != "work" and query_word != "workers":
        continue
    attention, probs = get_attention_and_probs(tokens_tensor, gender_index)
    attending_weights = get_attending_weights(attention, gender_index, query_index, False)
    if query_word == "work":
        work_weights.append(attending_weights)
    else:
        workers_weights.append(attending_weights)
    

## Attention Neuron View

In [None]:
# Import specialized versions of models (that return query/key vectors)
from bertviz.transformers_neuron_view import BertModel, BertTokenizer
from bertviz.neuron_view import show

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model = BertModel.from_pretrained('bert_mlm/block_512/bert_mlm_textbook', output_attentions=True)
model_type = 'bert'
sentence = "The cat sat on the mat"
show(model, model_type, tokenizer, sentence)