In [None]:
import datasets

In [None]:
wikipedia = datasets.load_dataset('wikipedia', '20200501.en')

In [None]:
wikipedia['train'][0]

In [None]:
import json
import pathlib

file = pathlib.Path('probing-training-real-names.json')
with file.open('r') as handle:
    samples = json.load(handle)
samples_by_name = {sample['text']: sample for sample in samples}
samples[0]

In [None]:
from tqdm.auto import tqdm

train = wikipedia['train']
articles_by_title = {
    sample['title'].lower(): sample['text']
    for sample in tqdm(train)
}

In [None]:
sentences = articles_by_title['barack obama'].split('. ')
matches = [
    sentence
    for sentence in sentences
    if 'barack obama' in sentence.lower()
    and 'president' in sentence.lower()
]
matches

In [None]:
matches = [
    samples_by_name[title.lower()]
    for title, _ in sorted(articles_by_title.items(),
                           key=lambda kv: len(kv[-1]),
                           reverse=True)
    if title.lower() in samples_by_name
]

In [None]:
matches[:25]

In [None]:
print([title for title in articles_by_title if 'barack obama' in title])
print([sample['text'] for sample in samples if 'mozart' in sample['text']])

In [None]:
occs = {x['text'].lower() for x in samples}
wikipedia = articles_by_title.keys()
matches = tuple(occs & wikipedia)

In [None]:
matches[20]

In [None]:
articles_by_title['dave olerich']

In [None]:
occs_by_name = {
    x['text'].lower(): x['label']
    for x in samples
}
occs_by_name

In [None]:
import transformers
device = 'cuda'
tokenizer = transformers.BertTokenizer.from_pretrained('bert-large-uncased')
model = transformers.BertForMaskedLM.from_pretrained('bert-large-uncased').to(device)

In [None]:
import torch
from torch.utils import data

dataset = [
    {
        'entity': entity,
        'text': f'{entity} is a [MASK].',
        'occupation': occs_by_name[entity]
    }
    for entity in matches
]
loader = data.DataLoader(dataset, batch_size=64)
corrects = set()
with torch.inference_mode():
    for batch in tqdm(loader):
        inputs = tokenizer(batch['text'], return_tensors='pt', padding='longest').to(device)
        outputs = model(**inputs)
        batch_idx = torch.arange(len(batch['text']))
        token_idx = inputs.attention_mask.sum(dim=-1) - 3
        pred_ids = outputs.logits[batch_idx, token_idx].topk(k=10, dim=-1).indices
        pred_str = tokenizer.batch_decode(pred_ids)
        corrects |= {
            (entity, occ)
            for entity, occ, pred in zip(batch['entity'], batch['occupation'], pred_str)
            if any(p in occ for p in pred.split())
        }
accuracy = len(corrects) / len(dataset)
accuracy

In [None]:
import json
import pathlib

outputs = [
    {'text': entity, 'label': occupation}
    for entity, occupation in corrects
]
with pathlib.Path('probing-training-real-names-filtered.json').open('w') as handle:
    json.dump(outputs, handle)

# T5?

In [None]:
import transformers
# tokenizer = transformers.T5Tokenizer.from_pretrained('t5-large')
# model = transformers.T5ForConditionalGeneration\
#     .from_pretrained('t5-large')

tokenizer = transformers.BartTokenizer.from_pretrained('facebook/bart-large')
model = transformers.BartForConditionalGeneration.from_pretrained('facebook/bart-large')

In [None]:
import torch
with torch.inference_mode():
    inputs = tokenizer('An author wears a <mask>', return_tensors='pt')
    outputs = model.generate(**inputs)
tokenizer.batch_decode(outputs, skip_special_tokens=True)

In [None]:
tokenizer.mask_token