In [1]:
#@title foo
#!pip install transformers==4.1.1 plotnine

## Setting stuff up

In [1]:
import re
import functools
import itertools
import copy

import numpy
import pandas

from IPython.display import HTML
import seaborn
import matplotlib
import ipywidgets as widgets

from ahviz import create_indices, create_dataframe, filter_mask
import torch
import datasets
from transformers import AutoModel, AutoTokenizer, AutoConfig
from valuezeroing import calculate_scores_for_batch

In [47]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()

In [3]:
# uncomment to force CPU if you have a GPU but not enough memory to do what you want. it will be slow of course
#device = torch.device("cpu")

In [4]:
## transformer = "distilbert-base-cased"
#transformer = "bert-base-cased"
#transformer = "gpt2"
#transformer = "gpt2-medium"
#transformer = "gpt2-large"
#transformer = "twmkn9/bert-base-uncased-squad2"
family = "gpt"
transformer = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(transformer)
# gpt2 doesn't do padding, so invent a padding token
# this one was suggested by the error you get when trying
# to do masking below, but it shouldn't matter as the actual
# tokens get ignored by the attention mask anyway
if family == "gpt":
    tokenizer.pad_token = tokenizer.eos_token

config = AutoConfig.from_pretrained(transformer, output_attentions=True)
model = AutoModel.from_pretrained(transformer, config=config)
model.to(device)
model.eval()

True

True

## data preparation

Read in the prepared data. Included in the repository is a copy of the penn treebank sample that is included in the `nltk` python package, converted into plain text and split into sentences. But you can replace this with any
text file. Since the first thing we do is join all the text, it isn't even neccessary to split it into sentences.

The script I used to create the file is `convert_corpus.py` in the repository

In [5]:
def encode(examples, prefix="sent_", left="more", right="less"):
    if prefix is None:
        prefix = ""
    pleft = f"{prefix}{left}"
    pright = f"{prefix}{right}"
        
    result = {}
    count = len(examples[pleft])
    alternating = list(itertools.chain.from_iterable(zip(examples[pleft], examples[pright])))
    result = tokenizer(alternating, padding=True)
    result['word_ids']  = [result.word_ids(n) for n in range(count*2)]
    result['tokens'] = [tokenizer.convert_ids_to_tokens(v) for v in result['input_ids']]
    result['token_ix'] = [list(range(len(v))) for v in result['input_ids']]
    result['side'] = [left, right] * count
    return result
encoder = functools.partial(encode, prefix="sent_", left="more", right="less")


In [49]:
def run_vz(ds):
    df = None
    with ds.formatted_as(type='torch', columns=['input_ids', 'attention_mask'], device="cuda"):
        dataloader = torch.utils.data.DataLoader(ds, batch_size=4)    
        for i, batch_data in enumerate(dataloader):
            attention_mask = batch_data['attention_mask'].cpu().numpy()

            valuezeroing_scores, rollout_valuezeroing_scores, attentions = calculate_scores_for_batch(
                    config,
                    model,
                    family,
                    batch_data['input_ids'],
                    batch_data['attention_mask']
                )
            batch_df = None
            for n in numpy.arange(attention_mask.shape[0]):
                tok_range = numpy.argwhere(attention_mask[n]).flatten()
                f = tok_range.min()
                l = tok_range.max() + 1
                index = pandas.MultiIndex.from_product(
                        [numpy.arange(12)+1, numpy.arange(12)+1, tok_range, tok_range],
                        names=['layer','head','from_ix', 'to_ix']
                    )
                scores_matrix = valuezeroing_scores[:,n,:,f:l, f:l]
                att_matrix = attentions.detach().cpu().numpy()[:,n,:,f:l,f:l]
                score_df = pandas.DataFrame(
                        numpy.hstack([
                                scores_matrix.reshape(-1, 1),
                                att_matrix.reshape(-1,1)
                            ]),
                        index=index,
                        columns=["raw_vz", "raw_attention"]
                    ).reset_index()
                score_df['example'] = i*2 + n//2
                score_df['side'] = "less" if n%2 else "more"
                if batch_df is None:
                    batch_df = score_df
                else:
                    batch_df = pandas.concat([batch_df, score_df])
            if df is None:
                df = batch_df
            else:
                df = pandas.concat([df, batch_df])
    return df

