In [1]:
### !pip install torch git+https://github.com/martijnvanbeers/transformers@feature/attention-transformers pandas seaborn matplotlib numpy scikit-learn spacy==2.3.7 https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz
#!wget https://raw.githubusercontent.com/martijnvanbeers/nlp-attribution-notebooks/main/firsthalf.txt
#!wget https://raw.githubusercontent.com/martijnvanbeers/nlp-attribution-notebooks/main/valuezeroing.py

In [2]:
import itertools
import numpy
import pandas
import seaborn
import matplotlib.pyplot as plt
import ipywidgets as widgets
import spacy
import torch

from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSequenceClassification, AutoModelForMaskedLM
)

from valuezeroing import calculate_scores_for_batch

In [3]:
from transformers import AutoModelForSequenceClassification, AutoModel

In [4]:
## GPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print('We will use the GPU:', torch.cuda.get_device_name("cuda"))
else:
    device = torch.device("cpu")
    print('No GPU available, using the CPU instead.')



We will use the GPU: NVIDIA RTX A4000 Laptop GPU


In [5]:
corpus = pandas.read_csv("firsthalf.txt", sep="\t", header=None, names=["line"])

In [6]:
with pandas.option_context("display.max_colwidth", 200):
    display(corpus.head(10))

Unnamed: 0,line
0,"Pierre Vinken, 61 years old, will join the board as a nonexecutive director Nov. 29."
1,"Mr. Vinken is chairman of Elsevier N.V., the Dutch publishing group."
2,"Rudolph Agnew, 55 years old and former chairman of Consolidated Gold Fields PLC, was named a nonexecutive director of this British industrial conglomerate."
3,"A form of asbestos once used to make Kent cigarette filters has caused a high percentage of cancer deaths among a group of workers exposed to it more than 30 years ago, researchers reported."
4,"The asbestos fiber, crocidolite, is unusually resilient once it enters the lungs, with even brief exposures to it causing symptoms that show up decades later, researchers said."
5,"Lorillard Inc., the unit of New York-based Loews Corp. that makes Kent cigarettes, stopped using crocidolite in its Micronite cigarette filters in 1956."
6,"Although preliminary findings were reported more than a year ago, the latest results appear in today's New England Journal of Medicine, a forum likely to bring new attention to the problem."
7,"A Lorillard spokewoman said, 'This is an old story."
8,We're talking about years ago before anyone heard of asbestos having any questionable properties.
9,There is no asbestos in our products now.'


In [7]:
class TransformerTokenizer:
    def __init__(self, vocab, tokenizer, transformer="bert"):
        self.vocab = vocab
        self._tokenizer = tokenizer
        self._transformer = transformer

    def __call__(self, text):
        result = self._tokenizer._tokenizer.encode(text)
        words = []
        spaces = []
        if self._transformer in ["bert", "roberta"]:
            word_ids = result.word_ids[1:-1]
            offset_skip = 1
        elif self._transformer in ["gpt"]:
            word_ids = result.word_ids
            offset_skip = 0
        for wordix,g in itertools.groupby(zip(range(len(word_ids)), word_ids), key=lambda t: t[1]):
            g = list(g)
            first_token = g[0][0]
            last_token = g[-1][0]
            start = result.offsets[first_token+offset_skip][0]
            if text[start] == ' ':
                start+= 1
            end = result.offsets[last_token+offset_skip][1]
#            print(start, end, f"'{text[start:end]}'")
            words.append(text[start:end])
            if wordix < max(word_ids):
                # If next start != current end we assume a space in between
                next_start, next_end = result.offsets[last_token + offset_skip + 1]
                spaces.append(next_start > end)
            else:
                if end < len(text):
                    spaces.append(True)
                else:
                    spaces.append(False)
        return spacy.tokens.Doc(self.vocab, words=words, spaces=spaces)

In [8]:
#family = "bert"
#transformer = "bert-base-uncased"
family = "gpt"
transformer = "gpt2"
config = AutoConfig.from_pretrained(transformer, output_attentions=True)#, attentions_with_qk=True)
tokenizer = AutoTokenizer.from_pretrained(transformer)
model = AutoModel.from_pretrained(transformer, config=config)
model.to(device)
model.eval()

