In [1]:
FROM_CACHE = False
GPT2_VERSION = 'gpt2-medium'

In [2]:
import winobias
from experiment import Model
from attention_utils import report_interventions_summary_by_head, report_interventions_summary_by_layer, report_intervention, perform_interventions
from transformers import GPT2Tokenizer
import json

To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


## Dev Set Results GPT2-Medium (filtering on odds ratio): 

In [None]:
fname = 'winobias_data/attention_intervention_results_gpt2medium_filtered.json'
if not FROM_CACHE:
	quantile = 0.25
	model = Model(output_attentions=True, gpt2_version=GPT2_VERSION)
	tokenizer = GPT2Tokenizer.from_pretrained(GPT2_VERSION)
	examples = winobias.load_dev_examples()
	
	df = winobias.analyze(examples, gpt2_version=GPT2_VERSION)
	df_expected = df[df.odds_ratio > 1]
	threshold = df_expected.odds_ratio.quantile(quantile)
	
	filtered_examples = []
	assert len(examples) == len(df)
	for i in range(len(examples)):
		ex = examples[i]
		odds_ratio = df.iloc[i].odds_ratio
		if odds_ratio > threshold:
	
			filtered_examples.append(ex)
	print(f'Num examples with odds ratio > 1: {len(df_expected)} / {len(examples)}')
	
	print(f'Num examples with odds ratio > {threshold:.4f} ({quantile} quantile): {len(filtered_examples)} / {len(examples)}')
			 
	examples = filtered_examples
	
	interventions = [ex.to_intervention(tokenizer) for ex in examples]
	results = perform_interventions(interventions, model)
	with open(fname, 'w') as f:
		json.dump(results, f)
with open(fname) as f:
	results = json.load(f)

Split: DEV, Filtered: False
Loaded 160 pairs. Skipped 38 pairs.


  0%|          | 0/160 [00:00<?, ?it/s]  1%|          | 1/160 [00:01<03:23,  1.28s/it]  1%|▏         | 2/160 [00:02<03:20,  1.27s/it]  2%|▏         | 3/160 [00:03<03:16,  1.25s/it]  2%|▎         | 4/160 [00:05<03:16,  1.26s/it]  3%|▎         | 5/160 [00:06<03:15,  1.26s/it]  4%|▍         | 6/160 [00:07<03:14,  1.26s/it]  4%|▍         | 7/160 [00:08<03:07,  1.22s/it]  5%|▌         | 8/160 [00:09<03:07,  1.23s/it]  6%|▌         | 9/160 [00:11<03:01,  1.20s/it]  6%|▋         | 10/160 [00:12<03:04,  1.23s/it]  7%|▋         | 11/160 [00:13<03:02,  1.22s/it]  8%|▊         | 12/160 [00:14<03:04,  1.25s/it]  8%|▊         | 13/160 [00:16<03:00,  1.23s/it]  9%|▉         | 14/160 [00:17<02:59,  1.23s/it]  9%|▉         | 15/160 [00:18<02:56,  1.22s/it] 10%|█         | 16/160 [00:19<02:55,  1.22s/it] 11%|█         | 17/160 [00:21<02:59,  1.26s/it] 11%|█▏        | 18/160 [00:22<02:56,  1.24s/it] 12%|█▏        | 19/160 [00:23<03:01,  1.28s/it] 12%|█▎        | 20/160 [00:24<02:56,

### Mean Effect

In [None]:
report_interventions_summary_by_head(results)

In [None]:
report_interventions_summary_by_layer(results)

### Examples
	

In [None]:
report_intervention(results[0])

In [None]:
report_intervention(results[1])

In [None]:
report_intervention(results[2])

In [None]:
report_intervention(results[3])
