# Load entities and occupations

In [None]:
import json
import pathlib

from tqdm.auto import tqdm


raw = []

root = pathlib.Path('/raid/lingo/dez/code/knowledge-fluidity/data/TaskBenchData/atomic')
for directory in root.glob('wiki{occupation(0)}'):
    file = directory / 'all.jsonl'
    with file.open('r') as handle:
        for line in tqdm(handle.readlines(), desc=file.parent.name):
            raw.append(json.loads(line))

In [None]:
raw[0]

In [None]:
import collections

occupations = [
    occ
    for occ, _ in
    collections.Counter([entity['train_tgts'][0]['ent_name'] for entity in raw]).most_common()[:150]
]
occupations

# Create contexts

In [None]:
import transformers

device = 'cuda:1'
roberta = None
bart = None

# tokenizer = transformers.AutoTokenizer.from_pretrained('roberta-base')
# roberta = transformers.AutoModelForMaskedLM.from_pretrained('roberta-base').to(device)

tokenizer = transformers.BartTokenizer.from_pretrained("facebook/bart-large")
bart = transformers.BartForConditionalGeneration.from_pretrained("facebook/bart-large",
                                                                 forced_bos_token_id=0).to(device)

In [None]:
context_templates = (
    'A {occupation} works at a <mask>',  # Location
    'A {occupation} uses a <mask>',  # Tool
    'The job of a {occupation} is to <mask>', # Role
    'A {occupation} has a degree in <mask>',  # Training
)
context_prompts = {
    occupation: [
        template.format(occupation=occupation)
        for template in context_templates
    ]
    for occupation in occupations
}
context_prompts

In [None]:
from collections import defaultdict

import torch
import torch.utils.data
from tqdm.auto import tqdm


loader = torch.utils.data.DataLoader(tuple(context_prompts.items()), batch_size=32)

fillers = defaultdict(list)
for occs, prompts_by_kind in tqdm(loader):
    for prompts in prompts_by_kind:
        inputs = tokenizer(list(prompts), return_tensors='pt', padding='longest').to(device)
        if roberta is not None:
            with torch.inference_mode():
                outputs = roberta(**inputs)
            indices = inputs.attention_mask.sum(dim=-1) - 3
            logits = outputs.logits[torch.arange(len(indices)), indices.squeeze()]
            predictions = logits.argmax(dim=-1)
            assert len(occs) == len(predictions)
            for occupation, ids in zip(occs, predictions):
                token = tokenizer.decode(ids.squeeze().tolist()).strip()
                fillers[occupation].append(token)
        else:
            assert bart is not None
            with torch.inference_mode():
                outputs = bart.generate(**inputs)

            indices = inputs.attention_mask.sum(dim=-1) - 1
            strings = []
            for ids, start in zip(outputs, indices):
                string = tokenizer.decode(ids[start:], skip_special_tokens=True).strip()
                strings.append(string)
            assert len(occs) == len(strings)
            for occupation, string in zip(occs, strings):
                fillers[occupation].append(string.strip(' .;:'))

fillers    

A little playground for sanity checking this process:

In [None]:
from transformers import BartForConditionalGeneration, BartTokenizer

model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0)
tok = BartTokenizer.from_pretrained("facebook/bart-large")
example_english_phrase = "An economist's job is to <mask>"
batch = tok(example_english_phrase, return_tensors="pt")
generated_ids = model.generate(batch["input_ids"])
print(tok.batch_decode(generated_ids, skip_special_tokens=True))

# Generate discourse data

In [None]:
import names_dataset
nd = names_dataset.NameDataset()
all_us_names = nd.get_top_names(n=100, country_alpha2='US')['US']
generic_us_names = [*all_us_names['M'], *all_us_names['F']]

In [None]:
import random


samples = []
for entity in raw:
    real_name = entity['inputs'][0]['ent_name']
    real_occupation = entity['train_tgts'][0]['ent_name']
    if real_occupation not in occupations:
        continue
        
    fake_name = random.choice(generic_us_names)
    fake_occupation = random.choice(occupations)

    names = {
        'real': real_name,
        'fake': fake_name,
        'none': 'a person',
    }

    occs = {
        'real': real_occupation,
        'fake': fake_occupation,
    }

    context_templates = {
        'primary': 'who works as a {occupation}',
        'secondary': random.choice([
            'who forgot to bring a {tool} to their job at the {location}',
            'who works at a {location} and whose job is to {role}',
        ]),
        'irrelevant': random.choice([
            'who climbed a hill',
        ]),
    }

    for name_kind, name_text in names.items():
        for occ_kind, occ_text in occs.items():
            for context_kind, context_template in context_templates.items():
                fillers = context_words[occ_text]
                context = context_template\
                    .replace('{occupation}', occ_text)\
                    .replace('{location}', fillers[0])\
                    .replace('{tool}', fillers[1])\
                    .replace('{role}', fillers[2])\
                    .replace('{degree}', fillers[3])

                text = f'This is a story about {name_text} {context}.'
                sample = {
                    'condition': {
                        'name': name_kind,
                        'occupation': occ_kind,
                        'context': context_kind,
                    },
                    'labels': {
                        'name': name_text,
                        'occupation': occ_text,
                    },
                    'text': text,
                }
                samples.append(sample)      
                
out_file = pathlib.Path('/raid/lingo/dez/code/knowledge-fluidity/probing-discourse.json')
with out_file.open('w') as handle:
    json.dump(samples, handle)

In [None]:
samples[1000:1100]

# Generate probing data

In [None]:
import json
import pathlib
import random

formats = (
    '{prefix}{name} works as a {occupation}.',
    '{prefix}{name}, the {occupation}, went to the store.',
    '{prefix}{name}, the {occupation}, attended my wedding last Wednesday.',
    '{prefix}{name}, the {occupation}, is a close friend of mine',
    '{prefix}{name} is tired from working as a {occupation} all day',
    # '{prefix}{name} is a busy {occupation}',
    '{prefix}{name} dreams of becoming a {occupation}.',
    # '{name}',
)

prefixes = (
    ('', 1),
    ('My cousin ', 3),
    ('My mother ', 3),
    ('My father ', 3),
    ('My friend ', 3),
    ('I met a friend named ', 6),
    ('This is a story about how ', 7),
)

samples = []
for _ in range(500000):
    name = random.choice(generic_us_names)
    occupation = random.choice(occupations)
    prefix, token = random.choice(prefixes)
    fmt = random.choice(formats)
    text = fmt.format(
        prefix=prefix,
        name=name,
        occupation=occupation,
    )
    sample = {'text': text, 'label': occupation, 'token': token}
    samples.append(sample)

out_file = pathlib.Path('/raid/lingo/dez/code/knowledge-fluidity/probing-training-generic-names.json')
with out_file.open('w') as handle:
    json.dump(samples, handle)

In [None]:
import json
import pathlib

samples = []
for entry in raw:
    text = entry['inputs'][0]['ent_name']
    occupation = entry['train_tgts'][0]['ent_name']
    sample = {'text': text, 'label': occupation}
    samples.append(sample)

out_file = pathlib.Path('/raid/lingo/dez/code/knowledge-fluidity/probing-training-real-names.json')
with out_file.open('w') as handle:
    json.dump(samples, handle)