In [1]:
%load_ext autoreload
%autoreload 2

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [2]:
import os
import json

import numpy as np
import math
import matplotlib
import matplotlib.pyplot as plt
from pylab import rcParams

import torch
import torch.nn.functional as F
from pytorch_pretrained_bert import tokenization, BertTokenizer, BertModel, BertForMaskedLM, BertForPreTraining, BertConfig
from examples.extract_features import *

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
class Args:
    def __init__(self):
        pass
    
args = Args()
args.no_cuda = True

CONFIG_NAME = 'bert_config.json'
BERT_DIR = '/nas/pretrain-bert/pretrain-tensorflow/uncased_L-12_H-768_A-12/'
config_file = os.path.join(BERT_DIR, CONFIG_NAME)
config = BertConfig.from_json_file(config_file)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')#do_lower_case：在标记化时将文本转换为小写。默认= True
model = BertForPreTraining.from_pretrained(BERT_DIR)
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
_ = model.to(device)
_ = model.eval()

05/14/2019 15:48:11 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/xd/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
05/14/2019 15:48:12 - INFO - pytorch_pretrained_bert.modeling -   loading archive file /nas/pretrain-bert/pretrain-tensorflow/uncased_L-12_H-768_A-12/
05/14/2019 15:48:12 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}



In [33]:
tokens = ['death','died','dead','die','dying','dies']
tokenizer.convert_tokens_to_ids(tokens)

[2331, 2351, 2757, 3280, 5996, 8289]

BertForPreTraining：
Outputs:
        if `masked_lm_labels` and `next_sentence_label` are not `None`:
            Outputs the total_loss which is the sum of the masked language modeling loss and the next
            sentence classification loss.
        if `masked_lm_labels` or `next_sentence_label` is `None`:
            Outputs a tuple comprising
            - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
            - the next sentence classification logits of shape [batch_size, 2].

from_pretrained：
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.

