In [1]:
device = 'cuda'

In [4]:
import transformers

tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-uncased')
model = transformers.AutoModelForMaskedLM.from_pretrained('bert-base-uncased').to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Make data

In [32]:
import json
import pathlib

file = pathlib.Path('probing-training-generic-names.json')
with file.open('r') as handle:
    samples = json.load(handle)
samples

[{'text': 'Carlos, the missionary, attended my wedding last Wednesday.',
  'label': 'missionary',
  'token': 1},
 {'text': 'My mother Christine, the rugby union player, went to the store.',
  'label': 'rugby union player',
  'token': 3},
 {'text': 'My friend Lori is tired from working as a radio personality all day',
  'label': 'radio personality',
  'token': 3},
 {'text': 'My father Josh dreams of becoming a television actor.',
  'label': 'television actor',
  'token': 3},
 {'text': 'My mother Sharon, the stage actor, attended my wedding last Wednesday.',
  'label': 'stage actor',
  'token': 3},
 {'text': 'My mother Julia, the diplomat, attended my wedding last Wednesday.',
  'label': 'diplomat',
  'token': 3},
 {'text': 'My friend Paula is tired from working as a entrepreneur all day',
  'label': 'entrepreneur',
  'token': 3},
 {'text': 'My father Christina is tired from working as a sport shooter all day',
  'label': 'sport shooter',
  'token': 3},
 {'text': "My mother Scott works a

In [40]:
occupations = {sample['label'] for sample in samples}
candidates_idx = set()
for occupation in occupations:
    ids = tokenizer(occupation, add_special_tokens=False)
    candidates_idx |= set(ids.input_ids)
len(occupations), len(candidates_idx)

(150, 190)

In [136]:
import torch
from torch.utils import data
from tqdm.auto import tqdm

for sample in samples:
    if 'prediction' in sample:
        del sample['prediction']

loader = data.DataLoader(samples, batch_size=64)

predictions, precomputed = [], []
for batch in tqdm(loader):
    texts = [
        text.replace(occupation, '[MASK]')
        for text, occupation in zip(batch['text'], batch['label'])
    ]
    inputs = tokenizer(texts, return_tensors='pt', padding='longest').to(device)

    with torch.inference_mode():
        outputs = model(**inputs, return_dict=True, output_hidden_states=True)
        
    batch_idx = range(len(texts))
    mask_idx = [
        ids.eq(tokenizer.mask_token_id).int().argmax().item()
        for ids in inputs.input_ids
    ]
    logits = outputs['logits'][batch_idx, mask_idx]
    logits[:, sorted(candidates_idx)] *= 10000
    ids = logits.argmax(dim=-1)
    tokens = tokenizer.batch_decode(ids)
    predictions.extend(tokens)

    reps = outputs.hidden_states[-1][batch_idx, batch['token']]
    precomputed.append(reps)
precomputed = torch.cat(precomputed)

for sample, pred in zip(samples, predictions):
    sample['prediction'] = pred

  0%|          | 0/7813 [00:00<?, ?it/s]

In [114]:
p = set()
for pred in predictions:
    p |= {pred}
len(p)
p

{'actor',
 'artist',
 'author',
 'catholic',
 'director',
 'driver',
 'editor',
 'farmer',
 'journalist',
 'judge',
 'lawyer',
 'manager',
 'minister',
 'model',
 'painter',
 'photographer',
 'pianist',
 'player',
 'poet',
 'politician',
 'priest',
 'psychologist',
 'rabbi',
 'scientist',
 'singer',
 'soldier',
 'teacher',
 'writer'}

# Train probe

In [150]:
EPOCHS = 5
BATCH_SIZE = 32
LR = 1e-3
PATIENCE = 4
HOLD_OUT = .5
EXCLUDE = .01

In [151]:
from torch.utils import data

class Dataset(data.Dataset):
    
    def __init__(self, samples, precomputed):
        self.samples = samples
        self.precomputed = precomputed
        
        indexer = {'unk': 0}
        for sample in samples:
            label = sample['label']
            if label not in indexer:
                indexer[label] = len(indexer)
        self.indexer = indexer
        
    def __getitem__(self, index):
        sample = self.samples[index]
        rep = self.precomputed[index]
        return rep, self.indexer.get(sample['prediction'], 0)

    def __len__(self):
        return len(self.samples)

dataset = Dataset(samples, precomputed)

exclude_size = int(EXCLUDE * len(dataset))
val_size = int(HOLD_OUT * len(dataset))
train_size = len(dataset) - val_size - exclude_size
train, val, exclude = data.random_split(dataset, (train_size, val_size, exclude_size))

train_loader = data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
val_loader = data.DataLoader(val, batch_size=BATCH_SIZE)

In [143]:
import torch
from torch import nn, optim
from tqdm.auto import tqdm

probe = nn.Sequential(
    nn.Linear(768, 768),
    nn.LeakyReLU(),
    nn.Linear(768, len(dataset.indexer)),
).to(device)
optimizer = optim.AdamW(probe.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

progress = tqdm(range(EPOCHS), desc='train probe')
best, bad, state_dict = float('inf'), 0, None
for epoch in progress:
    train_loss = 0.
    for reps, targets in train_loader:
        predictions = probe(reps)
        loss = criterion(predictions, targets.to(device))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    val_loss = 0.
    with torch.inference_mode():
        for reps, targets in val_loader:
            predictions = probe(reps)
            loss = criterion(predictions, targets.to(device))
            val_loss += loss.item()
    val_loss /= len(val_loader)

    progress.set_description(f'train probe (train={train_loss:.3f}, val={val_loss:.3f})')

    if val_loss < best:
        state_dict = probe.state_dict()
        best = val_loss
        bad = 0
    else:
        bad += 1

    if bad > PATIENCE:
        probe.load_state_dict(state_dict)
        break

train probe:   0%|          | 0/15 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [167]:
from torch.utils import data
from tqdm.auto import tqdm

@torch.inference_mode()
def test(dataset):
    loader = data.DataLoader(dataset, batch_size=BATCH_SIZE)
    correct = 0
    for reps, targets in tqdm(loader):
        predictions = probe(reps).argmax(dim=-1)
        correct += predictions.view(len(reps)).eq(targets.to(device).view(len(reps))).sum()
    return correct / len(dataset)

print(test(val))
print(test(exclude))

  0%|          | 0/7813 [00:00<?, ?it/s]

['unk', 'priest', 'teacher', 'photographer', 'unk', 'author', 'photographer', 'teacher', 'photographer', 'teacher', 'lawyer', 'unk', 'lawyer', 'writer', 'photographer', 'farmer', 'lawyer', 'teacher', 'teacher', 'lawyer', 'lawyer', 'singer', 'priest', 'photographer', 'singer', 'unk', 'lawyer', 'lawyer', 'teacher', 'photographer', 'photographer', 'lawyer']
0.0


  0%|          | 0/157 [00:00<?, ?it/s]

['priest', 'farmer', 'unk', 'lawyer', 'lawyer', 'lawyer', 'author', 'photographer', 'lawyer', 'singer', 'unk', 'author', 'lawyer', 'lawyer', 'lawyer', 'teacher', 'singer', 'priest', 'teacher', 'teacher', 'photographer', 'teacher', 'unk', 'author', 'singer', 'lawyer', 'priest', 'singer', 'model', 'photographer', 'author', 'singer']
0.0


In [161]:
import torch

@torch.inference_mode()
def predict(text, tokens=[1]):
    inputs = tokenizer(text, return_tensors='pt', padding='longest').to(device)
    print(tokenizer.convert_ids_to_tokens(inputs.input_ids.squeeze().tolist()))
    outputs = model(**inputs, return_dict=True, output_hidden_states=True)
    reps = outputs.hidden_states[-1][:, tokens].mean(dim=1)
    chosens = probe(reps).topk(k=3, dim=-1).indices.squeeze().tolist()
    unindexer = {idx: label for label, idx in dataset.indexer.items()}
    return [unindexer[chosen] for chosen in chosens]
predict('A person works in a hospital and uses a scalpel. He is a [MASK]', tokens=[2])

['[CLS]', 'a', 'person', 'works', 'in', 'a', 'hospital', 'and', 'uses', 'a', 'scalp', '##el', '.', 'he', 'is', 'a', '[MASK]', '[SEP]']


['teacher', 'lawyer', 'priest']

In [133]:
samples[:10]

[{'text': 'Carlos, the missionary, attended my wedding last Wednesday.',
  'label': 'missionary',
  'token': 1,
  'prediction': 'priest'},
 {'text': 'My mother Christine, the rugby union player, went to the store.',
  'label': 'rugby union player',
  'token': 3,
  'prediction': 'manager'},
 {'text': 'My friend Lori is tired from working as a radio personality all day',
  'label': 'radio personality',
  'token': 3,
  'prediction': 'model'},
 {'text': 'My father Josh dreams of becoming a television actor.',
  'label': 'television actor',
  'token': 3,
  'prediction': 'lawyer'},
 {'text': 'My mother Sharon, the stage actor, attended my wedding last Wednesday.',
  'label': 'stage actor',
  'token': 3,
  'prediction': 'singer'},
 {'text': 'My mother Julia, the diplomat, attended my wedding last Wednesday.',
  'label': 'diplomat',
  'token': 3,
  'prediction': 'singer'},
 {'text': 'My friend Paula is tired from working as a entrepreneur all day',
  'label': 'entrepreneur',
  'token': 3,
  'p