nlp = spacy.load("en_core_web_lg")
nlp.tokenizer = TransformerTokenizer(nlp.vocab, tokenizer, transformer=family)

In [9]:
#poslist = ["[CLS]", "[SEP]", "CCONJ", "PROPN", "PRON", "AUX", "VERB", "ADP", "NOUN", "SYM", "NUM", "DET", "PUNCT"]
poslist = [
    "[SELF]",
    "[CONTINUATION]",
    "[CLS]",
    "[SEP]",
#    "",
    "ADJ",
    "ADP",
    "ADV",
    "AUX",
    "CONJ",
    "CCONJ",
    "DET",
    "INTJ",
    "NOUN",
    "NUM",
    "PART",
    "PRON",
    "PROPN",
    "PUNCT",
    "SCONJ",
    "SYM",
    "VERB",
    "X",
    "EOL",
    "SPACE",
]

In [10]:
len(poslist)

24

In [11]:
df = None
for ix, sent in enumerate(corpus.head(5)['line']):
    result = tokenizer(sent, add_special_tokens=False, return_offsets_mapping=True)
    doc = nlp(sent)
    docpos = [doc[t].pos_ for t in result.word_ids()]
    tokens = tokenizer.convert_ids_to_tokens(result['input_ids'])
    words = [doc[t].text for t in result.word_ids()]
    sent_df = pandas.DataFrame(dict(sent_id=ix, input_id=result['input_ids'], pos=docpos, token=tokens, word=words, word_id=result.word_ids()))
    if df is None:
        df = sent_df
    else:
        df = pandas.concat([df, sent_df])
df = df.reset_index(drop=True)


In [12]:
window_size = 50
stride = 25
future = 0

mask = torch.cat((torch.zeros(window_size - (stride + future)), torch.ones(stride), torch.zeros(future))).to(torch.int64)
print(mask)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1])


In [13]:
pad_len = window_size - df.shape[0] % window_size

In [14]:
def dataloader_from_dataframe(input_df:pandas.DataFrame, window_size: int, stride: int, batch_size :int):
    input_ids = torch.as_tensor(input_df['input_id'].to_numpy(dtype="int64"))
    pad_len = window_size - input_ids.shape[0] % window_size
    padded_input_ids = (torch
            .cat(
                    (input_ids, input_ids.new_full((pad_len,), tokenizer.eos_token_id))
                )
            .unfold(0, window_size, stride)
        )
    padded_input_mask = (
            torch
                .cat(
                        (torch.ones_like(input_ids), torch.zeros(pad_len,dtype=torch.int64))
                    )
                .unfold(0, window_size, stride)
        )
    tensor_dataset = torch.utils.data.TensorDataset(padded_input_ids, padded_input_mask)
    tensor_dataloader = torch.utils.data.DataLoader(tensor_dataset, batch_size=batch_size)
    
    return tensor_dataloader

In [15]:
batch_size=2

In [16]:
dl = dataloader_from_dataframe(df, window_size, stride, batch_size=batch_size)

In [17]:
ds = dl.dataset

In [18]:
torch.save(ds, 'test.pt')