In [34]:
def merge(encoded_df, result_df):
    left_merge = (encoded_df
            .dropna()[['side', 'example', 'token_ix', 'tokens', 'word_ids']]
            .rename(columns={
                    'token_ix': "from_ix",
                    'tokens': "from_token",
                    'word_ids': "from_word_id",
                })
        )
    right_merge = (encoded_df
            .dropna()[['side', 'example', 'token_ix', 'tokens', 'word_ids']]
            .rename(columns={
                    'token_ix': "to_ix",
                    'tokens': "to_token",
                    'word_ids': "to_word_id",
                })
        )

    complete = result_df.merge(left_merge, how="left", on=["example", "side", "from_ix"])
    complete = complete.merge(right_merge, how="left", on=["example", "side", "to_ix"])
    return complete


In [35]:
example = 0

In [59]:
def build_data(more="foo", less="foo"):
    manual_data = dict(
            sent_more = [more],
            sent_less = [less],
            stereo_antistereo = ['stereo'],
            bias_type = ['gender']
        )
    manual_dataset = datasets.Dataset.from_dict(manual_data)
    encoded_dataset = manual_dataset.map(encoder, batched=True, remove_columns=['sent_more', 'sent_less', 'stereo_antistereo', 'bias_type'])

    encoded_df = encoded_dataset.to_pandas().explode(['input_ids', 'attention_mask', 'word_ids', 'token_ix', 'tokens']).reset_index(names="example")
    encoded_df['example'] //= 2

    result_df = run_vz(encoded_dataset)
    complete = merge(encoded_df, result_df)
    
    return complete

In [85]:
def show_diff(axes, val_type="raw_vz", complete=None, target=0,source=12):
    example=0
    c = complete[(complete['example'] == example) & (complete['from_ix'] == source) & (complete['to_ix'] <= complete['from_ix'])]
    mr = c[c['side'] == "more"]
    lr = c[c['side'] == "less"]
    diff = (lr.set_index(['layer', 'head', 'to_ix'])[['raw_vz']] - mr.set_index(['layer', 'head', 'to_ix'])[['raw_vz']]).reset_index().rename(columns={'raw_vz': "diff"})
    seaborn.heatmap(
            ax=axes[0],
            data=c[(c['side'] == 'more') & (c['to_ix'] == target)].pivot(index='layer', columns='head', values=val_type),
            vmin=0,
            vmax=1,
            cmap=seaborn.light_palette("seagreen", as_cmap=True)
        )
    axes[0].set_title(f"{mr[(mr['layer'] == 1) & (mr['head'] == 1) & (mr['from_ix'] == source)].reset_index().at[0, 'from_token']} (more stereotypical)")

    seaborn.heatmap(
            ax=axes[1],
            data=diff[diff['to_ix'] == target].pivot(index='layer', columns='head', values='diff'),
            vmin=-1,
            vmax=1,
            #annot=True,
            cmap=seaborn.color_palette("coolwarm", as_cmap=True)
        )
    axes[1].set_title("difference")
    axes[1].set_title(mr[(mr['layer'] == 1) & (mr['head'] == 1) & (mr['to_ix'] == target)].reset_index().at[0, 'to_token'], loc='left')
    axes[1].set_title(lr[(lr['layer'] == 1) & (lr['head'] == 1) & (lr['to_ix'] == target)].reset_index().at[0, 'to_token'], loc='right')

    seaborn.heatmap(
            ax=axes[2],
            data=c[(c['side'] == 'less') & (c['to_ix'] == target)].pivot(index='layer', columns='head', values=val_type),
            vmin=0,
            vmax=1,
            cmap=seaborn.light_palette("seagreen", as_cmap=True)
        )
    axes[2].set_title(f"{lr[(lr['layer'] == 1) & (lr['head'] == 1) & (lr['from_ix'] == source)].reset_index().at[0, 'from_token']} (less stereotypical)")

    for ax in axes:
        ax.invert_yaxis()
    return

