In [1]:
#@title foo
#!pip install transformers==4.1.1 plotnine

## Setting stuff up

In [2]:
import re
import itertools
from functools import partial
import copy

import numpy
import pandas

from IPython.display import HTML
import seaborn
import matplotlib
import ipywidgets as widgets

from ahviz import create_indices, create_dataframe, filter_mask
import torch
import datasets
from transformers import AutoModel, AutoTokenizer, AutoConfig
from valuezeroing import calculate_scores_for_batch

In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [4]:
# uncomment to force CPU if you have a GPU but not enough memory to do what you want. it will be slow of course
#device = torch.device("cpu")

In [5]:
## transformer = "distilbert-base-cased"
#transformer = "bert-base-cased"
#transformer = "gpt2"
#transformer = "gpt2-medium"
#transformer = "gpt2-large"
#transformer = "twmkn9/bert-base-uncased-squad2"
family = "gpt"
transformer = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(transformer, add_prefix_space=True)
# gpt2 doesn't do padding, so invent a padding token
# this one was suggested by the error you get when trying
# to do masking below, but it shouldn't matter as the actual
# tokens get ignored by the attention mask anyway
if family == "gpt":
    tokenizer.pad_token = tokenizer.eos_token

config = AutoConfig.from_pretrained(transformer, output_attentions=True)
model = AutoModel.from_pretrained(transformer, config=config)
model.to(device)
model.eval()


GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0): GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (1): GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP

## data preparation

Read in the prepared data. Included in the repository is a copy of the penn treebank sample that is included in the `nltk` python package, converted into plain text and split into sentences. But you can replace this with any
text file. Since the first thing we do is join all the text, it isn't even neccessary to split it into sentences.

The script I used to create the file is `convert_corpus.py` in the repository

In [6]:
pron_map = {
    'he': "she",
    'she': "he",
    'his': "her",
    'her': "his",
    'himself': "herself",
    'herself': "himself",
}

bug = (datasets
       .load_dataset(
                'csv', 
               data_files={
                   'full': "../BUG/data/full_BUG.csv",
#                   'balanced': "../BUG/data/balanced_BUG.csv"
               }
            )
       .remove_columns(['Unnamed: 0', 'sentence_text', 'predicted gender', 'stereotype', 'data_index'])
       .rename_column("tokens", "tokens_left")
       .map(lambda example: {"tokens_left": [t[1:-1] for t in example['tokens_left'][1:-1].split(", ")]}, desc="split token list")
       .map(lambda example: {"tokens_right": [pron_map[t] if t in pron_map else t for t in example['tokens_left']]}, desc="add opposite sentence")
    )

In [7]:
bf = bug['full']

In [8]:
bf = bf.select(range(10))

In [9]:
def encode(examples, left="sent_more", right="sent_less"):
    result = {}
    count = len(examples[left])
    alternating = list(itertools.chain.from_iterable(zip(examples[left], examples[right])))
    result = tokenizer(text=alternating, padding=True, is_split_into_words=True)
    result['word_ids']  = [result.word_ids(n) for n in range(count*2)]
    result['tokens'] = [tokenizer.convert_ids_to_tokens(v) for v in result['input_ids']]
    result['token_ix'] = [list(range(len(v))) for v in result['input_ids']]
    result['side'] = ['more', 'less'] * count
    return result
#encoded_dataset = crowS.map(encode, batched=True, remove_columns=['sent_more', 'sent_less', 'stereo_antistereo', 'bias_type'])
encoder = partial(encode, left="tokens_left", right="tokens_right")
encoded_dataset = bf.map(encoder, batched=True, remove_columns=['tokens_left', 'profession', 'g', 'profession_first_index', 'g_first_index', 'distance', 'num_of_pronouns', 'corpus', 'tokens_right'])


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

In [10]:
encoded_df = encoded_dataset.to_pandas().explode(['input_ids', 'attention_mask', 'word_ids', 'token_ix', 'tokens']).reset_index(names="example")
encoded_df['example'] //= 2
encoded_df = encoded_df[encoded_df['attention_mask'] == 1]
display(encoded_df.dropna()[['side', 'example', 'token_ix', 'tokens', 'word_ids']])

Unnamed: 0,side,example,token_ix,tokens,word_ids
0,more,0,0,ĠPatient,0.0
1,more,0,1,Ġnumber,1.0
2,more,0,2,Ġ2,2.0
3,more,0,3,Ġwas,3.0
4,more,0,4,Ġisolated,4.0
...,...,...,...,...,...
1125,less,9,23,Ġ30,23.0
1126,less,9,24,ĠO,24.0
1127,less,9,25,HC,24.0
1128,less,9,26,Ws,24.0


In [11]:
with pandas.option_context('display.max_rows', None):
    display(encoded_df[encoded_df['example'] == 2])

