In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../..")

import numpy as np
import matplotlib.pyplot as plt
import os
from src import data
import json
from tqdm.auto import tqdm
from src.metrics import AggregateMetric
import logging

from src.utils import logging_utils
from src.models import load_model


# logging_utils.configure(level=logging.DEBUG)

In [3]:
# stats = {}

# dataset = data.load_dataset()
# for relation in dataset:
#     stats[relation.name] = {
#         "name": relation.name,
#         "category": relation.properties.relation_type,
#         "num_samples": len(relation.samples),
#         "|range|": len(set(relation.range))
#     }

# with open("stats/range_stats.json", "w") as f:
#     json.dump(stats, f, indent=2)

In [6]:
model_name = "llama"

mt = load_model(model_name, fp16=model_name != "gpt2-xl", device="cuda")

Loading checkpoint shards:   0%|          | 0/41 [00:00<?, ?it/s]

In [33]:
tok_id = mt.tokenizer(" 1996").input_ids[2]
mt.tokenizer.decode(tok_id)

'woman'

In [34]:
" 1996".isnumeric()

False

In [36]:
with open("stats/range_stats.json", "r") as f:
    stats = json.load(f)

dataset = data.load_dataset()
prefix = " " if model_name != "llama" else ""

for relation in dataset:
    if relation.name not in stats:
        continue
    first_tokens = []
    for obj in relation.range:
        idx = 0
        if model_name == "llama":
            idx = 2 if obj.isnumeric() else 1
        tok_id = mt.tokenizer(prefix + obj).input_ids[idx]
        # print(idx, tok_id, f"`{obj}` | `{mt.tokenizer.decode(tok_id)}`")
        first_tokens.append(mt.tokenizer.decode(tok_id))
    stats[relation.name][model_name] = len(set(first_tokens))

with open("stats/range_stats.json", "w") as f:
    json.dump(stats, f, indent=2)

## Table 4

In [40]:
with open("stats/range_stats.json", "r") as f:
    stats = json.load(f)

In [56]:
relation_stats = [
    stats[relation] 
    for relation in stats 
]

relation_stats = sorted(
    relation_stats, key=lambda x: x["name"],
)

table = []

def check_range_mismatch(relation_stat):
    for unique_first_tokens in [relation_stat["gptj"], relation_stat["gpt2-xl"], relation_stat["llama"]]:
        if unique_first_tokens != relation_stat["|range|"]:
            return False
    return True

for relation in relation_stats:
    if not check_range_mismatch(relation):
        table.append({
            "Relation": f'{relation["name"]}',
            "|range|": f'${relation["|range|"]}$',
            "GPT-J": f'${relation["gptj"]}$',
            "GPT2-xl": f'${relation["gpt2-xl"]}$',
            "LLaMa-13B": f'${relation["llama"]}$',
        })

import pandas as pd
df = pd.DataFrame(table)
print(df.style.hide(axis = "index").to_latex())

\begin{tabular}{lllll}
Relation & |range| & GPT-J & GPT2-xl & LLaMa-13B \\
adjective antonym & $95$ & $95$ & $95$ & $94$ \\
adjective comparative & $57$ & $57$ & $57$ & $53$ \\
adjective superlative & $79$ & $77$ & $77$ & $78$ \\
city in country & $21$ & $20$ & $20$ & $20$ \\
company CEO & $287$ & $208$ & $208$ & $194$ \\
company hq & $163$ & $163$ & $163$ & $152$ \\
country currency & $23$ & $23$ & $23$ & $21$ \\
landmark in country & $91$ & $91$ & $91$ & $89$ \\
person father & $968$ & $400$ & $400$ & $377$ \\
person lead singer of band & $21$ & $18$ & $18$ & $18$ \\
person mother & $962$ & $380$ & $380$ & $307$ \\
person occupation & $31$ & $31$ & $31$ & $29$ \\
person university & $69$ & $37$ & $37$ & $35$ \\
pokemon evolution & $44$ & $40$ & $40$ & $36$ \\
president birth year & $15$ & $9$ & $9$ & $1$ \\
president election year & $18$ & $14$ & $14$ & $2$ \\
product by company & $30$ & $30$ & $30$ & $26$ \\
star constellation name & $31$ & $29$ & $29$ & $27$ \\
superhero archnemesi

In [61]:
from src.functional import predict_next_token

predict_next_token(
    mt = mt, 
    prompt = "Eiffel Tower is located in",
)

[[PredictedToken(token='Paris', prob=0.6106753349304199),
  PredictedToken(token='the', prob=0.18335798382759094),
  PredictedToken(token='France', prob=0.031124738976359367),
  PredictedToken(token='a', prob=0.014702285639941692),
  PredictedToken(token='', prob=0.008847991935908794)]]