Load the data.

In [1]:
import json
import pathlib

file = pathlib.Path('probing-training-generic-names.json')
with file.open('r') as handle:
    samples = json.load(handle)

Load the model and precompute the representations.

In [2]:
import transformers

device = 'cuda:1'

tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-uncased')
model = transformers.AutoModel.from_pretrained('bert-base-uncased').to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
import torch
from torch.utils import data
from tqdm.auto import tqdm

loader = data.DataLoader(samples, batch_size=32)
precomputed = []
for batch in tqdm(loader):
    inputs = tokenizer(batch['text'], return_tensors='pt', padding='longest').to(device)
    with torch.inference_mode():
        outputs = model(**inputs, return_dict=True, output_hidden_states=True)
        reps = outputs.hidden_states[-1][range(len(batch['text'])), batch['token']]
    precomputed.append(reps)
precomputed = torch.cat(precomputed)
precomputed.shape

  0%|          | 0/15625 [00:00<?, ?it/s]

torch.Size([500000, 768])

Train the probe.

In [4]:
EPOCHS = 15
BATCH_SIZE = 32
LR = 1e-3
PATIENCE = 4
HOLD_OUT = .05
EXCLUDE = .01

In [5]:
from torch.utils import data

class Dataset(data.Dataset):
    
    def __init__(self, samples, precomputed):
        self.samples = samples
        self.precomputed = precomputed
        
        indexer = {}
        for sample in samples:
            label = sample['label']
            if label not in indexer:
                indexer[label] = len(indexer)
        self.indexer = indexer
        
    def __getitem__(self, index):
        sample = self.samples[index]
        rep = self.precomputed[index]
        return rep, self.indexer[sample['label']]
    
    def __len__(self):
        return len(self.samples)
    
dataset = Dataset(samples, precomputed)

exclude_size = int(EXCLUDE * len(dataset))
val_size = int(HOLD_OUT * len(dataset))
train_size = len(dataset) - val_size - exclude_size
train, val, exclude = data.random_split(dataset, (train_size, val_size, exclude_size))

train_loader = data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
val_loader = data.DataLoader(val, batch_size=BATCH_SIZE)

In [6]:
import torch
from torch import nn, optim
from tqdm.auto import tqdm

probe = nn.Sequential(
    nn.Linear(768, 768),
    nn.LeakyReLU(),
    nn.Linear(768, len(dataset.indexer)),
).to(device)
optimizer = optim.AdamW(probe.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

progress = tqdm(range(EPOCHS), desc='train probe')
best, bad, state_dict = float('inf'), 0, None
for epoch in progress:
    train_loss = 0.
    for reps, targets in train_loader:
        predictions = probe(reps)
        loss = criterion(predictions, targets.to(device))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    val_loss = 0.
    with torch.inference_mode():
        for reps, targets in val_loader:
            predictions = probe(reps)
            loss = criterion(predictions, targets.to(device))
            val_loss += loss.item()
    val_loss /= len(val_loader)

    progress.set_description(f'train probe (train={train_loss:.3f}, val={val_loss:.3f})')

    if val_loss < best:
        state_dict = probe.state_dict()
        best = val_loss
        bad = 0
    else:
        bad += 1

    if bad > PATIENCE:
        probe.load_state_dict(state_dict)
        break

train probe:   0%|          | 0/15 [00:00<?, ?it/s]

In [7]:
from torch.utils import data
from tqdm.auto import tqdm

@torch.inference_mode()
def test(dataset):
    loader = data.DataLoader(dataset, batch_size=BATCH_SIZE)
    correct = 0
    for reps, targets in tqdm(loader):
        predictions = probe(reps).argmax(dim=-1)
        correct += predictions.view(len(reps)).eq(targets.to(device).view(len(reps))).sum()
    return correct / len(dataset)

print(test(val))
print(test(exclude))

  0%|          | 0/782 [00:00<?, ?it/s]

tensor(0.9777, device='cuda:1')


  0%|          | 0/157 [00:00<?, ?it/s]

tensor(0.9798, device='cuda:1')


In [19]:
import torch

@torch.inference_mode()
def predict(text, tokens=[1]):
    inputs = tokenizer(text, return_tensors='pt', padding='longest').to(device)
    print(tokenizer.convert_ids_to_tokens(inputs.input_ids.squeeze().tolist()))
    outputs = model(**inputs, return_dict=True, output_hidden_states=True)
    reps = outputs.hidden_states[-1][:, tokens].mean(dim=1)
    chosens = probe(reps).topk(k=5, dim=-1).indices.squeeze().tolist()
    unindexer = {idx: label for label, idx in dataset.indexer.items()}
    return [unindexer[chosen] for chosen in chosens]
predict('John has patients.', tokens=[1])

['[CLS]', 'john', 'has', 'patients', '.', '[SEP]']


['judge', 'theologian', 'psychiatrist', 'physician', 'officer']

In [9]:
#torch.save(probe, 'probe_wikidata_occupation.pth')
#torch.save(dataset.indexer, 'probe_wikidata_occupation_indexer.pth')

In [10]:
dataset.indexer.keys()

dict_keys(['editor', 'tarento', 'field hockey player', 'psychiatrist', 'educator', 'sportsperson', 'linguist', 'seiyū', 'herpetologist', 'mathematician', 'art historian', 'australian rules footballer', 'translator', 'artistic gymnast', 'opera singer', 'zoologist', 'ice hockey player', 'rabbi', 'boxer', 'music educator', 'fencer', 'screenwriter', 'television presenter', 'biathlete', 'physicist', 'beauty pageant contestant', 'geologist', 'sculptor', 'basketball player', 'motorcycle racer', 'musicologist', 'rugby league player', 'alpine skier', 'association football referee', 'musician', 'aviator', 'judge', 'officer', 'head coach', 'businessperson', 'monk', 'journalist', 'rugby union player', 'trade unionist', 'athletics competitor', 'illustrator', 'racecar driver', 'priest', 'non-fiction writer', 'sport cyclist', 'sport shooter', 'author', 'archaeologist', 'judoka', 'sociologist', 'golfer', 'av idol', 'photographer', 'guitarist', 'university teacher', 'police officer', 'economist', 'dipl