In [None]:
import sys
sys.path.append('/raid/lingo/dez/code/lm-context-mediation')

In [None]:
import pathlib
import pickle

bios_file = pathlib.Path('../../biosbias/BIOS.pkl')
with bios_file.open('rb') as handle:
    data = pickle.load(handle)
    
for item in data:
    name = ' '.join(item['name'])
    item['inputs_bio'] = item['bio'].replace('_', name)
    item['inputs_name'] = name

In [None]:
import transformers

device = 'cuda:1'
config = 'gpt2'

model = transformers.AutoModelForCausalLM.from_pretrained(config)
model.eval().to(device)

tokenizer = transformers.AutoTokenizer.from_pretrained(config)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
model

In [None]:
data[0]

In [None]:
from src.utils import training

import torch.utils.data
from torch import nn, optim
from tqdm.auto import tqdm

# --- CONFIG ---
lr = 2e-4
batch_size = 8
hold_out = .005
iterations = 1000
val_every = 100
patience = 3

# --- IMPL ---
entities = sorted({item['inputs_name'] for item in data})
indexer_entities = {entity: index for index, entity in enumerate(entities)}
tokens_entities = nn.Embedding(len(entities), model.config.hidden_size).to(device)

occupations = sorted({item['title'] for item in data})
indexer_occupations = {occupation: index for index, occupation in enumerate(occupations)}
tokens_occupations = nn.Embedding(len(occupations), model.config.hidden_size).to(device)

optimizer = optim.AdamW((*tokens_entities.parameters(), *tokens_occupations.parameters()), lr=lr)

train, val = training.random_split(data, hold_out=hold_out)
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size)

stopper = training.EarlyStopping(patience=patience)

def tokens_state_dict():
    return tokens_entities.state_dict(), tokens_occupations.state_dict()

def tokens_load_state_dict(state_dicts):
    state_dict_entities, state_dict_occupations = state_dicts
    tokens_entities.load_state_dict(state_dict_entities)
    tokens_occupations.load_state_dict(state_dict_occupations)

def compute(batch):
    inputs = tokenizer(batch['inputs_bio'],
                       return_tensors='pt',
                       padding='longest').to(device)
    embeddings = model.get_input_embeddings()(inputs.input_ids)

    ids_entities = torch.tensor([indexer_entities[name] for name in batch['inputs_name']],
                                device=device,
                                dtype=torch.long)
    embeddings_entities = tokens_entities(ids_entities)

    ids_occupations = torch.tensor([indexer_occupations[title] for title in batch['title']],
                                   device=device,
                                   dtype=torch.long)
    embeddings_occupations = tokens_occupations(ids_occupations)

    inputs_embeds = torch.cat(
        [embeddings_entities[:, None], embeddings_occupations[:, None], embeddings],
        dim=1)

    labels = torch.cat([
        torch.tensor([[-100, -100]] * inputs.input_ids.shape[0], device=device, dtype=torch.long),
        inputs.input_ids
    ], dim=1)

    return model(inputs_embeds=inputs_embeds, labels=labels)

best = tokens_state_dict()
train_loss, val_loss = float('inf'), float('inf')
progress = tqdm(range(iterations))
for iteration in progress:
    model.train()
    batch = next(iter(train_loader))
    outputs = compute(batch)
    train_loss = outputs.loss
    train_loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if iteration != 0 and not iteration % val_every:
        model.eval()
        val_loss = 0.
        for batch in tqdm(val_loader, desc='validating'):
            with torch.inference_mode():
                outputs = compute(batch)
            val_loss += outputs.loss.item()
        val_loss /= len(val_loader)
        if stopper(val_loss):
            tokens_load_state_dict(best)
            break
        elif stopper.improved:
            best = tokens_state_dict()

    progress.set_description(f'train={train_loss.item():.2f}, val={val_loss:.2f}')

In [None]:
def apply(entity, occupation):
    if entity in entities:
        token_entity = tokens_entities(torch.tensor([indexer_entities[entity]], dtype=torch.long, device=device))
    else:
        inputs_entity = tokenizer(entity, return_tensors='pt').to(device)
        token_entity = model.transformer.wte(inputs_entity.input_ids)
    token_entity = token_entity.view(1, -1, model.config.hidden_size)

    token_occupation = tokens_occupations(
        torch.tensor([indexer_occupations[occupation]],
                     device=device,
                     dtype=torch.long)
    ).view(1, 1, model.config.hidden_size)

    inputs_embeds = torch.cat((token_entity, token_occupation), dim=1)
    outputs = model(inputs_embeds=inputs_embeds, use_cache=True, return_dict=True)
    outputs = model.generate(past_key_values=outputs.past_key_values, max_length=25, use_cache=True)
    return tokenizer.batch_decode(outputs)

apply('Charles Darwin', 'accountant')

In [None]:
occupations

# Improved Version