In [2]:
%load_ext autoreload
%autoreload 2

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [3]:
# import seaborn as sns
import os
import json

import numpy as np
import math
import matplotlib
import matplotlib.pyplot as plt
from pylab import rcParams

import torch
import torch.nn.functional as F
from pytorch_pretrained_bert import tokenization, BertTokenizer, BertModel, BertForMaskedLM, BertForPreTraining, BertConfig
from examples.extract_features import *

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [8]:
class Args:
    def __init__(self):
        pass
    
args = Args()
args.no_cuda = True

CONFIG_NAME = 'bert_config.json'
# BERT_DIR = '/nas/pretrain-bert/pretrain-tensorflow/uncased_L-12_H-768_A-12/'
BERT_DIR = '/nas/pretrain-bert/pretrain-pytorch/bert-base-uncased/'
config_file = os.path.join(BERT_DIR, CONFIG_NAME)
config = BertConfig.from_json_file(config_file)

# tokenizer = BertTokenizer.from_pretrained(os.path.join(BERT_DIR, 'vocab.txt'))
tokenizer = BertTokenizer.from_pretrained('/nas/pretrain-bert/pretrain-pytorch/bert-base-uncased-vocab.txt')
model = BertForPreTraining.from_pretrained(BERT_DIR)
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
_ = model.to(device)
_ = model.eval()

06/10/2019 08:14:45 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file /nas/pretrain-bert/pretrain-pytorch/bert-base-uncased-vocab.txt
06/10/2019 08:14:45 - INFO - pytorch_pretrained_bert.modeling -   loading archive file /nas/pretrain-bert/pretrain-pytorch/bert-base-uncased/
06/10/2019 08:14:45 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}



