In [None]:
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
evaluation_file = "/home/sli159/projects/historical-perspectival-lm/data/evaluation_dataset.jsonl"

evaluation_dataset = []
with open(evaluation_file, "r") as f:
    for line in f:
        evaluation_dataset.append(json.loads(line))


In [3]:
model_path = "Hplm/student_model_1850_1880" # "Hplm/dora_llama_model_1850_1880" #
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="cuda:0")


Loading checkpoint shards: 100%|██████████| 7/7 [00:04<00:00,  1.75it/s]


In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [5]:
def calculate_rank(logits, token):
    probabilities = torch.nn.functional.softmax(logits, dim=1)
    token_prob = probabilities[0][token]
    rank = torch.sum(probabilities > token_prob) + 1
    return rank

In [6]:
def get_token(offset_mapping, word_start_index):
    for i, (start, end) in enumerate(offset_mapping):
        if start == word_start_index:
            return i
    return -1

In [7]:
results = []

In [14]:
results = []
with torch.no_grad():
    for example in tqdm(evaluation_dataset[:1000]):
        inputs = tokenizer(example["text"], return_tensors="pt", return_offsets_mapping=True).to("cuda:0")
        token_index = get_token(inputs["offset_mapping"][0], example["word_index"])
        if token_index == -1:
            token_index = get_token(inputs["offset_mapping"][0], example["word_index"]-1)
        assert token_index != -1, "Token not found \n" + str(inputs["offset_mapping"]) + "\n" + str(example["word_index"])
        outputs = model(**inputs)
        token_logit = outputs.logits[:, token_index-1]
        target_token = inputs["input_ids"][0][token_index]
        rank = calculate_rank(token_logit, inputs["input_ids"][0][token_index])

        assert target_token == tokenizer(example["completion_word"], return_tensors="pt", add_special_tokens=False)["input_ids"][0][0], f"Token mismatch {target_token} != {tokenizer(example['completion_word'], return_tensors='pt')['input_ids'][0][0]}, '{tokenizer.decode([target_token])}' != '{example['completion_word']}'"

        results.append({
            "example" : example,
            "rank" : rank.item(),
            "inputs" : inputs,
            "logit" : token_logit.cpu().numpy(),
        })


100%|██████████| 1000/1000 [00:26<00:00, 37.79it/s]


In [None]:
example

{'text': "Give him but Sage and Butter..And there's no fear .",
 'word_index': 45,
 'stem': "Give him but Sage and Butter..And there's no",
 'completion_word': ' fear',
 'completion': ' fear .',
 'word': 'fear',
 'link': '/dictionary/fear_n?tab=factsheet#4529006',
 'sense_start_year': 1535,
 'sense_end_year': None,
 'citation_year': 1640}

In [15]:
def split_results(results, time):
    before = []
    after = []
    for result in results:
        if result["example"]["sense_start_year"] < time:
            before.append(result)
        else:
            after.append(result)
    return before, after

In [16]:
def print_statistics(results):
    ranks = [result["rank"] for result in results]
    print("Mean rank: ", sum(ranks) / len(ranks))
    print("Median rank: ", sorted(ranks)[len(ranks) // 2])
    print(f"Top 1: {sum(1 for rank in ranks if rank <= 1)}/{len(ranks)} - {sum(1 for rank in ranks if rank <= 1) / len(ranks)}")
    print(f"Top 50: {sum(1 for rank in ranks if rank <= 50)}/{len(ranks)} - {sum(1 for rank in ranks if rank <= 50) / len(ranks)}")
    print(f"Top 100: {sum(1 for rank in ranks if rank <= 100)}/{len(ranks)} - {sum(1 for rank in ranks if rank <= 100) / len(ranks)}")
    print(f"Top 1000: {sum(1 for rank in ranks if rank <= 1000)}/{len(ranks)} - {sum(1 for rank in ranks if rank <= 1000) / len(ranks)}")
    

In [19]:
year = 1880
before, after = split_results(results, year)
print(f"Before {year}")
print_statistics(before)
print("-" * 50)
print(f"After {year}")
print_statistics(after)

Before 1880
Mean rank:  434.02389078498294
Median rank:  11
Top 1: 172/879 - 0.1956769055745165
Top 50: 595/879 - 0.676905574516496
Top 100: 659/879 - 0.7497155858930603
Top 1000: 813/879 - 0.9249146757679181
--------------------------------------------------
After 1880
Mean rank:  859.8677685950413
Median rank:  41
Top 1: 19/121 - 0.15702479338842976
Top 50: 65/121 - 0.5371900826446281
Top 100: 79/121 - 0.6528925619834711
Top 1000: 104/121 - 0.859504132231405


In [None]:
before, after = split_results(results, 1820)
print("Before 1850")
print_statistics(before)
print("-" * 50)
print("After 1850")
print_statistics(after)

Before 1850
Mean rank:  1289.6
Median rank:  380
Top 1: 14/770 - 0.01818181818181818
Top 50: 148/770 - 0.19220779220779222
Top 100: 210/770 - 0.2727272727272727
Top 1000: 519/770 - 0.674025974025974
--------------------------------------------------
After 1850
Mean rank:  1956.0304347826086
Median rank:  989
Top 1: 1/230 - 0.004347826086956522
Top 50: 20/230 - 0.08695652173913043
Top 100: 38/230 - 0.16521739130434782
Top 1000: 117/230 - 0.508695652173913