Unnamed: 0,example,input_ids,attention_mask,word_ids,tokens,token_ix,side
232,2,1881,1,0.0,ĠOne,0,more
233,2,5827,1,1.0,Ġpatient,1,more
234,2,3025,1,2.0,Ġwhose,2,more
235,2,277,1,3.0,Ġf,3,more
236,2,292,1,3.0,as,4,more
237,2,2413,1,3.0,cial,5,more
238,2,11685,1,4.0,Ġlayers,6,more
239,2,547,1,5.0,Ġwere,7,more
240,2,4838,1,6.0,Ġclosed,8,more
241,2,287,1,7.0,Ġin,9,more


In [13]:
df = None
with encoded_dataset.formatted_as(type='torch', columns=['input_ids', 'attention_mask'], device="cuda"):
    bs = 50
    dataloader = torch.utils.data.DataLoader(encoded_dataset, batch_size=bs)
    for i, batch_data in enumerate(dataloader):
        attention_mask = batch_data['attention_mask'].cpu().numpy()

        print(attention_mask.shape)
        valuezeroing_scores, rollout_valuezeroing_scores, attentions = calculate_scores_for_batch(
                config,
                model,
                family,
                batch_data['input_ids'],
                batch_data['attention_mask']
            )
        batch_df = None
        for n in numpy.arange(attention_mask.shape[0]):
            tok_range = numpy.argwhere(attention_mask[n]).flatten()
            f = tok_range.min()
            l = tok_range.max() + 1
            index = pandas.MultiIndex.from_product(
                    [numpy.arange(12)+1, numpy.arange(12)+1, tok_range, tok_range],
                    names=['layer','head','from_ix', 'to_ix']
                )
            scores_matrix = valuezeroing_scores[:,n,:,f:l, f:l]
            att_matrix = attentions.detach().cpu().numpy()[:,n,:,f:l,f:l]
            score_df = pandas.DataFrame(
                    numpy.hstack([
                            scores_matrix.reshape(-1, 1),
                            att_matrix.reshape(-1,1)
                        ]),
                    index=index,
                    columns=["raw_vz", "raw_attention"]
                ).reset_index()
            print(i, n, n%2, i*(bs//2) + n//2, "less" if n%2 else "more")
            score_df['example'] = i*(bs//2) + n//2
            score_df['side'] = "less" if n%2 else "more"
            if batch_df is None:
                batch_df = score_df
            else:
                batch_df = pandas.concat([batch_df, score_df])
        if df is None:
            df = batch_df
            print("NEW")
        else:
            df = pandas.concat([df, batch_df])
            print("ADD")

(20, 58)
0 0 0 0 more
0 1 1 0 less
0 2 0 1 more
0 3 1 1 less
0 4 0 2 more
0 5 1 2 less
0 6 0 3 more
0 7 1 3 less
0 8 0 4 more
0 9 1 4 less
0 10 0 5 more
0 11 1 5 less
0 12 0 6 more
0 13 1 6 less
0 14 0 7 more
0 15 1 7 less
0 16 0 8 more
0 17 1 8 less
0 18 0 9 more
0 19 1 9 less
NEW


In [14]:
left_merge = (encoded_df
        .dropna()[['side', 'example', 'token_ix', 'tokens', 'word_ids']]
        .rename(columns={
                'token_ix': "from_ix",
                'tokens': "from_token",
                'word_ids': "from_word_id",
            })
    )
right_merge = (encoded_df
        .dropna()[['side', 'example', 'token_ix', 'tokens', 'word_ids']]
        .rename(columns={
                'token_ix': "to_ix",
                'tokens': "to_token",
                'word_ids': "to_word_id",
            })
    )

complete = df.merge(left_merge, how="left", on=["example", "side", "from_ix"])
complete = complete.merge(right_merge, how="left", on=["example", "side", "to_ix"])
display(complete.sort_values(['layer', 'head', 'example', 'from_ix', 'to_ix', 'side']))

Unnamed: 0,layer,head,from_ix,to_ix,raw_vz,raw_attention,example,side,from_token,from_word_id,to_token,to_word_id
129600,1,1,0,0,1.000000,1.000000,0,less,ĠPatient,0.0,ĠPatient,0.0
0,1,1,0,0,1.000000,1.000000,0,more,ĠPatient,0.0,ĠPatient,0.0
129601,1,1,0,1,0.000000,0.000000,0,less,ĠPatient,0.0,Ġnumber,1.0
1,1,1,0,1,0.000000,0.000000,0,more,ĠPatient,0.0,Ġnumber,1.0
129602,1,1,0,2,0.000000,0.000000,0,less,ĠPatient,0.0,Ġ2,2.0
...,...,...,...,...,...,...,...,...,...,...,...,...
3552765,12,12,27,25,0.017225,0.010427,9,more,Ġ.,25.0,HC,24.0
3665662,12,12,27,26,0.017263,0.028149,9,less,Ġ.,25.0,Ws,24.0
3552766,12,12,27,26,0.017263,0.029191,9,more,Ġ.,25.0,Ws,24.0
3665663,12,12,27,27,0.016923,0.024015,9,less,Ġ.,25.0,Ġ.,25.0


In [15]:
example = 2

In [16]:
list(itertools.takewhile(lambda v: v, [m if m == l else None for m, l in zip(complete[(complete['example'] == example) & (complete['side'] == "more") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'], complete[(complete['example'] == example) & (complete['side'] == "less") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'])])) + ["[...]"] + list(reversed(list(itertools.takewhile(lambda v: v, [m if m == l else None for m, l in zip(complete[(complete['example'] == example) & (complete['side'] == "more") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'][::-1], complete[(complete['example'] == example) & (complete['side'] == "less") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'][::-1])]))))


['ĠOne',
 'Ġpatient',
 'Ġwhose',
 'Ġf',
 'as',
 'cial',
 'Ġlayers',
 'Ġwere',
 'Ġclosed',
 'Ġin',
 'Ġinterrupted',
 'Ġs',
 'ut',
 'ures',
 'Ġdeveloped',
 'Ġher',
 'nia',
 'Ġ11',
 'Ġmonths',
 'Ġafter',
 '[...]',
 'Ġoperation',
 'Ġ.']

In [17]:
complete[(complete['layer'] == 1) & (complete['head'] == 1) & (complete['side'] == 'less') & (complete['example'] == 2) & (complete['from_ix'] == 15)]

Unnamed: 0,layer,head,from_ix,to_ix,raw_vz,raw_attention,example,side,from_token,from_word_id,to_token,to_word_id
530409,1,1,15,0,0.030979,0.105512,2,less,Ġher,11.0,ĠOne,0.0
530410,1,1,15,1,0.086449,0.079994,2,less,Ġher,11.0,Ġpatient,1.0
530411,1,1,15,2,0.027553,0.048121,2,less,Ġher,11.0,Ġwhose,2.0
530412,1,1,15,3,0.020847,0.040854,2,less,Ġher,11.0,Ġf,3.0
530413,1,1,15,4,0.003936,0.016472,2,less,Ġher,11.0,as,3.0
530414,1,1,15,5,0.024637,0.047417,2,less,Ġher,11.0,cial,3.0
530415,1,1,15,6,0.068955,0.102268,2,less,Ġher,11.0,Ġlayers,4.0
530416,1,1,15,7,0.015745,0.034376,2,less,Ġher,11.0,Ġwere,5.0
530417,1,1,15,8,0.016109,0.054848,2,less,Ġher,11.0,Ġclosed,6.0
530418,1,1,15,9,0.011371,0.022675,2,less,Ġher,11.0,Ġin,7.0


In [18]:
example = 2

In [19]:
bf[example]

{'tokens_left': ['One',
  'patient',
  'whose',
  'fascial',
  'layers',
  'were',
  'closed',
  'in',
  'interrupted',
  'sutures',
  'developed',
  'hernia',
  '11',
  'months',
  'after',
  'her',
  'operation',
  '.'],
 'profession': 'patient',
 'g': 'her',
 'profession_first_index': 1,
 'g_first_index': 15,
 'distance': 14,
 'num_of_pronouns': 1,
 'corpus': 'covid19',
 'tokens_right': ['One',
  'patient',
  'whose',
  'fascial',
  'layers',
  'were',
  'closed',
  'in',
  'interrupted',
  'sutures',
  'developed',
  'hernia',
  '11',
  'months',
  'after',
  'his',
  'operation',
  '.']}

In [20]:
bf[example]['profession_first_index'], complete[(complete['example'] == example) & (complete['from_word_id'] == bf[example]['profession_first_index'])]['from_ix'].min()

(1, 1)

In [21]:
bf[example]['g_first_index'], complete[(complete['example'] == example) & (complete['to_word_id'] == bf[example]['g_first_index'])]['to_ix'].min()

(15, 20)

In [22]:
import textwrap

In [23]:
def show_diff(value="raw_vz", target=0, example=0, source=10):
    c = complete[(complete['example'] == example) & (complete['from_ix'] == source) & (complete['to_ix'] <= complete['from_ix'])]
    mr = c[c['side'] == "more"]
    lr = c[c['side'] == "less"]
    diff = (lr.set_index(['layer', 'head', 'to_ix'])[['raw_vz']] - mr.set_index(['layer', 'head', 'to_ix'])[['raw_vz']]).reset_index().rename(columns={'raw_vz': "diff"})
    fig, axes = matplotlib.pyplot.subplots(1, 3, figsize=(16, 6), sharey=True)
    fig.subplots_adjust(top=0.8)
    #fig.suptitle(f'average {distance} distance per head for the two datasets, and the difference per head')
    #fig.suptitle(" ".join([t if i != source else "..." for i, t in enumerate(complete[(complete['example'] == example) & (complete['side'] == "more")]['from_token'].unique())]))
    titletext = (" ".join(list(itertools.takewhile(lambda v: v, [m if m == l else None for m, l in zip(complete[(complete['example'] == example) & (complete['side'] == "more") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'], complete[(complete['example'] == example) & (complete['side'] == "less") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'])])) + ["[...]"] + list(reversed(list(itertools.takewhile(lambda v: v, [m if m == l else None for m, l in zip(complete[(complete['example'] == example) & (complete['side'] == "more") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'][::-1], complete[(complete['example'] == example) & (complete['side'] == "less") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'][::-1])]))))))
    titletext = "\n".join(textwrap.wrap(titletext, width=100))
    fig.suptitle(titletext)
    seaborn.heatmap(
            ax=axes[0],
            data=c[(c['side'] == 'more') & (c['to_ix'] == target)].pivot(index='layer', columns='head', values=value),
            vmin=0,
            vmax=1,
            cmap=seaborn.light_palette("seagreen", as_cmap=True)
        )
    axes[0].set_title(f"{mr[(mr['layer'] == 1) & (mr['head'] == 1) & (mr['from_ix'] == source)].reset_index().at[0, 'from_token']} (more stereotypical)")

    seaborn.heatmap(
            ax=axes[1],
            data=diff[diff['to_ix'] == target].pivot(index='layer', columns='head', values='diff'),
            vmin=-1,
            vmax=1,
            #annot=True,
            cmap=seaborn.color_palette("coolwarm", as_cmap=True)
        )
    axes[1].set_title("difference")
    axes[1].set_title(mr[(mr['layer'] == 1) & (mr['head'] == 1) & (mr['to_ix'] == target)].reset_index().at[0, 'to_token'], loc='left')
    axes[1].set_title(lr[(lr['layer'] == 1) & (lr['head'] == 1) & (lr['to_ix'] == target)].reset_index().at[0, 'to_token'], loc='right')

    seaborn.heatmap(
            ax=axes[2],
            data=c[(c['side'] == 'less') & (c['to_ix'] == target)].pivot(index='layer', columns='head', values=value),
            vmin=0,
            vmax=1,
            cmap=seaborn.light_palette("seagreen", as_cmap=True)
        )
    axes[2].set_title(f"{lr[(lr['layer'] == 1) & (lr['head'] == 1) & (lr['from_ix'] == source)].reset_index().at[0, 'from_token']} (less stereotypical)")

    for ax in axes:
        ax.invert_yaxis()

#    matplotlib.pyplot.show()
    return fig

In [24]:
example_widget=widgets.IntSlider(
        value=example,
        min=complete['example'].min(),
        max=complete['example'].max(),
        description="Which sentence to display",
    )
source_widget=widgets.IntSlider(
        value=complete[(complete['example'] == example) & (complete['to_word_id'] == bf[example]['g_first_index'])]['to_ix'].min(),
        min=0,
        max=complete[complete['example'] == example]['from_ix'].max(),
        step=1,
        description='source token:',
    )
target_widget=widgets.IntSlider(
        value=complete[(complete['example'] == example) & (complete['from_word_id'] == bf[example]['profession_first_index'])]['from_ix'].min(),
        min=0,
        max=complete[complete['example'] == example]['to_ix'].max(),
        step=1,
        description='target token:',
    )
def update_for_example(*args):
    source_widget.value = complete[(complete['example'] == example_widget.value) & (complete['to_word_id'] == bf[example_widget.value]['g_first_index'])]['to_ix'].min()
    source_widget.max = complete[complete['example'] == example_widget.value]['from_ix'].max()
    target_widget.value = complete[(complete['example'] == example_widget.value) & (complete['from_word_id'] == bf[example_widget.value]['profession_first_index'])]['from_ix'].min()
    target_widget.max = source_widget.value

example_widget.observe(update_for_example, 'value')

def update_target_range(*args):
    target_widget.max = source_widget.value
source_widget.observe(update_target_range, 'value')

value_widget=widgets.RadioButtons(
        options=['raw_attention', 'raw_vz'],
        value='raw_vz',
        layout={'width': 'max-content'}, # If the items' names are long
        description='what to display',
    )
w = widgets.interactive(show_diff,
        example=example_widget,
        source=source_widget,
        target=target_widget,
        value=value_widget,
    )
display(w)

interactive(children=(RadioButtons(description='what to display', index=1, layout=Layout(width='max-content'),…