In [4]:
import re
def convert_text_to_examples(text):
    examples = []
    unique_id = 0
    if True:
        for line in text:
            line = line.strip()
            text_a = None
            text_b = None
            m = re.match(r"^(.*) \|\|\| (.*)$", line)
            if m is None:
                text_a = line
            else:
                text_a = m.group(1)
                text_b = m.group(2)
            examples.append(
                InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
            unique_id += 1
    return examples

def convert_examples_to_features(examples, tokenizer, append_special_tokens=True, replace_mask=True, print_info=False):
    features = []
    for (ex_index, example) in enumerate(examples):
        tokens_a = tokenizer.tokenize(example.text_a)
        tokens_b = None
        if example.text_b:
            tokens_b = tokenizer.tokenize(example.text_b)

        tokens = []
        input_type_ids = []
        if append_special_tokens:
            tokens.append("[CLS]")
            input_type_ids.append(0)
        for token in tokens_a:
            if replace_mask and token == '_':  # XD
                token = "[MASK]"
            tokens.append(token)
            input_type_ids.append(0)
        if append_special_tokens:
            tokens.append("[SEP]")
            input_type_ids.append(0)

        if tokens_b:
            for token in tokens_b:
                if replace_mask and token == '_':  # XD
                    token = "[MASK]"
                tokens.append(token)
                input_type_ids.append(1)
            if append_special_tokens:
                tokens.append("[SEP]")
                input_type_ids.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(input_ids)

        if ex_index < 5:
#             logger.info("*** Example ***")
#             logger.info("unique_id: %s" % (example.unique_id))
            logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
#             logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
#             logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
#             logger.info(
#                 "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]))
            
        features.append(
            InputFeatures(
                unique_id=example.unique_id,
                tokens=tokens,
                input_ids=input_ids,
                input_mask=input_mask,
                input_type_ids=input_type_ids))
    return features

def copy_and_mask_feature(feature, masked_tokens=None):
    import copy
    tokens = feature.tokens
    masked_positions = [tokens.index(t) for t in masked_tokens if t in tokens] \
        if masked_tokens is not None else range(len(tokens))
    assert len(masked_positions) > 0
    masked_feature_copies = []
    for masked_pos in masked_positions:
        feature_copy = copy.deepcopy(feature)
        feature_copy.input_ids[masked_pos] = tokenizer.vocab["[MASK]"]
        masked_feature_copies.append(feature_copy)
    return masked_feature_copies, masked_positions



In [5]:
def show_lm_probs(tokens, input_ids, probs, topk=5, firstk=20):
    def print_pair(token, prob, end_str='', hit_mark=' '):
        if i < firstk:
            # token = token.replace('</w>', '').replace('\n', '/n')
            print('{}{: >3} | {: <12}'.format(hit_mark, int(round(prob*100)), token), end=end_str)
    
    ret = None
    for i in range(len(tokens)):
        ind_ = input_ids[i].item() if input_ids is not None else tokenizer.vocab[tokens[i]]
        prob_ = probs[i][ind_].item()
        print_pair(tokens[i], prob_, end_str='\t')
        values, indices = probs[i].topk(topk)
        top_pairs = []
        for j in range(topk):
            ind, prob = indices[j].item(), values[j].item()
            hit_mark = '*' if ind == ind_ else ' '
            token = tokenizer.ids_to_tokens[ind]
            print_pair(token, prob, hit_mark=hit_mark, end_str='' if j < topk - 1 else '\n')
            top_pairs.append((token, prob))
        if tokens[i] == "[MASK]":
            ret = top_pairs
    return ret

In [7]:
import colored
from colored import stylize

def show_abnormals(tokens, probs, show_suggestions=False):
    def gap2color(gap):
        if gap <= 5:
            return 'yellow_1'
        elif gap <= 10:
            return 'orange_1'
        else:
            return 'red_1'
        
    def print_token(token, suggestion, gap):
        if gap == 0:
            print(stylize(token + ' ', colored.fg('white') + colored.bg('black')), end='')
        else:
            print(stylize(token, colored.fg(gap2color(gap)) + colored.bg('black')), end='')
            if show_suggestions and gap > 5:
                print(stylize('/' + suggestion + ' ', colored.fg('green' if gap > 10 else 'cyan') + colored.bg('black')), end='')
            else:
                print(stylize(' ', colored.fg(gap2color(gap)) + colored.bg('black')), end='')
                # print('/' + suggestion, end=' ')
            # print('%.2f' % gap, end=' ')
        
    avg_gap = 0.
    for i in range(1, len(tokens) - 1):  # skip first [CLS] and last [SEP]
        ind_ = tokenizer.vocab[tokens[i]]
        prob_ = probs[i][ind_].item()
        top_prob = probs[i].max().item()
        top_ind = probs[i].argmax().item()
        gap = math.log(top_prob) - math.log(prob_)
        suggestion = tokenizer.ids_to_tokens[top_ind]
        print_token(tokens[i], suggestion, gap)
        avg_gap += gap
    avg_gap /= (len(tokens) - 2)
    print()
    print(avg_gap)

In [6]:
analyzed_cache = {}

def analyze_text(text, masked_tokens=None, show_suggestions=False, show_firstk_probs=20):
    if text[0] in analyzed_cache:
        features, mlm_probs = analyzed_cache[text[0]]
        given_mask = "[MASK]" in features[0].tokens
        tokens = features[0].tokens
    else:
        examples = convert_text_to_examples(text)
        features = convert_examples_to_features(examples, tokenizer, print_info=False)
        given_mask = "[MASK]" in features[0].tokens
        if not given_mask or masked_tokens is not None:
            assert len(features) == 1
            features, masked_positions = copy_and_mask_feature(features[0], masked_tokens=masked_tokens)

        input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        input_type_ids = torch.tensor([f.input_type_ids for f in features], dtype=torch.long)
        input_ids = input_ids.to(device)
        input_type_ids = input_type_ids.to(device)

        mlm_logits, _ = model(input_ids, input_type_ids)
        mlm_probs = F.softmax(mlm_logits, dim=-1)

        tokens = features[0].tokens
        if not given_mask or masked_tokens is not None:
            bsz, seq_len, vocab_size = mlm_probs.size()
            assert bsz == len(masked_positions)
            # reduced_mlm_probs = torch.Tensor(1, seq_len, vocab_size)
            # for i in range(seq_len):
            #    reduced_mlm_probs[0, i] = mlm_probs[i, i]
            reduced_mlm_probs = torch.Tensor(1, len(masked_positions), vocab_size)
            for i, pos in enumerate(masked_positions):
                reduced_mlm_probs[0, i] = mlm_probs[i, pos]
            mlm_probs = reduced_mlm_probs
            tokens = [tokens[i] for i in masked_positions]
        
        analyzed_cache[text[0]] = (features, mlm_probs)
        
    top_pairs = show_lm_probs(tokens, None, mlm_probs[0], firstk=show_firstk_probs)
    if not given_mask:
        show_abnormals(tokens, mlm_probs[0], show_suggestions=show_suggestions)
    return top_pairs

In [38]:
text = ["_ was the greatest physicist who developed theory of relativity."]
text = ["The trophy doesn't fit into the brown suitcase because the _ is too large."] # relational adj
text = ['"Is Tom taller than Mary?" "No, _ is taller."']  # yes/no
text = [ "Tom has black hair. Mary has black hair. John has yellow hair. _  and Mary have the same hair color."]  # compare 
text = ['John is taller/shorter than Mary because/although _ is older/younger.']  # causality
text = ["Jennifer is older than James . Jennifer younger than Robert . _ is the oldest."]  # transitive inference

analyze_text(text, show_firstk_probs=100)

   0 | [CLS]       	   3 | .              1 | the            1 | ,              1 | )              1 | "           
 100 | "           	*100 | "              0 | '              0 | and            0 | so             0 | did         
 100 | is          	*100 | is             0 | was            0 | does           0 | isn            0 | has         
  97 | tom         	* 97 | tom            2 | he             0 | thomas         0 | you            0 | she         
 100 | taller      	*100 | taller         0 | tall           0 | shorter        0 | height         0 | tallest     
 100 | than        	*100 | than           0 | then           0 | as             0 | that           0 | to          
 100 | mary        	*100 | mary           0 | tom            0 | you            0 | barbara        0 | maria       
 100 | ?           	*100 | ?              0 | .              0 | !              0 | ...            0 | -           
 100 | "           	*100 | "              0 | '              0 | !      