In [4]:
import re
def convert_text_to_examples(text): #把每一行的句子变成一个实例，一个实例中包含text_a,text_b(text_b目前是没用的)
    examples = []
    unique_id = 0
    if True:
        for line in text:
            line = line.strip()
            text_a = None
            text_b = None
            m = re.match(r"^(.*) \|\|\| (.*)$", line) #想要匹配这样的字符串'You are my sunshine. ||| I love you.'
            
            if m is None:
                text_a = line
            else:
                text_a = m.group(1) #匹配的第一句,比如You are my sunshine,my only sunshine.
                text_b = m.group(2) #匹配的第二句，比如I love you.
            
            examples.append(
                InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
            unique_id += 1
    return examples
#疑问，当text是一行的时候，line是一个个字母 -> text是["***"]的形式
#print(convert_text_to_examples({"I love you","hello everybody"})[0].text_a)

def convert_examples_to_features(examples, tokenizer, append_special_tokens=True, replace_mask=True, print_info=False):
    #把实例变成一个特征
    features = []
    for (ex_index, example) in enumerate(examples):
        tokens_a = tokenizer.tokenize(example.text_a) #tokenizer的作用是
        #print(example.unique_id) #*****************************
        tokens_b = None
        if example.text_b:
            tokens_b = tokenizer.tokenize(example.text_b)

        tokens = []
        input_type_ids = [] #segment embedding
        if append_special_tokens: #输入参数中默认为true
            tokens.append("[CLS]")
            input_type_ids.append(0)
        for token in tokens_a:
            if replace_mask and token == '_':  # XD
                token = "[MASK]"
            tokens.append(token)
            input_type_ids.append(0)
        if append_special_tokens:
            tokens.append("[SEP]")
            input_type_ids.append(0)

        if tokens_b:
            for token in tokens_b:
                if replace_mask and token == '_':  # XD
                    token = "[MASK]"
                tokens.append(token)
                input_type_ids.append(1)
            if append_special_tokens:
                tokens.append("[SEP]")
                input_type_ids.append(1)
        print(tokens) #*******************************
        input_ids = tokenizer.convert_tokens_to_ids(tokens) #把原来句子中的词语编成在字典中的编号
        input_mask = [1] * len(input_ids) 
        print(input_ids)#***********************************
        if ex_index < 5:
#             logger.info("*** Example ***")
#             logger.info("unique_id: %s" % (example.unique_id))
            logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
#             logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
#             logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
#             logger.info(
#                 "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]))
            
        features.append(
            InputFeatures(
                unique_id=example.unique_id,
                tokens=tokens,
                input_ids=input_ids,#字符串中的每个单词在词典中的index序列
                input_mask=input_mask, #一堆1
                input_type_ids=input_type_ids)) #第0类和第一类，对text_a,text_b的区分
    return features
                
examples = convert_text_to_examples({"I love you","hello everybody"})
features = convert_examples_to_features(examples, tokenizer, print_info=False)

def copy_and_mask_feature(feature, masked_tokens=None):
    import copy
    tokens = feature.tokens
    masked_positions = [tokens.index(t) for t in masked_tokens if t in tokens] \
        if masked_tokens is not None else range(len(tokens))
    
    assert len(masked_positions) > 0
    masked_feature_copies = []
    for masked_pos in masked_positions: #用[mask]依次掩盖每一个位置
        feature_copy = copy.deepcopy(feature)
        feature_copy.input_ids[masked_pos] = tokenizer.vocab["[MASK]"]
        masked_feature_copies.append(feature_copy)
    return masked_feature_copies, masked_positions

#masked_feature_copies, masked_positions = copy_and_mask_feature(features[1])
#print(masked_feature_copies[0].input_ids) #结果[101, 1045, 2293, 103, 102]
#print(masked_positions) #结果是一个range(0,5)

05/14/2019 15:48:15 - INFO - examples.extract_features -   tokens: [CLS] i love you [SEP]
05/14/2019 15:48:15 - INFO - examples.extract_features -   tokens: [CLS] hello everybody [SEP]


['[CLS]', 'i', 'love', 'you', '[SEP]']
[101, 1045, 2293, 2017, 102]
['[CLS]', 'hello', 'everybody', '[SEP]']
[101, 7592, 7955, 102]


In [47]:
def show_lm_probs(tokens, input_ids, probs, topk=5, firstk=20): #输出结果的函数，要最高概率topk个输出
    def print_pair(token, prob, end_str='', hit_mark=' '):
        if i < firstk:
            # token = token.replace('</w>', '').replace('\n', '/n')
            print('{}{: >3} | {: <12}'.format(hit_mark, int(round(prob*100)), token), end=end_str)
    
    ret = None
    for i in range(len(tokens)):
        ind_ = input_ids[i].item() if input_ids is not None else tokenizer.vocab[tokens[i]]
        prob_ = probs[i][ind_].item() #这个probs是该字符串第i个位置上填上词典上各个词的概率
        print_pair(tokens[i], prob_, end_str='\t')
        values, indices = probs[i].topk(topk)
        top_pairs = []
        for j in range(topk):
            ind, prob = indices[j].item(), values[j].item()
            hit_mark = '*' if ind == ind_ else ' '
            token = tokenizer.ids_to_tokens[ind]
            print_pair(token, prob, hit_mark=hit_mark, end_str='' if j < topk - 1 else '\n')
            top_pairs.append((token, prob))
        if tokens[i] == "[MASK]":
            ret = top_pairs
    return ret #返回的这是个啥

In [48]:
import colored
from colored import stylize

def show_abnormals(tokens, probs, show_suggestions=False):
    def gap2color(gap):
        if gap <= 5:
            return 'yellow_1'
        elif gap <= 10:
            return 'orange_1'
        else:
            return 'red_1'
        
    def print_token(token, suggestion, gap):
        
        if gap == 0:
            print(stylize(token + ' ', colored.fg('white') + colored.bg('black')), end='')
        else:
            print(stylize(token, colored.fg(gap2color(gap)) + colored.bg('black')), end='')
            if show_suggestions and gap > 5:
                print(stylize('/' + suggestion + ' ', colored.fg('green' if gap > 10 else 'cyan') + colored.bg('black')), end='')
            else:
                print(stylize(' ', colored.fg(gap2color(gap)) + colored.bg('black')), end='')
                
                # print('/' + suggestion, end=' ')
            # print('%.2f' % gap, end=' ')
        #print(gap)
    avg_gap = 0.
    for i in range(1, len(tokens) - 1):  # skip first [CLS] and last [SEP]
        ind_ = tokenizer.vocab[tokens[i]]
        prob_ = probs[i][ind_].item()
        top_prob = probs[i].max().item()
        top_ind = probs[i].argmax().item()
        gap = math.log(top_prob) - math.log(prob_) #计算两个词之间的差距
        #print(top_prob,prob_)
        suggestion = tokenizer.ids_to_tokens[top_ind]
        print_token(tokens[i], suggestion, gap)
        avg_gap += gap
    avg_gap /= (len(tokens) - 2)
    print()
    print(avg_gap)

In [49]:
analyzed_cache = {}

def analyze_text(text, masked_tokens=None, show_suggestions=True, show_firstk_probs=20):
    if text[0] in analyzed_cache: #分析过的缓存
        features, mlm_probs = analyzed_cache[text[0]]
        given_mask = "[MASK]" in features[0].tokens
        tokens = features[0].tokens 
    else:
        examples = convert_text_to_examples(text)
        features = convert_examples_to_features(examples, tokenizer, print_info=False)
        given_mask = "[MASK]" in features[0].tokens
        if not given_mask or masked_tokens is not None:
            assert len(features) == 1
            features, masked_positions = copy_and_mask_feature(features[0], masked_tokens=masked_tokens)

        input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) #把input_ids增加了一个维度，变成[n_features,sequence_len]
        #这里的n_features实际上是句子有多少个单词位置，每个位置依次换成[mask]
        input_type_ids = torch.tensor([f.input_type_ids for f in features], dtype=torch.long) #把input_type_ids增加了一个维度，其实每一行都一样
        input_ids = input_ids.to(device) #拿去GPU
        input_type_ids = input_type_ids.to(device)
        
        time_start=time.time()
        mlm_logits, _ = model(input_ids, input_type_ids)
        time_end=time.time()
        print('time cost1',time_end-time_start,'s')
        
        mlm_probs = F.softmax(mlm_logits, dim=-1) #最后一维，也就是vocab 换算成概率和为百分之百
        #print(mlm_probs.size())#这里实验的是torch.Size([5, 5, 30522])
        tokens = features[0].tokens #不知道要干嘛
        if not given_mask or masked_tokens is not None:
            bsz, seq_len, vocab_size = mlm_probs.size() #三个维度分别是batch_size, sequence_length, vocab_size
            assert bsz == len(masked_positions)
            # reduced_mlm_probs = torch.Tensor(1, seq_len, vocab_size)
            # for i in range(seq_len):
            #    reduced_mlm_probs[0, i] = mlm_probs[i, i]
            reduced_mlm_probs = torch.Tensor(1, len(masked_positions), vocab_size)
            for i, pos in enumerate(masked_positions):
                reduced_mlm_probs[0, i] = mlm_probs[i, pos]
            mlm_probs = reduced_mlm_probs #压缩一下大小，节约不必要浪费的空间（只需要第i个batch里面[mask]位置的词汇表概率即可）
            tokens = [tokens[i] for i in masked_positions]
        
        analyzed_cache[text[0]] = (features, mlm_probs)
        
    top_pairs = show_lm_probs(tokens, None, mlm_probs[0], firstk=show_firstk_probs) #传入的probs是二维的
    #print("************************************************************************************************************")
    #print(top_pairs) #******************************
    if not given_mask:
        show_abnormals(tokens, mlm_probs[0], show_suggestions=show_suggestions)
    return top_pairs


