In [None]:
import pathlib
import pickle

bios_file = pathlib.Path('../biosbias/BIOS.pkl')
with bios_file.open('rb') as handle:
    data = pickle.load(handle)

In [None]:
data[0]

In [None]:
title_indexer = {}
for x in data:
    title = x['title']
    if title not in title_indexer:
        title_indexer[title] = len(title_indexer)

gender_indexer = {'M': 0, 'F': 1}

title_indexer, gender_indexer

In [None]:
device = 'cuda:1'

In [None]:
import transformers

tokenizer = transformers.AutoTokenizer.from_pretrained('roberta-base')
model = transformers.AutoModelForMaskedLM.from_pretrained('roberta-base').to(device)

In [None]:
import torch.utils.data


class Dataset(torch.utils.data.Dataset):
    
    def __init__(self, data):
        self.data = data
    
    def __getitem__(self, index):
        x = self.data[index]
        bio = x['raw']
        bio = bio[bio.index(x['name'][0]):]
        mention = 1
        return bio, title_indexer[x['title']], mention, gender_indexer[x['gender']]

    def __len__(self):
        return len(self.data)

dataset = Dataset(data)

In [None]:
dataset[110]

In [None]:
from tqdm.auto import tqdm

for i in tqdm(range(len(dataset))):
    dataset[i]

In [None]:
import torch
from torch import nn, optim
from tqdm.auto import tqdm

model.to(device)
probe = nn.Sequential(
    nn.Linear(768, 768),
    nn.ReLU(),
    nn.Linear(768, len(title_indexer)),
).to(device)
optimizer = optim.AdamW(probe.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

val_size = int(.1 * len(dataset))
train_size = len(dataset) - val_size
train, val = torch.utils.data.random_split(dataset, (train_size, val_size))
train_loader = torch.utils.data.DataLoader(train, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val, batch_size=32)

bad, best, state_dict = 0, float('inf'), None
for epoch in range(1):
    description = f'epoch {epoch}'
    progress = tqdm(train_loader, desc=description)

    probe.train()
    train_loss = 0
    for sentences, targets, mentions, _ in progress:
        inputs = tokenizer(list(sentences), return_tensors='pt', padding='longest').to(device)
        outputs = model(**inputs, return_dict=True, output_hidden_states=True)
        reps = outputs.hidden_states[-1][range(len(sentences)), sorted(mentions)]
        predictions = probe(reps)
        loss = criterion(predictions, targets.to(device))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_loss += loss.item()
        progress.set_description(f'{description} (loss={loss.item():.3f})')
    train_loss /= len(train_loader)
    print('train', train_loss)

In [None]:
reverse_title_indexer = {index: title for title, index in title_indexer.items()}

@torch.inference_mode()
def predict(sentence, token=1):
    inputs = tokenizer([sentence], return_tensors='pt', padding='longest').to(device)
    print(tokenizer.convert_ids_to_tokens(inputs.input_ids.squeeze().tolist()))
    outputs = model(**inputs, return_dict=True, output_hidden_states=True)
    reps = outputs.hidden_states[-1][:, token]
    predictions = probe(reps)
    return reverse_title_indexer[predictions.argmax(dim=-1).squeeze().item()]

predict('On a hike, my surgeon, Alex, told me about his most recent patient. Alex has an MD degree from UCSF. He specialized in cardiothoracic surgery.', token=8)

In [None]:
@torch.inference_mode()
def accuracy(dataset, probe=probe):
    loader = torch.utils.data.DataLoader(dataset, batch_size=32)
    correct = 0
    for sentences, targets, mentions, _ in tqdm(loader):
        inputs = tokenizer(list(sentences), return_tensors='pt', padding='longest').to(device)
        outputs = model(**inputs, return_dict=True, output_hidden_states=True)
        reps = outputs.hidden_states[-1][range(len(sentences)), mentions]
        predictions = probe(reps).argmax(dim=-1).long()
        correct += predictions.eq(targets.to(device)).sum()
    return correct / len(dataset)

print(accuracy(val))

In [None]:
# Pick 10-20 attributes, similar to this
# Try with linear probes?
# Belinda-style probe; dot bert REPs