[('tom', 0.7961671352386475),
 ('he', 0.09765198826789856),
 ('mary', 0.04068772494792938),
 ('she', 0.022535543888807297),
 ('thomas', 0.0058586327359080315)]

In [8]:
def words2heads(attns, tokens, words):
    positions = [tokens.index(word) for word in words]

    for layer in range(config.num_hidden_layers):
        for head in range(config.num_attention_heads):
            for pos_indices in [(0, 1), (1, 0)]:
                from_pos, to_pos = positions[pos_indices[0]], positions[pos_indices[1]]
                if attns[layer][head][from_pos].max(0)[1].item() == to_pos:
                    print('Layer %d, head %d: %s -> %s' % (layer, head, tokens[from_pos], tokens[to_pos]), end='\t')
                    print(attns[layer][head][from_pos].topk(5)[0].data)

def head2words(attns, tokens, layer, head):
    for from_pos in range(len(tokens)):
        to_pos = attns[layer][head][from_pos].max(0)[1].item()
        from_word, to_word = tokens[from_pos], tokens[to_pos]
        if from_word.isalpha() and to_word.isalpha():
            print('%s @ %d -> %s @ %d' % (from_word, from_pos, to_word, to_pos), end='\t')
            print(attns[layer][head][from_pos].topk(5)[0].data)
      
special_tokens = ['[CLS]', '[SEP]']

def get_salient_heads(attns, tokens, attn_thld=0.5):
    for layer in range(config.num_hidden_layers):
        for head in range(config.num_attention_heads):
            pos_pairs = []
            for from_pos in range(1, len(tokens) - 1):  # skip [CLS] and [SEP]
                top_attn, to_pos = attns[layer][head][from_pos].max(0)
                top_attn, to_pos = top_attn.item(), to_pos.item()
                from_word, to_word = tokens[from_pos], tokens[to_pos]
#                 if from_word.isalpha() and to_word.isalpha() and top_attn >= attn_thld:
                if abs(from_pos - to_pos) <= 1:
#                     print('Layer %d, head %d: %s @ %d -> %s @ %d' % (layer, head, from_word, from_pos, to_word, to_pos), end='\t')
#                     print(attns[layer][head][from_pos].topk(5)[0].data)
                    pos_pairs.append((from_pos, to_pos))
    
            ratio = len(pos_pairs) / (len(tokens) - 2)
            if ratio > 0.5:
                print(ratio)
                for from_pos, to_pos in pos_pairs:
                    print('Layer %d, head %d: %s @ %d -> %s @ %d' % (layer, head, tokens[from_pos], from_pos, tokens[to_pos], to_pos), end='\t')
                    print(attns[layer][head][from_pos].topk(5)[0].data)
                    

