Visualize probe accuracies.

In [None]:
import json
import pathlib

probe_accuracies_file = pathlib.Path('../results/probe_occupations/gpt-j-6B/accuracies.json')
with probe_accuracies_file.open('r') as handle:
    accuracies = json.load(handle)
accuracies

In [None]:
from collections import defaultdict

layers_by_target = defaultdict(list)
accuracies_by_target = defaultdict(list)

for target in ('occupation', 'predictions'):
    true_layers = []
    true_accuracies = []
    for entry in sorted(accuracies, key=lambda entry: entry['layer']):
        if 'top-k' not in entry or entry['top-k'] != 3:
            continue
        if entry['target'] != target:
            continue
        layers_by_target[target].append(entry['layer'])
        accuracies_by_target[target].append(entry['accuracy'])

print(layers_by_target['occupation'])
print(accuracies_by_target['occupation'])

In [None]:
import matplotlib.pyplot as plt
import numpy

for target in ('occupation', 'predictions'):
    plt.figure()
    plt.bar(layers_by_target[target], accuracies_by_target[target])
    plt.title(f'probe {target}')
    plt.xlabel('layer')
    plt.ylabel('probe accuracy')
    plt.yticks(tuple(numpy.arange(0, 1.1, .1)))

Visualize discourse probing results.

In [None]:
import json
import pathlib

results_file = pathlib.Path(
    '../data/gpt-j-6B/occupations-discourse-predicted.json')
with results_file.open('r') as handle:
    results = json.load(handle)

In [None]:
from collections import defaultdict

corrects, totals = defaultdict(int), defaultdict(int)
for result in results:
    condition = frozenset(result['condition'].items())
    corrects[condition] += result['occupation'] in result['predictions']
    totals[condition] += 1
accuracies = {key: correct / totals[key] for key, correct in corrects.items()}
accuracies

In [None]:
from collections import defaultdict

def condition_set_to_text(condition):
    condition = dict(condition)
    result = f'{condition["context"]} ctx'
    if condition['entity'] == 'famous':
        result = f'{result}, {condition["occupation"]} occ'
    return result

xs_by_entity, ys_by_entity = defaultdict(list), defaultdict(list)
for entity in ('famous', 'generic'):
    for occ in ('correct', 'random'):
        for condition, accuracy in accuracies.items():
            condition = dict(condition)
            if condition['entity'] != entity or condition['occupation'] != occ:
                continue
            xs_by_entity[entity, occ].append(condition_set_to_text(condition))
            ys_by_entity[entity, occ].append(accuracy)
xs_by_entity, ys_by_entity

In [None]:
import matplotlib.pyplot as plt
import numpy

def plot(entity, occupation, title):
    plt.figure(figsize=(10, 5))
    plt.bar(xs_by_entity[entity, occupation], ys_by_entity[entity, occupation])
    plt.title(title)
    plt.ylabel('model accuracy')
    plt.yticks(tuple(numpy.arange(0, 1.1, .1)))

plot('famous', 'correct', 'famous entity, correct occupation')
plot('famous', 'random', 'famous entity, random occupation')
plot('generic', 'random', 'generic entity')

# Agreement

In [None]:
import json
import pathlib

results_file = pathlib.Path(
    '../results/probe_discourse/gpt-j-6B/occupations-discourse-probed.json')
with results_file.open('r') as handle:
    results = json.load(handle)

In [None]:
from collections import defaultdict
from tqdm.auto import tqdm

accuracies_by_layer = {}
for layer in tqdm(range(29)):
    corrects, totals = defaultdict(int), defaultdict(int)
    for result in results:
        condition = frozenset(result['condition'].items())
        corrects[condition] += result['predictions'][0] in result['probed'][str(layer)][:3]
        totals[condition] += 1
    accuracies_by_layer[layer] = {key: correct / totals[key] for key, correct in corrects.items()}

accuracies = defaultdict(dict)
for layer, accuracies_by_condition in accuracies_by_layer.items():
    for condition, accuracy in accuracies_by_condition.items():
        accuracies[condition][layer] = accuracy

outputs = {}
for condition, by_layer in accuracies.items():
    layer, accuracy = max(by_layer.items(), key=lambda kv: kv[-1])
    outputs[condition] = {'layer': layer, 'accuracy': accuracy}
outputs