In [50]:
# text = ["Who was Jim Henson? Jim Henson _ a puppeteer."]
#text = ["Last week I went to the theatre. I had very good seat. The play was very interesting. But I didn't enjoy it. A young man and a young woman were sitting behind me. They were talking loudly. I got very angry. I couldn't hear a word. I turned round. I looked at the man angry. They didn't pay any attention.In the end, I couldn't bear it. I turned round again. 'I can't hear a word!' I said angrily. 'It's none of your business,' the young man said rudely. 'This is a private conversation!'"]
#text = ["After the outbreak of the disease, the Ministry of Agriculture and rural areas immediately sent a supervision team to the local. Local Emergency Response Mechanism has been activated in accordance with the requirements, to take blockade, culling, harmless treatment, disinfection and other treatment measures to all disease and culling of pigs for harmless treatment. At the same time, all live pigs and their products are prohibited from transferring out of the blockade area, and live pigs are not allowed to be transported into the blockade area. At present, all the above measures have been implemented."]
#text = ["The journey was long and tired. We left London at five o'clock in the evening and spend eight hours in the train. We had been travelled for 3 hours after someone appeared selling food and drinks. It was darkness all the time we were crossing Wales, but we could see nothing through the windows. When we finally arrived at Holyhead nearly , everyone was slept. As soon as the train stopped, everybody come to life, grabbing their suitcases and rushing onto the platform."]
#text = ["When I was little, Friday's night was our family game night. After supper, we would play card games of all sort in the sitting room. As the kid, I loved to watch cartoons,but no matter how many times I asked for watching them, my parents would not to let me.They would say to us that playing card games would help my brain. Still I unwilling to play the games for them sometimes. "]
#text = ["After the outbreak of the disease, the Ministry of Agriculture and rural areas immediately sent a supervision team to the local. Local Emergency Response Mechanism has been activated in accordance with the requirements, to take blockade, culling, harmless treatment, disinfection and other treatment measures to all disease and culling of pigs for harmless treatment. At the same time, all live pigs and their products are prohibited from transferring out of the blockade area, and live pigs are not allowed to be transported into the blockade area. At present, all the above measures have been implemented."]
# text = ["Early critics of Emily Dickinson's poetry mistook for simplemindedness the surface of artlessness that in fact she constructed with such innocence."]
#text = ["During my last winter holiday, I went to the countryside with my father to visit my grandparents. I find a big change there. The first time I went there, they were living in a small house with dogs, ducks, and another animals. Last winter when I went here again, they had a big separate house to raise dozens of chicken. They also had a small pond which they raised fish. My grandpa said last summer they earned quite a lot by sell the fish. I felt happily that their life had improved. At the end of our trip，I told my father that I planned to return for every two years, but he agreed."]
# text = ['The problem is difficult than that one.']
#text = ["It was Monday morning, and the writing class had just begin. Everyone was silent, wait to see who would be called upon to read his and her paragraph aloud. Some of us were confident and eagerly take part in the class activity, others were nervous and anxious. I had done myself homework but I was shy. I was afraid that to speak in front of a larger group of people. At that moment, I remembered that my father once said, 'The classroom is a place for learning and that include learning from the textbooks, and mistakes as well.' Immediate, I raised my hand."]
text = ["He is dies."]
import time
time_start=time.time()
#text = ["The play was very interesting."]
#text = ["The question is easy than that one."]
#text =["The apple a eat by me. I had a very good seat. The play was very interesting.But I didn't enjoy it. A young man and a young woman were sitting behind me.They were talking loudly. I got very angry."]#因为外面有中括号，所以是二维的
analyze_text(text, show_firstk_probs=200)
#print(analyzed_cache)
time_end=time.time()
print('time cost',time_end-time_start,'s')