In [9]:
# text, words = ["The trophy doesn't fit into the brown suitcase because the it is too large."], ['fit', 'large']
# text, words = ["Mary couldn't beat John in the match because he was too strong."], ['beat', 'strong']
text, words = ["John is taller than Mary because he is older."], ['taller', 'older']
# text, words = ["The red ball is heavier than the blue ball because the red ball is bigger."], ['heavier', 'bigger']
text, words = ["Jim laughed because he was so happy."], ['cried', 'sad']
# text, words = ["Jim ate the cake quickly because he was so hungry."], ['ate', 'hungry']
# text, words = ["Jim drank the juice quickly because he was so thirsty."], ['drank', 'thirsty']
# text, words = ["Tom's drawing hangs high. It is above Susan's drawing"], ['high', 'above']
# text, words = ["Tom's drawing hangs low. It is below Susan's drawing"], ['low', 'below']
# text, words = ["John is taller than Mary . Mary is shorter than John."], ['taller', 'shorter']
# text, words = ["The drawing is above the cabinet. The cabinet is below the drawing"], ['above', 'below']
# text, words = ["Jim is very thin . He is not fat."], ['thin', 'fat']

features = convert_examples_to_features(convert_text_to_examples(text), tokenizer, print_info=False)
input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long).to(device)
input_type_ids = torch.tensor([f.input_type_ids for f in features], dtype=torch.long).to(device)
mlm_logits, _ = model(input_ids, input_type_ids)
mlm_probs = F.softmax(mlm_logits, dim=-1)
tokens = features[0].tokens
# top_pairs = show_lm_probs(tokens, None, mlm_probs[0], firstk=100)

attn_name = 'enc_self_attns'
hypo = {attn_name: [model.bert.encoder.layer[i].attention.self.attention_probs[0] for i in range(config.num_hidden_layers)]}
key_labels = query_labels = tokens
labels_dict = {attn_name: (key_labels, query_labels)}
result_tuple = (hypo, config.num_attention_heads, labels_dict)
# plot_layer_attn(result_tuple, attn_name=attn_name, layer=10, heads=None)

attns = hypo[attn_name]
    
# words2heads(attns, tokens, words)
head2words(attns, tokens, 2, 10)
# get_salient_heads(attns, tokens, attn_thld=0.0)

01/10/2019 21:46:20 - INFO - examples.extract_features -   tokens: [CLS] jim laughed because he was so happy . [SEP]


jim @ 1 -> jim @ 1	tensor([0.7248, 0.0842, 0.0656, 0.0407, 0.0319], device='cuda:0')


In [19]:
head_size = config.hidden_size // config.num_attention_heads
layer = 1
head = 1 # 2, 3, 10
wq = model.bert.encoder.layer[layer].attention.self.query.weight.data.view(-1, config.num_attention_heads, head_size).permute(1, 0, 2)
wk = model.bert.encoder.layer[layer].attention.self.key.weight.data.view(-1, config.num_attention_heads, head_size).permute(1, 0, 2)

wqk = torch.bmm(wq, wk.transpose(-1, -2))
# (wqk * wqk.transpose(-1, -2)).sum((1, 2)) / (wqk * wqk).sum((1, 2))
# plt.imshow(wqk[head]*wqk[head])
# plt.show()

# q = torch.matmul(pos_emb, wq)
# k = torch.matmul(pos_emb_prev, wk)
# (q * k).sum((-2, -1))

In [10]:
pos_emb = model.bert.embeddings.position_embeddings.weight.data
pos_emb_prev = torch.zeros_like(pos_emb)
pos_emb_next = torch.zeros_like(pos_emb)
pos_emb_prev[1:] = pos_emb[:-1]
pos_emb_next[:-1] = pos_emb[1:]
pos_emb, pos_emb_prev, pos_emb_next = pos_emb[1:-1], pos_emb_prev[1:-1], pos_emb_next[1:-1]

# pos_q = torch.matmul(pos_emb, wk[head])
# plt.imshow(pos_q[:32])
# plt.show()