In [22]:
%%time
combined_df = None
for i, (input_ids, input_mask) in enumerate(dl):
    valuezeroing_scores, rollout_valuezeroing_scores, attentions = calculate_scores_for_batch(config, model, family, input_ids, input_mask)
    print(valuezeroing_scores.shape, attentions.shape)
    for n in numpy.arange(input_mask.shape[0]):
        sent_mask = numpy.argwhere(numpy.multiply(input_mask[n], mask)).flatten()
        if sent_mask.sum() == 0:
            # the combination of padding masking and stride masking leaves nothing
            continue
        f = sent_mask.min()
        l = sent_mask.max() + 1
        # skip the initial stride
        m = n + 1
        start = i*batch_size*stride + m*stride
        print(batch_size,i,m, start,start+stride)
        all_tokens = df.iloc[start:start+stride, 3]
        index = pandas.MultiIndex.from_product(
            [numpy.arange(12)+1, numpy.arange(12)+1, all_tokens, all_tokens],
            names=['layer','head','from', 'to']
        )
        scores_matrix = valuezeroing_scores[:,n,:,f:l,f:l]
        rollout_matrix = rollout_valuezeroing_scores[:,n,:,f:l,f:l]
        att_matrix = attentions.detach().cpu().numpy()[:,n,:,f:l,f:l]
        score_df = pandas.DataFrame(
                numpy.hstack([
                        scores_matrix.reshape(-1, 1),
                        rollout_matrix.reshape(-1,1),
                        att_matrix.reshape(-1,1)
                    ]),
                index=index,
                columns=["valuezeroing", "rollout_vz", "raw_attention"]
            ).reset_index()
        score_df['from_pos'] = pandas.Categorical(
            numpy.tile(numpy.repeat(df.iloc[all_tokens.index, 2], all_tokens.shape[0]), 12*12),
            categories=poslist
        )
        score_df['to_pos'] = pandas.Categorical(
            numpy.tile(numpy.array(df.iloc[all_tokens.index, 2]), all_tokens.shape[0]*12*12),
            categories=poslist
        )
        score_df['from_ix'] = numpy.tile(numpy.repeat(all_tokens.index, all_tokens.shape[0]), 12*12)
        score_df['from_word'] = numpy.tile(numpy.repeat(df.iloc[all_tokens.index,4], len(all_tokens)), 12*12)
        score_df['to_ix'] = numpy.tile(all_tokens.index, all_tokens.shape[0]*12*12)
        score_df['to_word'] = numpy.tile(df.iloc[all_tokens.index,4], all_tokens.shape[0]*12*12)
        score_df['to_pos'] = score_df.apply(lambda r: "[SELF]" if r['from_ix'] == r['to_ix'] else r['to_pos'], axis=1)
        score_df['to_pos'] = score_df.apply(lambda r: "[CONTINUATION]" if r['to_pos'] != "[SELF]" and r['from_word'] is not None and r['to_word'] is not None and r['from_word' ] == r['to_word'] else r['to_pos'], axis=1)
        score_df['stride'] = i*batch_size+m
        counts = ((score_df[(score_df['layer'] == 1) & (score_df['head'] == 1)]
                        .groupby(["from_pos", "to_pos"])
                        .agg({"from": "count"}))
                        .rename(columns={'from': 'combo_count'})
                        .reset_index()
                )
        score_df = score_df.merge(counts, how="left", on=["from_pos", "to_pos"])

        if combined_df is None:
            combined_df = score_df
        else:
            combined_df = pandas.concat([combined_df, score_df])
combined_df = combined_df[combined_df['from_ix'] >= combined_df['to_ix']]

(12, 2, 12, 50, 50) torch.Size([12, 2, 12, 50, 50])
2 0 1 25 50
2 0 2 50 75
(12, 2, 12, 50, 50) torch.Size([12, 2, 12, 50, 50])
2 1 1 75 100
2 1 2 100 125
(12, 1, 12, 50, 50) torch.Size([12, 1, 12, 50, 50])
2 2 1 125 150
CPU times: user 3min 46s, sys: 1.98 s, total: 3min 48s
Wall time: 33.9 s


In [23]:
numpy.multiply(input_mask[-1], mask)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
        0, 0])

In [24]:
combined_df['raw_attention'].describe()

count    2.236320e+05
mean     2.670984e-02
std      7.516092e-02
min      6.863768e-34
25%      1.082296e-03
50%      5.933533e-03
75%      2.266469e-02
max      1.000000e+00
Name: raw_attention, dtype: float64

In [25]:
combined_df.head(50)