05/21/2019 16:22:56 - INFO - examples.extract_features -   tokens: [CLS] he is dies . [SEP]


['[CLS]', 'he', 'is', 'dies', '.', '[SEP]']
[101, 2002, 2003, 8289, 1012, 102]
time cost1 0.0779261589050293 s
   0 | [CLS]       	   4 | .              1 | ,              1 | the            1 | )              1 | "           
  19 | he          	* 19 | he             8 | it             6 | she            3 | and            2 | the         
   0 | is          	  33 | then          15 | soon          12 | eventually     7 | later          4 | also        
   0 | dies        	   4 | dead           3 | alive          3 | right          2 | beautiful      2 | not         
  93 | .           	* 93 | .              6 | ;              1 | !              0 | ?              0 | |           
   0 | [SEP]       	  11 | "              5 | he             2 | .              1 | and            1 | it          
[38;5;15m[48;5;0mhe [0m[38;5;214m[48;5;0mis[0m[38;5;6m[48;5;0m/then [0m[38;5;214m[48;5;0mdies[0m[38;5;6m[48;5;0m/dead [0m[38;5;15m[48;5;0m. [0m
3.6602350262062977
time cost 0.

In [19]:
text = [
    # same / different
    "Tom has black hair. Mary has black hair. John has yellow hair. _  and Mary have the same hair color.",
    "Tom has black hair. Mary has black hair. John has yellow hair. _  and Mary have different hair colors.",
    "Tom has yellow hair. Mary has black hair. John has black hair. Mary and _ have the same hair color.",
    # because / although
    "John is taller/shorter than Mary because/although _ is older/younger.",
    "The red ball is heavier/lighter than the blue ball because/although the _ ball is bigger/smaller.",
    "Charles did a lot better/worse than his good friend Nancy on the test because/although _ had/hadn't studied so hard.",
    "The trophy doesn't fit into the brown suitcase because/although the _ is too small/large.",
    "John thought that he would arrive earlier than Susan, but/and indeed _ was the first to arrive.",
    # reverse
    "John came then Mary came. They left in reverse order. _ left then _ left.",
    "John came after Mary. They left in reverse order. _ left after _ .",
    "John came first, then came Mary. They left in reverse order: _ left first, then left _ .",
    # compare
    "Though John is tall, Tom is taller than John. So John is _ than Tom.",
    "Tom is taller than John. So _ is shorter than _.",
    # WSC-style: before /after
    "Mary came before/after John. _ was late/early .",
    # yes / no
    "Was Tom taller than Susan? Yes, _ was taller.",
    # right / wrong, epistemic modality
    "John said the rain was about to stop. Mary said the rain would continue. Later the rain stopped. _ was wrong.",
    
    "The trophy doesn't fit into the brown suitcase because/although the _ is too small/large.",
    "John thanked Mary because  _ had given help to _ . ",
    "John felt vindicated/crushed when his longtime rival Mary revealed that _ was the winner of the competition.",
    "John couldn't see the stage with Mary in front of him because _ is so short/tall.",
    "Although they ran at about the same speed, John beat Sally because _ had such a bad start.",
    "The fish ate the worm. The _ was hungry/tasty.",
    
    "John beat Mary. _ won the game/e winner.",
]
text

['Tom has black hair. Mary has black hair. John has yellow hair. _  and Mary have the same hair color.',
 'Tom has black hair. Mary has black hair. John has yellow hair. _  and Mary have different hair colors.']

In [None]:
config

In [22]:
with open('WSC_switched_label.json') as f:
    examples = json.load(f)

In [9]:
with open('WSC_child_problem.json') as f:
    cexamples = json.load(f)

In [89]:
for ce in cexamples:
    for s in ce['sentences']:
        for a in s['answer0'] + s['answer1']:
            a = a.lower()
            if a not in tokenizer.vocab:
                ce
                print(a, 'not in vocab!!!')

In [23]:
for ce in cexamples:
    if len(ce['sentences']) > 0:
        e = examples[ce['index']]
        assert ce['index'] == e['index']
        e['score'] = all([s['score'] for s in ce['sentences']])
        assert len(set([s['adjacent_ref'] for s in ce['sentences']])) == 1, 'adjcent_refs are different!'
        e['adjacent_ref'] = ce['sentences'][0]['adjacent_ref']

In [24]:
from collections import defaultdict

groups = defaultdict(list)
for e in examples:
    if 'score' in e:
        index = e['index']
        if index < 252:
            if index % 2 == 1:
                index -= 1
        elif index in [252, 253, 254]:
            index = 252
        else:
            if index % 2 == 0:
                index -= 1
        groups[index].append(e)

In [62]:
def filter_dict(d, keys=['index', 'sentence', 'correct_answer', 'relational_word', 'is_associative', 'score']):
    return {k: d[k] for k in d if k in keys}

# ([[filter_dict(e) for e in eg] for eg in groups.values() if eg[0]['relational_word'] != 'none' and all([e['score'] for e in eg])])# / len([eg for eg in groups.values() if eg[0]['relational_word'] != 'none'])
[(index, eg[0]['relational_word'], all([e['score'] for e in eg])) for index, eg in groups.items() if eg[0]['relational_word'] != 'none']
# len([filter_dict(e) for e in examples if 'score' in e and not e['score'] and e['adjacent_ref']])
# for e in examples:
#     if e['index'] % 2 == 0:
#         print(e['sentence'])

[(2, 'fit into:large/small', False),
 (4, 'thank:receive/give', False),
 (6, 'call:successful available', True),
 (8, 'ask:repeat answer', False),
 (10, 'zoom by:fast/slow', False),
 (12, 'vindicated/crushed:be the winner', False),
 (14, 'lift:weak heavy', False),
 (16, 'crash through:[hard]/[soft]', False),
 (18, '[block]:short/tall', False),
 (20, 'down to:top/bottom', False),
 (22, 'beat:good/bad', False),
 (24, 'roll off:anchored level', False),
 (26, 'above/below', False),
 (28, 'better/worse:study hard', False),
 (30, 'after/before:far away', False),
 (32, 'be upset with:buy from not work/sell not work', True),
 (34, '?yell at comfort:upset', False),
 (36, 'above/below:moved first', False),
 (38, 'although/because', False),
 (40, 'bully:punish rescue', False),
 (42, 'pour:empty/full', False),
 (44, 'know:nosy indiscreet', False),
 (46, 'explain:convince/understand', True),
 (48, '?know tell:so/because', True),
 (50, 'beat:younger/older', False),
 (56, 'clog:cleaned removed', True

In [51]:
sum(['because' in e['sentence'] for e in examples]) + \
sum(['so ' in e['sentence'] for e in examples]) + \
sum(['but ' in e['sentence'] for e in examples]) + \
sum(['though' in e['sentence'] for e in examples])

179

In [73]:
# with open('WSC_switched_label.json', 'w') as f:
#     json.dump(examples, f)

In [19]:
vis_attn_topk = 3

def has_chinese_label(labels):
    labels = [label.split('->')[0].strip() for label in labels]
    r = sum([len(label) > 1 for label in labels if label not in ['BOS', 'EOS']]) * 1. / (len(labels) - 1)
    return 0 < r < 0.5  # r == 0 means empty query labels used in self attention

def _plot_attn(ax1, attn_name, attn, key_labels, query_labels, col, color='b'):
    assert len(query_labels) == attn.size(0)
    assert len(key_labels) == attn.size(1)

    ax1.set_xlim([-1, 1])
    ax1.set_xticks([])
    ax2 = ax1.twinx()
    nlabels = max(len(key_labels), len(query_labels))
    pos = range(nlabels)
    
    if 'self' in attn_name and col < ncols - 1:
        query_labels = ['' for _ in query_labels]

    for ax, labels in [(ax1, key_labels), (ax2, query_labels)]:
        ax.set_yticks(pos)
        if has_chinese_label(labels):
            ax.set_yticklabels(labels, fontproperties=zhfont)
        else:
            ax.set_yticklabels(labels)
        ax.set_ylim([nlabels - 1, 0])
        ax.tick_params(width=0, labelsize='xx-large')

        for spine in ax.spines.values():
            spine.set_visible(False)

#     mask, attn = filter_attn(attn)
    for qi in range(attn.size(0)):
#         if not mask[qi]:
#             continue
#         for ki in range(attn.size(1)):
        for ki in attn[qi].topk(vis_attn_topk)[1]:
            a = attn[qi, ki]
            ax1.plot((-1, 1), (ki, qi), color, alpha=a)
#     print(attn.mean(dim=0).topk(5)[0])
#     ax1.barh(pos, attn.mean(dim=0).data.cpu().numpy())

def plot_layer_attn(result_tuple, attn_name='dec_self_attns', layer=0, heads=None):
    hypo, nheads, labels_dict = result_tuple
    key_labels, query_labels = labels_dict[attn_name]
    if heads is None:
        heads = range(nheads)
    else:
        nheads = len(heads)
    
    stride = 2 if attn_name == 'dec_enc_attns' else 1
    nlabels = max(len(key_labels), len(query_labels))
    rcParams['figure.figsize'] = 20, int(round(nlabels * stride * nheads / 8 * 1.0))
    
    rows = nheads // ncols * stride
    fig, axes = plt.subplots(rows, ncols)
    
    # for head in range(nheads):
    for head_i, head in enumerate(heads):
        row, col = head_i * stride // ncols, head_i * stride % ncols
        ax1 = axes[row, col]
        attn = hypo[attn_name][layer][head]
        _plot_attn(ax1, attn_name, attn, key_labels, query_labels, col)
        if attn_name == 'dec_enc_attns':
            col = col + 1
            axes[row, col].axis('off')  # next subfig acts as blank place holder
    # plt.suptitle('%s with %d heads, Layer %d' % (attn_name, nheads, layer), fontsize=20)
    plt.show()  
            
ncols = 4

In [31]:
attn_name = 'enc_self_attns'
hypo = {attn_name: [model.bert.encoder.layer[i].attention.self.attention_probs[0] for i in range(config.num_hidden_layers)]}
key_labels = query_labels = tokens
labels_dict = {attn_name: (key_labels, query_labels)}
result_tuple = (hypo, config.num_attention_heads, labels_dict)
plot_layer_attn(result_tuple, attn_name=attn_name, layer=10, heads=None)

AttributeError: 'BertSelfAttention' object has no attribute 'attention_probs'