In [1]:
import numpy as np
import json
import tensorflow as tf

from transformers import AutoTokenizer, AutoModel, utils
from bertviz import model_view, head_view

In [4]:
tokenizer = AutoTokenizer.from_pretrained("DeepPavlov/bert-base-cased-conversational")
model = AutoModel.from_pretrained("DeepPavlov/bert-base-cased-conversational", output_attentions=True)

with open('/Users/lizzy/Desktop/Universita/tesi/git/dialogue_coherence/clark/preprocessed_conversational.json', 'r') as f:
       texts = json.load(f)
       input_text = texts[0]["words"]
       input_text = " ".join(input_text)
       
inputs = tokenizer.encode(input_text, return_tensors='pt', add_special_tokens = False)  # Tokenize input text
outputs = model(inputs)  # Run model
attention = outputs[-1]  # Retrieve attention from model outputs
tokens = tokenizer.convert_ids_to_tokens(inputs[0])  # Convert input ids to token strings

Some weights of the model checkpoint at DeepPavlov/bert-base-cased-conversational were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [186]:
inputs[0]

tensor([  101,   750,  3119,   699, 25385,   119,   102,   146,   112,  1035,
         4932,   795,   646, 14946,   119,   102])

In [184]:
for item in tokenizer.convert_ids_to_tokens(inputs[0]):
    print(item=="[CLS]")

True
False
False
False
False
False
False
False
False
False
False
False
False
False
False
False


In [187]:
def get_token_positions(encoded_input, token, tokenizer):
    sep_pos = []
    i = 0
    for item in encoded_input:
        if item == token:
            sep_pos.append(i)
        i+=1
    return sep_pos

In [188]:
def from_torch_tensor2_numpy(tensor):
    arr = np.empty((len(tensor), tensor[0].size()[1], tensor[0].size()[-1], tensor[0].size()[-1]))
    for i, item in enumerate(tensor):
        arr[i] = item[0].cpu().detach().numpy()
    return arr

In [189]:
attentions = from_torch_tensor2_numpy(attention) 

In [194]:
def perc_attention(tokens_to_pay_attention_to, attention, encoded_tokens, tokenizer):
    res = {}
    for tkn in tokens_to_pay_attention_to:
        perc_attention_to = np.zeros(shape = (np.shape(attention)[:-2]))
        perc_attention_from = np.zeros(shape = (np.shape(attention)[:-2]))
        for l, layer in enumerate(attention):
            for h, head in enumerate(layer):
                if tkn in ['current', 'next', 'previous']:
                    offset = switch_attention2(tkn)
                    sum_weights_to = np.trace(head, offset = offset)
                    sum_weights_from = np.trace(np.transpose(head), offset = offset)
                else:
                    tkn_pos = get_token_positions(encoded_tokens, tkn, tokenizer)
                    sum_weights_to = sum([itemj for i, itemi in enumerate(head) for j, itemj in enumerate(head[i]) if i in tkn_pos])
                    sum_weights_from = sum([itemj for i, itemi in enumerate(np.transpose(head)) for j, itemj in enumerate(np.transpose(head)[i]) if i in tkn_pos])
                    
                sum_tot = np.sum(head)
                perc_attention_to[l][h] = sum_weights_to/sum_tot
                perc_attention_from[l][h] = sum_weights_from/sum_tot
        res[tkn] = {'to': perc_attention_to, 'from':perc_attention_from}
    return res

In [206]:
get_token_positions(tokens, '[SEP]', tokenizer)

[6, 15]

In [204]:
tokens[0]

'[CLS]'

In [195]:
def switch_attention2(token_to_pay_attention_to):
    if token_to_pay_attention_to == 'current':
        return 0
    elif token_to_pay_attention_to == 'next':
        return 1
    elif token_to_pay_attention_to == 'previous':
        return -1
    else:
        raise Exception('Token not valid')

In [207]:
res = perc_attention(['current', 
                    'next', 
                    'previous', 
                    '[CLS]', 
                    '[SEP]'], 
                    attentions,
                    tokens,
                    tokenizer)

In [208]:
for key in res:
    for direction in res[key]:
        print(f"Attention {direction} {key} token")
        for l, layer in enumerate(res[key][direction]):
            for h, head in enumerate(res[key][direction][l]):
                print(f"Layer {l} head {h}: {np.around(res[key][direction][l][h], decimals = 2)}")

Attention to current token
Layer 0 head 0: 0.12
Layer 0 head 1: 0.08
Layer 0 head 2: 0.08
Layer 0 head 3: 0.07
Layer 0 head 4: 0.11
Layer 0 head 5: 0.08
Layer 0 head 6: 0.08
Layer 0 head 7: 0.04
Layer 0 head 8: 0.09
Layer 0 head 9: 0.07
Layer 0 head 10: 0.34
Layer 0 head 11: 0.08
Layer 1 head 0: 0.08
Layer 1 head 1: 0.07
Layer 1 head 2: 0.06
Layer 1 head 3: 0.04
Layer 1 head 4: 0.05
Layer 1 head 5: 0.09
Layer 1 head 6: 0.1
Layer 1 head 7: 0.07
Layer 1 head 8: 0.03
Layer 1 head 9: 0.03
Layer 1 head 10: 0.05
Layer 1 head 11: 0.06
Layer 2 head 0: 0.06
Layer 2 head 1: 0.04
Layer 2 head 2: 0.17
Layer 2 head 3: 0.05
Layer 2 head 4: 0.07
Layer 2 head 5: 0.12
Layer 2 head 6: 0.05
Layer 2 head 7: 0.08
Layer 2 head 8: 0.06
Layer 2 head 9: 0.09
Layer 2 head 10: 0.08
Layer 2 head 11: 0.11
Layer 3 head 0: 0.08
Layer 3 head 1: 0.13
Layer 3 head 2: 0.08
Layer 3 head 3: 0.08
Layer 3 head 4: 0.07
Layer 3 head 5: 0.1
Layer 3 head 6: 0.12
Layer 3 head 7: 0.24
Layer 3 head 8: 0.05
Layer 3 head 9: 0.08
Lay