In [None]:
import json
import pathlib

file = pathlib.Path('probing-discourse.json')
with file.open('r') as handle:
    samples = json.load(handle)

In [None]:
samples[:5]

In [None]:
import transformers

device = 'cuda:1'

tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-uncased')
bert = transformers.AutoModelForMaskedLM.from_pretrained('bert-base-uncased').to(device)

In [None]:
import torch
probe = torch.load('probe_wikidata_occupation.pth', map_location=device)
indexer = torch.load('probe_wikidata_occupation_indexer.pth')
unindexer = {index: label for label, index in indexer.items()}

# Sanity check

In [None]:
import torch
from torch.utils import data
from tqdm.auto import tqdm

loader = data.DataLoader(samples, batch_size=64)

predictions = []
for batch in tqdm(loader):
    with torch.inference_mode():
        inputs = tokenizer(batch['text'], return_tensors='pt', padding='longest').to(device)
        with torch.inference_mode():
            outputs = bert(**inputs, return_dict=True, output_hidden_states=True)
            reps = outputs.hidden_states[-1][:, 6]
            topks = probe(reps).topk(k=5, dim=-1).indices.tolist()
        predictions.extend([
            [unindexer[idx] for idx in topk]
            for topk in topks
        ])

In [None]:
predictions[:10]

In [None]:
import collections

accuracies = collections.defaultdict(list)
for sample, topk in zip(samples, predictions):
    if sample['condition']['name'] != 'real' and sample['condition']['occupation'] == 'real':
        continue
    condition = frozenset(sample['condition'].items())
    accuracies[condition].append(sample['labels']['occupation'] in topk)
accuracies = {cond: sum(values) / len(values) for cond, values in accuracies.items()}

In [None]:
import pathlib

import tabulate

table = [('context', 'name', 'occupation', 'accuracy')]
for keys, accuracy in sorted(accuracies.items(), key=lambda kv: kv[-1], reverse=True):
    keys = dict(keys)
    row = [keys['context'], keys['name'], keys['occupation'], f'{accuracy:.3f}']
    table.append(row)


table_file = pathlib.Path('accuracy_table.txt')
with table_file.open('w') as handle:
    handle.write(tabulate.tabulate(table))

# Connect with causal predictions

In [None]:
import torch
from torch.utils import data
from tqdm.auto import tqdm

loader = data.DataLoader(samples, batch_size=64)

probe_predictions, model_predictions = [], []
for batch in tqdm(loader):
    with torch.inference_mode():
        inputs = tokenizer(batch['text'], return_tensors='pt', padding='longest').to(device)
        with torch.inference_mode():
            outputs = bert(**inputs, return_dict=True, output_hidden_states=True)
            reps = outputs.hidden_states[-1][:, 6]
            probe_topks = probe(reps).topk(k=5, dim=-1).indices.tolist()
        probe_predictions.extend([
            [unindexer[idx] for idx in topk]
            for topk in probe_topks
        ])

        texts = [
            f'{text} Therefore, {name} works as a [MASK].'
            for text, name in zip(batch['text'], batch['labels']['name'])
        ]
        inputs = tokenizer(texts, return_tensors='pt', padding='longest').to(device)
        with torch.inference_mode():
            outputs = bert(**inputs, return_dict=True, output_hidden_states=True)
        batch_idx = torch.arange(len(texts))
        token_idx = inputs.attention_mask.sum(dim=-1) - 3
        model_predictions_ids = outputs.logits[batch_idx, token_idx].argmax(dim=-1)
        model_predictions_tokens = tokenizer.batch_decode(model_predictions_ids.tolist())
        model_predictions.extend(model_predictions_tokens)

In [None]:
import collections

probe_accuracies = collections.defaultdict(list)
model_accuracies = collections.defaultdict(list)
probe_model_agreements = collections.defaultdict(list)
for sample, model_pred, probe_topk in zip(samples, model_predictions, probe_predictions):
    if sample['condition']['name'] != 'real' and sample['condition']['occupation'] == 'real':
        continue
    condition = frozenset(sample['condition'].items())
    probe_accuracies[condition].append(sample['labels']['occupation'] in probe_topk)
    model_accuracies[condition].append(sample['labels']['occupation'] == model_pred)
    probe_model_agreements[condition].append(model_pred in probe_topk)
probe_accuracies = {cond: sum(values) / len(values) for cond, values in probe_accuracies.items()}
model_accuracies = {cond: sum(values) / len(values) for cond, values in model_accuracies.items()}
probe_model_agreements = {cond: sum(values) / len(values) for cond, values in probe_model_agreements.items()}

table = [('context', 'name', 'occupation', 'probe accuracy', 'model accuracy', 'agreement')]
for keys in sorted(probe_accuracies.keys()):
    probe_accuracy = probe_accuracies[keys]
    model_accuracy = model_accuracies[keys]
    agreement = probe_model_agreements[keys]

    keys = dict(keys)
    row = [keys['context'], keys['name'], keys['occupation'], f'{probe_accuracy:.3f}', f'{model_accuracy:.3f}',
           f'{agreement:.3f}']
    table.append(row)


table_file = pathlib.Path('accuracy_table.txt')
with table_file.open('w') as handle:
    handle.write(tabulate.tabulate(table))
print(tabulate.tabulate(table))