In [1]:
import json
import pathlib

file = pathlib.Path('probing-discourse.json')
with file.open('r') as handle:
    samples = json.load(handle)

In [2]:
samples[:5]

[{'condition': {'name': 'real', 'occupation': 'real', 'context': 'primary'},
  'labels': {'name': 'juan juarez fernandez', 'occupation': 'politician'},
  'text': 'This is a story about juan juarez fernandez who works as a politician.'},
 {'condition': {'name': 'real', 'occupation': 'real', 'context': 'secondary'},
  'labels': {'name': 'juan juarez fernandez', 'occupation': 'politician'},
  'text': 'This is a story about juan juarez fernandez who forgot to bring a pseudonym to their job at the political party.'},
 {'condition': {'name': 'real', 'occupation': 'real', 'context': 'irrelevant'},
  'labels': {'name': 'juan juarez fernandez', 'occupation': 'politician'},
  'text': 'This is a story about juan juarez fernandez who climbed a hill.'},
 {'condition': {'name': 'real', 'occupation': 'fake', 'context': 'primary'},
  'labels': {'name': 'juan juarez fernandez',
   'occupation': 'table tennis player'},
  'text': 'This is a story about juan juarez fernandez who works as a table tennis pl

In [15]:
import transformers

device = 'cuda:1'

tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-uncased')
bert = transformers.AutoModelForMaskedLM.from_pretrained('bert-base-uncased').to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
import torch
probe = torch.load('probe_wikidata_occupation.pth', map_location=device)
indexer = torch.load('probe_wikidata_occupation_indexer.pth')
unindexer = {index: label for label, index in indexer.items()}

# Sanity check

In [5]:
import torch
from torch.utils import data
from tqdm.auto import tqdm

loader = data.DataLoader(samples, batch_size=64)

predictions = []
for batch in tqdm(loader):
    with torch.inference_mode():
        inputs = tokenizer(batch['text'], return_tensors='pt', padding='longest').to(device)
        with torch.inference_mode():
            outputs = bert(**inputs, return_dict=True, output_hidden_states=True)
            reps = outputs.hidden_states[-1][:, 6]
            topks = probe(reps).topk(k=5, dim=-1).indices.tolist()
        predictions.extend([
            [unindexer[idx] for idx in topk]
            for topk in topks
        ])

  0%|          | 0/4881 [00:00<?, ?it/s]

In [6]:
predictions[:10]

[['politician',
  'sportsperson',
  'actor',
  'association football player',
  'artist'],
 ['sportsperson',
  'handball player',
  'badminton player',
  'television presenter',
  'artist'],
 ['samurai', 'sprinter', 'mangaka', 'sportsperson', 'musician'],
 ['sportsperson',
  'volleyball player',
  'handball player',
  'badminton player',
  'association football player'],
 ['sportsperson',
  'field hockey player',
  'handball player',
  'volleyball player',
  'table tennis player'],
 ['samurai', 'sprinter', 'mangaka', 'sportsperson', 'musician'],
 ['politician', 'businessperson', 'sportsperson', 'entrepreneur', 'mangaka'],
 ['sportsperson',
  'comics artist',
  'badminton player',
  'businessperson',
  'visual artist'],
 ['mangaka', 'samurai', 'musician', 'farmer', 'sprinter'],
 ['volleyball player',
  'badminton player',
  'sportsperson',
  'basketball player',
  'field hockey player']]

In [7]:
import collections

accuracies = collections.defaultdict(list)
for sample, topk in zip(samples, predictions):
    if sample['condition']['name'] != 'real' and sample['condition']['occupation'] == 'real':
        continue
    condition = frozenset(sample['condition'].items())
    accuracies[condition].append(sample['labels']['occupation'] in topk)
accuracies = {cond: sum(values) / len(values) for cond, values in accuracies.items()}

In [8]:
import pathlib

import tabulate

table = [('context', 'name', 'occupation', 'accuracy')]
for keys, accuracy in sorted(accuracies.items(), key=lambda kv: kv[-1], reverse=True):
    keys = dict(keys)
    row = [keys['context'], keys['name'], keys['occupation'], f'{accuracy:.3f}']
    table.append(row)


table_file = pathlib.Path('accuracy_table.txt')
with table_file.open('w') as handle:
    handle.write(tabulate.tabulate(table))

# Connect with causal predictions

In [39]:
import torch
from torch.utils import data
from tqdm.auto import tqdm

loader = data.DataLoader(samples, batch_size=64)

probe_predictions, model_predictions = [], []
for batch in tqdm(loader):
    with torch.inference_mode():
        inputs = tokenizer(batch['text'], return_tensors='pt', padding='longest').to(device)
        with torch.inference_mode():
            outputs = bert(**inputs, return_dict=True, output_hidden_states=True)
            reps = outputs.hidden_states[-1][:, 6]
            probe_topks = probe(reps).topk(k=5, dim=-1).indices.tolist()
        probe_predictions.extend([
            [unindexer[idx] for idx in topk]
            for topk in probe_topks
        ])

        texts = [
            f'{text} Therefore, {name} works as a [MASK].'
            for text, name in zip(batch['text'], batch['labels']['name'])
        ]
        inputs = tokenizer(texts, return_tensors='pt', padding='longest').to(device)
        with torch.inference_mode():
            outputs = bert(**inputs, return_dict=True, output_hidden_states=True)
        batch_idx = torch.arange(len(texts))
        token_idx = inputs.attention_mask.sum(dim=-1) - 3
        model_predictions_ids = outputs.logits[batch_idx, token_idx].argmax(dim=-1)
        model_predictions_tokens = tokenizer.batch_decode(model_predictions_ids.tolist())
        model_predictions.extend(model_predictions_tokens)

  0%|          | 0/4881 [00:00<?, ?it/s]

In [40]:
import collections

probe_accuracies = collections.defaultdict(list)
model_accuracies = collections.defaultdict(list)
probe_model_agreements = collections.defaultdict(list)
for sample, model_pred, probe_topk in zip(samples, model_predictions, probe_predictions):
    if sample['condition']['name'] != 'real' and sample['condition']['occupation'] == 'real':
        continue
    condition = frozenset(sample['condition'].items())
    probe_accuracies[condition].append(sample['labels']['occupation'] in probe_topk)
    model_accuracies[condition].append(sample['labels']['occupation'] == model_pred)
    probe_model_agreements[condition].append(model_pred in probe_topk)
probe_accuracies = {cond: sum(values) / len(values) for cond, values in probe_accuracies.items()}
model_accuracies = {cond: sum(values) / len(values) for cond, values in model_accuracies.items()}
probe_model_agreements = {cond: sum(values) / len(values) for cond, values in probe_model_agreements.items()}

table = [('context', 'name', 'occupation', 'probe accuracy', 'model accuracy', 'agreement')]
for keys in sorted(probe_accuracies.keys()):
    probe_accuracy = probe_accuracies[keys]
    model_accuracy = model_accuracies[keys]
    agreement = probe_model_agreements[keys]

    keys = dict(keys)
    row = [keys['context'], keys['name'], keys['occupation'], f'{probe_accuracy:.3f}', f'{model_accuracy:.3f}',
           f'{agreement:.3f}']
    table.append(row)


table_file = pathlib.Path('accuracy_table.txt')
with table_file.open('w') as handle:
    handle.write(tabulate.tabulate(table))
print(tabulate.tabulate(table))

----------  ----  ----------  --------------  --------------  ---------
context     name  occupation  probe accuracy  model accuracy  agreement
primary     real  real        0.450           0.532           0.246
secondary   real  real        0.232           0.206           0.103
irrelevant  real  real        0.044           0.003           0.021
primary     real  fake        0.358           0.421           0.181
secondary   real  fake        0.170           0.123           0.074
irrelevant  real  fake        0.034           0.003           0.021
primary     fake  fake        0.512           0.290           0.142
secondary   fake  fake        0.201           0.104           0.055
irrelevant  fake  fake        0.035           0.004           0.000
primary     none  fake        0.174           0.408           0.106
secondary   none  fake        0.102           0.095           0.078
irrelevant  none  fake        0.033           0.000           0.000
----------  ----  ----------  ----------