Unnamed: 0,layer,head,from,to,valuezeroing,rollout_vz,raw_attention,from_pos,to_pos,from_ix,from_word,to_ix,to_word,stride,combo_count
0,1,1,ĠV,ĠV,0.051343,0.051343,0.048282,PROPN,[SELF],25,Vinken,25,Vinken,1,12
25,1,1,ink,ĠV,0.157899,0.157899,0.084443,PROPN,[CONTINUATION],26,Vinken,25,Vinken,1,16
26,1,1,ink,ink,0.01681,0.01681,0.030021,PROPN,[SELF],26,Vinken,26,Vinken,1,12
50,1,1,en,ĠV,0.088908,0.088908,0.06437,PROPN,[CONTINUATION],27,Vinken,25,Vinken,1,16
51,1,1,en,ink,0.021178,0.021178,0.03232,PROPN,[CONTINUATION],27,Vinken,26,Vinken,1,16
52,1,1,en,en,0.032921,0.032921,0.027646,PROPN,[SELF],27,Vinken,27,Vinken,1,12
75,1,1,Ġis,ĠV,0.009204,0.009204,0.032582,AUX,PROPN,28,is,25,Vinken,1,12
76,1,1,Ġis,ink,0.008154,0.008154,0.032257,AUX,PROPN,28,is,26,Vinken,1,12
77,1,1,Ġis,en,0.004447,0.004447,0.016931,AUX,PROPN,28,is,27,Vinken,1,12
78,1,1,Ġis,Ġis,0.006857,0.006857,0.022232,AUX,[SELF],28,is,28,is,1,1


In [26]:
combined_df[combined_df['stride'] == 1].groupby(["layer", "head", "from_ix"]).agg({'raw_attention': ["count", "sum"]})

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,raw_attention,raw_attention
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,count,sum
layer,head,from_ix,Unnamed: 3_level_2,Unnamed: 4_level_2
1,1,25,1,0.048282
1,1,26,2,0.114464
1,1,27,3,0.124336
1,1,28,4,0.104001
1,1,29,5,0.167633
...,...,...,...,...
12,12,45,21,0.013342
12,12,46,22,0.273854
12,12,47,23,0.128829
12,12,48,24,0.865304


In [27]:
combined_df[combined_df['stride'] == 1]

Unnamed: 0,layer,head,from,to,valuezeroing,rollout_vz,raw_attention,from_pos,to_pos,from_ix,from_word,to_ix,to_word,stride,combo_count
0,1,1,ĠV,ĠV,0.051343,0.051343,0.048282,PROPN,[SELF],25,Vinken,25,Vinken,1,12
25,1,1,ink,ĠV,0.157899,0.157899,0.084443,PROPN,[CONTINUATION],26,Vinken,25,Vinken,1,16
26,1,1,ink,ink,0.016810,0.016810,0.030021,PROPN,[SELF],26,Vinken,26,Vinken,1,12
50,1,1,en,ĠV,0.088908,0.088908,0.064370,PROPN,[CONTINUATION],27,Vinken,25,Vinken,1,16
51,1,1,en,ink,0.021178,0.021178,0.032320,PROPN,[CONTINUATION],27,Vinken,26,Vinken,1,16
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
89995,12,12,Ġyears,ĠAg,0.020031,0.000478,0.014977,NOUN,PROPN,49,years,45,Agnew,1,48
89996,12,12,Ġyears,new,0.020071,0.000865,0.016337,NOUN,PROPN,49,years,46,Agnew,1,48
89997,12,12,Ġyears,",",0.020048,0.000566,0.011024,NOUN,PUNCT,49,years,47,",",1,16
89998,12,12,Ġyears,Ġ55,0.020050,0.000727,0.011241,NOUN,NUM,49,years,48,55,1,4


In [28]:
g = (combined_df
     .groupby(["layer", "head", "from_pos", "to_pos"])
     .agg({
             "raw_attention": lambda n: numpy.sum(n) / (stride * ((numpy.ceil(df.shape[0]/stride)-1))),
             "valuezeroing": lambda n: numpy.sum(n) / (stride * ((numpy.ceil(df.shape[0]/stride)-1))),
             "rollout_vz": lambda n: numpy.sum(n) / (stride * ((numpy.ceil(df.shape[0]/stride)-1))),
         })
     .dropna()
     .reset_index())
 