In [19]:
text = [
    # same / different
    "Tom has black hair. Mary has black hair. John has yellow hair. _  and Mary have the same hair color.",
    "Tom has black hair. Mary has black hair. John has yellow hair. _  and Mary have different hair colors.",
    "Tom has yellow hair. Mary has black hair. John has black hair. Mary and _ have the same hair color.",
    # because / although
    "John is taller/shorter than Mary because/although _ is older/younger.",
    "The red ball is heavier/lighter than the blue ball because/although the _ ball is bigger/smaller.",
    "Charles did a lot better/worse than his good friend Nancy on the test because/although _ had/hadn't studied so hard.",
    "The trophy doesn't fit into the brown suitcase because/although the _ is too small/large.",
    "John thought that he would arrive earlier than Susan, but/and indeed _ was the first to arrive.",
    # reverse
    "John came then Mary came. They left in reverse order. _ left then _ left.",
    "John came after Mary. They left in reverse order. _ left after _ .",
    "John came first, then came Mary. They left in reverse order: _ left first, then left _ .",
    # compare sentences with same / opposite meaning, 2nd order
    "Though John is tall, Tom is taller than John. So John is _ than Tom.",
    "Tom is taller than John. So _ is shorter than _.",
    # WSC-style: before /after
    # "Mary came before/after John. _ was late/early .",
    # yes / no, 2nd order
    "Was Tom taller than Susan? Yes, _ was taller.",
    # right / wrong, epistemic modality, 2nd order
    "John said/thought that the red ball was heavier than the blue ball. He was wrong. The _ ball was heavier",
    "John was wrong in saying/thinking that the red ball was heavier than the blue ball. The _ ball was heavier",
    "John said the rain was about to stop. Mary said the rain would continue. Later the rain stopped. _ was wrong/right.",
    
    "The trophy doesn't fit into the brown suitcase because/although the _ is too small/large.",
    "John thanked Mary because  _ had given help to _ . ",
    "John felt vindicated/crushed when his longtime rival Mary revealed that _ was the winner of the competition.",
    "John couldn't see the stage with Mary in front of him because _ is so short/tall.",
    "Although they ran at about the same speed, John beat Sally because _ had such a bad start.",
    "The fish ate the worm. The _ was hungry/tasty.",
    
    "John beat Mary. _ won the game/e winner.",
]
text

['Tom has black hair. Mary has black hair. John has yellow hair. _  and Mary have the same hair color.',
 'Tom has black hair. Mary has black hair. John has yellow hair. _  and Mary have different hair colors.']

In [4]:
with open('WSC_switched_label.json') as f:
    examples = json.load(f)

In [5]:
with open('WSC_child_problem.json') as f:
    cexamples = json.load(f)

In [7]:
for ce in cexamples:
    for s in ce['sentences']:
        for a in s['answer0'] + s['answer1']:
            a = a.lower()
#             if a not in tokenizer.vocab:
#                 ce
#                 print(a, 'not in vocab!!!')

In [8]:
for ce in cexamples:
    if len(ce['sentences']) > 0:
        e = examples[ce['index']]
        assert ce['index'] == e['index']
        e['score'] = all([s['score'] for s in ce['sentences']])
        assert len(set([s['adjacent_ref'] for s in ce['sentences']])) == 1, 'adjcent_refs are different!'
        e['adjacent_ref'] = ce['sentences'][0]['adjacent_ref']

In [9]:
from collections import defaultdict

groups = defaultdict(list)
for e in examples:
    if 'score' in e:
        index = e['index']
        if index < 252:
            if index % 2 == 1:
                index -= 1
        elif index in [252, 253, 254]:
            index = 252
        else:
            if index % 2 == 0:
                index -= 1
        groups[index].append(e)

In [41]:
def filter_dict(d, keys=['index', 'sentence', 'correct_answer', 'relational_word', 'is_associative', 'score']):
    return {k: d[k] for k in d if k in keys}

# ([[filter_dict(e) for e in eg] for eg in groups.values() if eg[0]['relational_word'] != 'none' and all([e['score'] for e in eg])])# / len([eg for eg in groups.values() if eg[0]['relational_word'] != 'none'])
# [(index, eg[0]['relational_word'], all([e['score'] for e in eg])) for index, eg in groups.items() if eg[0]['relational_word'] != 'none']
# len([filter_dict(e) for e in examples if 'score' in e and not e['score'] and e['adjacent_ref']])
# for e in examples:
#     if e['index'] % 2 == 0:
#         print(e['sentence'])
[(eg[0]['index'], eg[0]['sentence'], eg[0]['relational_word']) for index, eg in groups.items() if '/' in eg[0]['relational_word']]

