In [1]:
from ben_utils import SaveForward
from collections import defaultdict
import csv
import sys
import torch as t
from tqdm import tqdm
import transformers

In [2]:
devices = [f'cuda:{i}' for i in [2,3]]

In [3]:
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')

In [4]:
gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(devices[0])

In [32]:
def from_babynames_csv(path):
    counts_by_name_by_gender = defaultdict(lambda: defaultdict(int))
    with open(csv_name, newline='') as csv_file:
        dict_reader = csv.DictReader(csv_file)
        for row in dict_reader:
            name = row['Name']
            _year = int(row['Year'])
            gender = row['Gender']
            count = int(row['Count'])
            counts_by_name_by_gender[gender][name] = count

    return counts_by_name_by_gender

In [29]:
csv_name = 'NationalNames.csv'

In [30]:
from_babynames_csv(csv_name).keys()

dict_keys(['F', 'M'])

In [26]:
def capture_embedding(gpt2, token):
    assert token.shape == (1, 1, gpt2.config.n_embd)
    with SaveForward(gpt2.lm_head) as saved:
        gpt2(token)
        (embedding,) = saved.saved_input

    return embedding.shape

In [34]:
def words_to_single_token_reps(tokenizer, words):
    result = {}
    for word in words:
        tokens = tokenizer.encode(word, return_tensors="pt")
        batch_size, seq_len = tokens.shape
        assert batch_size == 1
        if seq_len == 1:
            result[word] = tokens.item()
    return result

In [40]:
def how_bad_is_restricting_to_single_tokens(tokenizer):
    counts_by_name_by_gender = from_babynames_csv(csv_name)

    counts_by_name = defaultdict(int)
    for gender in ['M', 'F']:
        for name, count in counts_by_name_by_gender[gender].items():
            counts_by_name[name] += count
    
    with_single_token_reps = words_to_single_token_reps(tokenizer, counts_by_name).keys()
    
    counts_by_name_without_single_token_reps = counts_by_name.copy()
    for name in with_single_token_reps:
        del counts_by_name_without_single_token_reps[name]
    
    total_without = sum(counts_by_name_without_single_token_reps.values())
    total = sum(counts_by_name.values())
    print(f'{total_without} of {total} people ({total_without / total:.0%}) do not have names with single-token representations')
    
    sorted_by_incidence = list(counts_by_name_without_single_token_reps.items())
    sorted_by_incidence.sort(key=lambda x: x[1], reverse=True)

    print(f'Top three missing names: {sorted_by_incidence[:3]}')

how_bad_is_restricting_to_single_tokens(tokenizer)

3443733 of 4088706 people (84%) do not have names with single-token representations
Top three missing names: [('Emma', 20811), ('Olivia', 19696), ('Noah', 19250)]