In [29]:
g

Unnamed: 0,layer,head,from_pos,to_pos,raw_attention,valuezeroing,rollout_vz
0,1,1,ADJ,ADJ,0.003360,0.003326,0.003326
1,1,1,ADJ,ADP,0.001954,0.002732,0.002732
2,1,1,ADJ,ADV,0.000821,0.000486,0.000486
3,1,1,ADJ,AUX,0.001110,0.001052,0.001052
4,1,1,ADJ,CCONJ,0.001134,0.001364,0.001364
...,...,...,...,...,...,...,...
25051,12,12,VERB,PROPN,0.004022,0.001290,0.000353
25052,12,12,VERB,PUNCT,0.002994,0.000956,0.000309
25053,12,12,VERB,SCONJ,0.000010,0.000160,0.000001
25054,12,12,VERB,VERB,0.000509,0.001922,0.000725


In [30]:
g['raw_attention'].describe()

count    2.505600e+04
mean     1.907144e-03
std      5.659381e-03
min      7.556258e-25
25%      7.020152e-05
50%      4.094015e-04
75%      1.636672e-03
max      1.923689e-01
Name: raw_attention, dtype: float64

In [31]:
n = 0
for _, gr in combined_df.groupby(["layer", "head", "from_pos", "to_pos"]):
    display(gr)
    n += 1
    if n > 5:
        break

Unnamed: 0,layer,head,from,to,valuezeroing,rollout_vz,raw_attention,from_pos,to_pos,from_ix,from_word,to_ix,to_word,stride,combo_count
50,1,1,Ġformer,Ġold,0.048961,0.048961,0.055692,ADJ,ADJ,52,former,50,old,2,36
375,1,1,Ġnonex,Ġold,0.001802,0.001802,0.010431,ADJ,ADJ,65,nonexecutive,50,old,2,36
377,1,1,Ġnonex,Ġformer,0.010113,0.010113,0.022981,ADJ,ADJ,65,nonexecutive,52,former,2,36
400,1,1,ec,Ġold,0.000498,0.000498,0.005523,ADJ,ADJ,66,nonexecutive,50,old,2,36
402,1,1,ec,Ġformer,0.001244,0.001244,0.005679,ADJ,ADJ,66,nonexecutive,52,former,2,36
425,1,1,utive,Ġold,0.000887,0.000887,0.009594,ADJ,ADJ,67,nonexecutive,50,old,2,36
427,1,1,utive,Ġformer,0.001773,0.001773,0.0103,ADJ,ADJ,67,nonexecutive,52,former,2,36
525,1,1,ĠBritish,Ġold,0.001359,0.001359,0.009886,ADJ,ADJ,71,British,50,old,2,36
527,1,1,ĠBritish,Ġformer,0.020154,0.020154,0.029166,ADJ,ADJ,71,British,52,former,2,36
540,1,1,ĠBritish,Ġnonex,0.082088,0.082088,0.04701,ADJ,ADJ,71,British,65,nonexecutive,2,36


Unnamed: 0,layer,head,from,to,valuezeroing,rollout_vz,raw_attention,from_pos,to_pos,from_ix,from_word,to_ix,to_word,stride,combo_count
330,1,1,ĠDutch,Ġof,0.013331,0.013331,0.01602,ADJ,ADP,38,Dutch,30,of,1,1
379,1,1,Ġnonex,Ġof,0.05387,0.05387,0.034978,ADJ,ADP,65,nonexecutive,54,of,2,14
404,1,1,ec,Ġof,0.00846,0.00846,0.012462,ADJ,ADP,66,nonexecutive,54,of,2,14
429,1,1,utive,Ġof,0.022166,0.022166,0.021126,ADJ,ADP,67,nonexecutive,54,of,2,14
529,1,1,ĠBritish,Ġof,0.007133,0.007133,0.009724,ADJ,ADP,71,British,54,of,2,14
544,1,1,ĠBritish,Ġof,0.020041,0.020041,0.016176,ADJ,ADP,71,British,69,of,2,14
554,1,1,Ġindustrial,Ġof,0.02685,0.02685,0.020725,ADJ,ADP,72,industrial,54,of,2,14
569,1,1,Ġindustrial,Ġof,0.071872,0.071872,0.032753,ADJ,ADP,72,industrial,69,of,2,14
352,1,1,Ġhigh,Ġof,0.000728,0.000728,0.006531,ADJ,ADP,89,high,77,of,3,4
50,1,1,Ġmore,Ġto,0.018948,0.018948,0.017726,ADJ,ADP,102,more,100,to,4,2