In [108]:
%matplotlib agg
more_widget=widgets.Text(
        layout={'width': '100%'},
        description="'left' sentence",
        placeholder="fill me",
        disabled=False,
    )
less_widget=widgets.Text(
        layout={'width': '100%'},    
        description="'right' sentence",
        placeholder="fill me",
        disabled=False,
    )
output = widgets.Output(
        layout={'height': '900px'}
    )
run_button = widgets.Button(
        description='Process Sentences',
        disabled=False,
        button_style='',
        tooltip='Click me',
    )

@output.capture(clear_output=True, wait=True)
def on_button_clicked(b):
    complete = build_data(more=more_widget.value, less=less_widget.value)
    source_widget=widgets.IntSlider(
        value=12,
        min=0,
        max=17,
        step=1,
        description='source token:',
        continuous_update=False,
    )
    target_widget=widgets.IntSlider(
        value=0,
        min=0,
        max=10,
        step=1,
        description='target token:',
        continuous_update=False,
    )
    value_widget=widgets.RadioButtons(
        options=['raw_attention', 'raw_vz'],
        value='raw_vz',
        layout={'width': 'max-content'}, # If the items' names are long
        description='what to display',
    )
    nested = widgets.Output()
    fig, axes = matplotlib.pyplot.subplots(1, 3, figsize=(16, 6), sharey=True)    
    fig.suptitle(" ".join(list(itertools.takewhile(lambda v: v, [m if m == l else None for m, l in zip(complete[(complete['example'] == example) & (complete['side'] == "more") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'], complete[(complete['example'] == example) & (complete['side'] == "less") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'])])) + ["[...]"] + list(reversed(list(itertools.takewhile(lambda v: v, [m if m == l else None for m, l in zip(complete[(complete['example'] == example) & (complete['side'] == "more") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'][::-1], complete[(complete['example'] == example) & (complete['side'] == "less") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'][::-1])]))))))

    @nested.capture(clear_output=True, wait=True)
    def on_input_update(*args):
        matplotlib.pyplot.close()
        fig, axes = matplotlib.pyplot.subplots(1, 3, figsize=(16, 6), sharey=True)    
        show_diff(axes, complete=complete,target=target_widget.value,source=source_widget.value,val_type=value_widget.value)
        fig.suptitle(" ".join(list(itertools.takewhile(lambda v: v, [m if m == l else None for m, l in zip(complete[(complete['example'] == example) & (complete['side'] == "more") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'], complete[(complete['example'] == example) & (complete['side'] == "less") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'])])) + ["[...]"] + list(reversed(list(itertools.takewhile(lambda v: v, [m if m == l else None for m, l in zip(complete[(complete['example'] == example) & (complete['side'] == "more") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'][::-1], complete[(complete['example'] == example) & (complete['side'] == "less") & (complete['layer'] == 1) & (complete['head'] == 1) & (complete['to_ix'] == 0)]['from_token'][::-1])]))))))
        with nested:
            display(fig)
        
    @nested.capture(clear_output=True, wait=True)
    def on_source_update(*args):
        target_widget.max = source_widget.value
    
    source_widget.observe(on_source_update, 'value')
    source_widget.observe(on_input_update, 'value')
    target_widget.observe(on_input_update, 'value')
    value_widget.observe(on_input_update, 'value')

    show_diff(axes, complete=complete,target=target_widget.value,source=source_widget.value,val_type=value_widget.value)
    
    display(source_widget,target_widget,value_widget, nested)
    with nested:
        display(fig)

run_button.on_click(on_button_clicked)


display(more_widget, less_widget, run_button, output)

Text(value='', description="'left' sentence", layout=Layout(width='100%'), placeholder='fill me')

Text(value='', description="'right' sentence", layout=Layout(width='100%'), placeholder='fill me')

Button(description='Process Sentences', style=ButtonStyle(), tooltip='Click me')

Output(layout=Layout(height='900px'))