[(2,
  "The trophy doesn't fit into the brown suitcase because [it] is too large.",
  'fit into:large/small'),
 (4,
  'Joan made sure to thank Susan for all the help [she] had recieved.',
  'thank:receive/give'),
 (10,
  'The delivery truck zoomed by the school bus because [it] was going so fast.',
  'zoom by:fast/slow'),
 (12,
  'Frank felt vindicated when his longtime rival Bill revealed that [he] was the winner of the competition.',
  'vindicated/crushed:be the winner'),
 (16,
  'The large ball crashed right through the table because [it] was made of steel.',
  'crash through:[hard]/[soft]'),
 (18,
  "John couldn't see the stage with Billy in front of him because [he] is so short.",
  '[block]:short/tall'),
 (20,
  'Tom threw his schoolbag down to Ray after [he] reached the top of the stairs.',
  'down to:top/bottom'),
 (22,
  'Although they ran at about the same speed, Sue beat Sally because [she] had such a good start.',
  'beat:good/bad'),
 (26,
  "Sam's drawing was hung just abo

In [51]:
sum(['because' in e['sentence'] for e in examples]) + \
sum(['so ' in e['sentence'] for e in examples]) + \
sum(['but ' in e['sentence'] for e in examples]) + \
sum(['though' in e['sentence'] for e in examples])

179

In [73]:
# with open('WSC_switched_label.json', 'w') as f:
#     json.dump(examples, f)

In [12]:
vis_attn_topk = 3

def has_chinese_label(labels):
    labels = [label.split('->')[0].strip() for label in labels]
    r = sum([len(label) > 1 for label in labels if label not in ['BOS', 'EOS']]) * 1. / (len(labels) - 1)
    return 0 < r < 0.5  # r == 0 means empty query labels used in self attention

def _plot_attn(ax1, attn_name, attn, key_labels, query_labels, col, color='b'):
    assert len(query_labels) == attn.size(0)
    assert len(key_labels) == attn.size(1)

    ax1.set_xlim([-1, 1])
    ax1.set_xticks([])
    ax2 = ax1.twinx()
    nlabels = max(len(key_labels), len(query_labels))
    pos = range(nlabels)
    
    if 'self' in attn_name and col < ncols - 1:
        query_labels = ['' for _ in query_labels]

    for ax, labels in [(ax1, key_labels), (ax2, query_labels)]:
        ax.set_yticks(pos)
        if has_chinese_label(labels):
            ax.set_yticklabels(labels, fontproperties=zhfont)
        else:
            ax.set_yticklabels(labels)
        ax.set_ylim([nlabels - 1, 0])
        ax.tick_params(width=0, labelsize='xx-large')

        for spine in ax.spines.values():
            spine.set_visible(False)

#     mask, attn = filter_attn(attn)
    for qi in range(attn.size(0)):
#         if not mask[qi]:
#             continue
#         for ki in range(attn.size(1)):
        for ki in attn[qi].topk(vis_attn_topk)[1]:
            a = attn[qi, ki]
            ax1.plot((-1, 1), (ki, qi), color, alpha=a)
#     print(attn.mean(dim=0).topk(5)[0])
#     ax1.barh(pos, attn.mean(dim=0).data.cpu().numpy())

def plot_layer_attn(result_tuple, attn_name='dec_self_attns', layer=0, heads=None):
    hypo, nheads, labels_dict = result_tuple
    key_labels, query_labels = labels_dict[attn_name]
    if heads is None:
        heads = range(nheads)
    else:
        nheads = len(heads)
    
    stride = 2 if attn_name == 'dec_enc_attns' else 1
    nlabels = max(len(key_labels), len(query_labels))
    rcParams['figure.figsize'] = 20, int(round(nlabels * stride * nheads / 8 * 1.0))
    
    rows = nheads // ncols * stride
    fig, axes = plt.subplots(rows, ncols)
    
    # for head in range(nheads):
    for head_i, head in enumerate(heads):
        row, col = head_i * stride // ncols, head_i * stride % ncols
        ax1 = axes[row, col]
        attn = hypo[attn_name][layer][head]
        _plot_attn(ax1, attn_name, attn, key_labels, query_labels, col)
        if attn_name == 'dec_enc_attns':
            col = col + 1
            axes[row, col].axis('off')  # next subfig acts as blank place holder
    # plt.suptitle('%s with %d heads, Layer %d' % (attn_name, nheads, layer), fontsize=20)
    plt.show()  
            
ncols = 4

In [40]:
config.num

{
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}