Unnamed: 0,layer,head,from,to,valuezeroing,rollout_vz,raw_attention,from_pos,to_pos,from_ix,from_word,to_ix,to_word,stride,combo_count
354,1,1,Ġhigh,Ġonce,0.006849,0.006849,0.022627,ADJ,ADV,89,high,79,once,3,1
581,1,1,Ġresilient,Ġago,0.00247,0.00247,0.011684,ADJ,ADV,123,resilient,106,ago,4,4
597,1,1,Ġresilient,Ġunusually,0.017999,0.017999,0.039437,ADJ,ADV,123,resilient,122,unusually,4,4
160,1,1,Ġbrief,Ġeven,0.033472,0.033472,0.028916,ADJ,ADV,132,brief,131,even,5,2


Unnamed: 0,layer,head,from,to,valuezeroing,rollout_vz,raw_attention,from_pos,to_pos,from_ix,from_word,to_ix,to_word,stride,combo_count
328,1,1,ĠDutch,Ġis,0.004561,0.004561,0.009763,ADJ,AUX,38,Dutch,28,is,1,1
387,1,1,Ġnonex,Ġwas,0.067388,0.067388,0.043833,ADJ,AUX,65,nonexecutive,62,was,2,7
412,1,1,ec,Ġwas,0.023389,0.023389,0.022541,ADJ,AUX,66,nonexecutive,62,was,2,7
437,1,1,utive,Ġwas,0.00152,0.00152,0.006158,ADJ,AUX,67,nonexecutive,62,was,2,7
537,1,1,ĠBritish,Ġwas,0.01404,0.01404,0.016251,ADJ,AUX,71,British,62,was,2,7
562,1,1,Ġindustrial,Ġwas,0.011842,0.011842,0.014975,ADJ,AUX,72,industrial,62,was,2,7
361,1,1,Ġhigh,Ġhas,0.002911,0.002911,0.012935,ADJ,AUX,89,high,86,has,3,1
596,1,1,Ġresilient,Ġis,0.005823,0.005823,0.012265,ADJ,AUX,123,resilient,121,is,4,2


Unnamed: 0,layer,head,from,to,valuezeroing,rollout_vz,raw_attention,from_pos,to_pos,from_ix,from_word,to_ix,to_word,stride,combo_count
51,1,1,Ġformer,Ġand,0.04502,0.04502,0.030238,ADJ,CCONJ,52,former,51,and,2,7
376,1,1,Ġnonex,Ġand,0.047662,0.047662,0.033901,ADJ,CCONJ,65,nonexecutive,51,and,2,7
401,1,1,ec,Ġand,0.028738,0.028738,0.023827,ADJ,CCONJ,66,nonexecutive,51,and,2,7
426,1,1,utive,Ġand,0.024066,0.024066,0.023905,ADJ,CCONJ,67,nonexecutive,51,and,2,7
526,1,1,ĠBritish,Ġand,0.008605,0.008605,0.011927,ADJ,CCONJ,71,British,51,and,2,7
551,1,1,Ġindustrial,Ġand,0.016415,0.016415,0.018013,ADJ,CCONJ,72,industrial,51,and,2,7


Unnamed: 0,layer,head,from,to,valuezeroing,rollout_vz,raw_attention,from_pos,to_pos,from_ix,from_word,to_ix,to_word,stride,combo_count
337,1,1,ĠDutch,Ġthe,0.07104,0.07104,0.031193,ADJ,DET,38,Dutch,37,the,1,1
389,1,1,Ġnonex,Ġa,0.144187,0.144187,0.04906,ADJ,DET,65,nonexecutive,64,a,2,14
414,1,1,ec,Ġa,0.044414,0.044414,0.024729,ADJ,DET,66,nonexecutive,64,a,2,14
439,1,1,utive,Ġa,0.207346,0.207346,0.054893,ADJ,DET,67,nonexecutive,64,a,2,14
539,1,1,ĠBritish,Ġa,0.028419,0.028419,0.0165,ADJ,DET,71,British,64,a,2,14
545,1,1,ĠBritish,Ġthis,0.084013,0.084013,0.032498,ADJ,DET,71,British,70,this,2,14
564,1,1,Ġindustrial,Ġa,0.048306,0.048306,0.023207,ADJ,DET,72,industrial,64,a,2,14
570,1,1,Ġindustrial,Ġthis,0.029781,0.029781,0.019465,ADJ,DET,72,industrial,70,this,2,14
350,1,1,Ġhigh,A,0.003424,0.003424,0.014181,ADJ,DET,89,high,75,A,3,3
363,1,1,Ġhigh,Ġa,0.00565,0.00565,0.013151,ADJ,DET,89,high,88,a,3,3


In [32]:
def show_head(ignores=[], sortby="valuezeroing", layer=1, head=1, top_n=5):
    am = {
        'raw_attention': "adjusted_attention", 
        'valuezeroing': "adjusted_vz",
        'rollout_vz': "adjusted_rollout_vz",
    }
    display(g[~g['from_pos'].isin(ignores) & ~g['to_pos'].isin(ignores) & (g['layer'] == layer) & (g['head'] == head)].sort_values(sortby, ascending=False).head(top_n))
    #display(ga[~ga['from_pos'].isin(ignores) & ~ga['to_pos'].isin(ignores) & (ga['layer'] == layer) & (ga['head'] == head)].sort_values(am[sortby], ascending=False).head(top_n))

In [33]:
w = widgets.interactive(show_head,
                ignores=widgets.SelectMultiple(
                        options=poslist,
                        value=['[CLS]', '[SEP]'],
                        description='Ignored POS',
                        rows=25,
                        disabled=False
                    ),
                sortby=widgets.RadioButtons(
                        options=['raw_attention', 'valuezeroing', 'rollout_vz'],
                        value='valuezeroing',
                        layout={'width': 'max-content'}, # If the items' names are long
                        description='sort by',
                    ),
                layer=widgets.IntSlider(min=1, max=12, value=1, step=1),
                head=widgets.IntSlider(min=1, max=12, value=1, step=1),
                top_n=widgets.IntSlider(min=3, max=20, value=10, step=1)
            )
display(w)

interactive(children=(SelectMultiple(description='Ignored POS', index=(2, 3), options=('[SELF]', '[CONTINUATIO…

In [34]:
def show_combo(from_pos, to_pos, sortby):
    with pandas.option_context("display.max_rows", 150):
        display(
            pandas.concat([
                g[(g['from_pos'] == from_pos) & (g['to_pos'] == to_pos)],
                #ga[(ga['from_pos'] == from_pos) & (ga['to_pos'] == to_pos)][['adjusted_attention', 'adjusted_vz', 'adjusted_rollout_vz']]
            ], axis=1).reset_index(drop=True).sort_values(sortby, ascending=False)
        )

In [35]:
w = widgets.interactive(show_combo,
                from_pos=widgets.Select(
                        options=poslist,
                        value='NOUN',
                    ),
                to_pos=widgets.Select(
                        options=poslist,
                        value='NOUN',
                    ),
                sortby=widgets.RadioButtons(
                        options=['raw_attention', 'valuezeroing', 'rollout_vz'],#, 'adjusted_attention', 'adjusted_vz', 'adjusted_rollout_vz'],
                        value='valuezeroing',
                        layout={'width': 'max-content'}, # If the items' names are long
                        description='sort by',
                    ),

            )
display(w)

interactive(children=(Select(description='from_pos', index=12, options=('[SELF]', '[CONTINUATION]', '[CLS]', '…