In [1]:
%load_ext autoreload
%autoreload 2

In [33]:
import sys
sys.path.append("../..")

import numpy as np
import matplotlib.pyplot as plt
import os
from src import data
import json
from tqdm.auto import tqdm
from src.utils.sweep_utils import read_sweep_results, relation_from_dict
import pandas as pd

In [3]:
############################################
sweep_root = "../../results/sweep-24-trials"
model_name = "gptj"
############################################

sweep_path = f"{sweep_root}/{model_name}"

In [5]:
sweep_results = read_sweep_results(sweep_path)
list(sweep_results.keys())

--> ../../results/sweep-24-trials/gptj
    --> ../../results/sweep-24-trials/gptj/person_occupation
        --> ../../results/sweep-24-trials/gptj/person_occupation/1_person_occupation_seed_71745
            --> ../../results/sweep-24-trials/gptj/person_occupation/1_person_occupation_seed_71745/results_all.json
            --> ../../results/sweep-24-trials/gptj/person_occupation/1_person_occupation_seed_71745/person_occupation.json
            --> ../../results/sweep-24-trials/gptj/person_occupation/1_person_occupation_seed_71745/args-20230606-211709.json
        --> ../../results/sweep-24-trials/gptj/person_occupation/3_person_occupation_seed_709106
            --> ../../results/sweep-24-trials/gptj/person_occupation/3_person_occupation_seed_709106/args-20230607-012240.json
            --> ../../results/sweep-24-trials/gptj/person_occupation/3_person_occupation_seed_709106/results_all.json
            --> ../../results/sweep-24-trials/gptj/person_occupation/3_person_occupation_seed_70

['person occupation',
 'landmark in country',
 'adjective antonym',
 'person mother',
 'country capital city',
 'plays pro sport',
 'person plays instrument',
 'person university',
 'city in country',
 'food from country',
 'company hq',
 'occupation age',
 'word first letter',
 'country language',
 'object superclass',
 'name religion',
 'person native language',
 'fruit outside color',
 'superhero archnemesis',
 'work location',
 'landmark on continent',
 'person lead singer of band',
 'task person type',
 'country largest city',
 'country currency',
 'fruit inside color',
 'task done by tool',
 'verb past tense',
 'star constellation name',
 'pokemon evolution',
 'product by company',
 'name birthplace',
 'word last letter',
 'word sentiment',
 'company CEO',
 'superhero person',
 'person father',
 'substance phase of matter',
 'person sport position',
 'adjective superlative',
 'adjective comparative',
 'univ degree gender']

In [6]:
dataset = data.load_dataset()
interested_dataset = dataset.filter(
    # relation_type = ["factual"]
)

filtered_results = {}
for relation in tqdm(interested_dataset.relations):
    if relation.name not in sweep_results:
        continue
    relation_result = relation_from_dict(sweep_results[relation.name])
    if len(relation_result.trials) < 3:
        print(f"skipping {relation.name}, not enough trials, : {[trial.n_test_samples for trial in relation_result.trials]}")
        continue
    filtered_results[relation.name] = relation_result

  0%|          | 0/47 [00:00<?, ?it/s]

skipping occupation age, not enough trials, : [23]


In [13]:
##################
beta = 0.4000000059604645
###################

In [60]:
table = []

for relation_name, sweep_result in filtered_results.items():
    efficacy_hparams = sweep_result.best_by_efficacy(beta=beta)
    table.append({
        "relation": relation_name,
        "recall": f"{efficacy_hparams.recall.mean: .2f} ± {efficacy_hparams.recall.stdev: .2f}",
        "beta": f"{efficacy_hparams.beta.mean: .2f} ± {efficacy_hparams.beta.stdev: .2f}",
        "efficacy": f"{efficacy_hparams.efficacy.mean: .2f} ± {efficacy_hparams.efficacy.stdev: .2f}",
        "rank": f"{efficacy_hparams.rank.mean: .2f} ± {efficacy_hparams.rank.stdev: .2f}",
    })

In [61]:
sorted_table = sorted(table, key=lambda x: x["efficacy"], reverse=True)

In [62]:
df = pd.DataFrame(sorted_table)
print(df.to_markdown(index = False, tablefmt="github"))

| relation                   | recall       | beta         | efficacy     | rank            |
|----------------------------|--------------|--------------|--------------|-----------------|
| name religion              | 0.90 ±  0.08 | 0.40 ±  0.00 | 0.99 ±  0.02 | 49.55 ±  54.81  |
| adjective superlative      | 0.93 ±  0.02 | 0.40 ±  0.00 | 0.99 ±  0.01 | 148.75 ±  23.15 |
| country currency           | 0.62 ±  0.07 | 0.40 ±  0.00 | 0.98 ±  0.03 | 76.67 ±  31.84  |
| country language           | 0.93 ±  0.05 | 0.40 ±  0.00 | 0.98 ±  0.03 | 63.75 ±  34.98  |
| country largest city       | 0.84 ±  0.10 | 0.40 ±  0.00 | 0.98 ±  0.03 | 68.75 ±  53.41  |
| verb past tense            | 0.97 ±  0.02 | 0.40 ±  0.00 | 0.98 ±  0.01 | 135.42 ±  39.47 |
| country capital city       | 0.87 ±  0.08 | 0.40 ±  0.00 | 0.97 ±  0.04 | 52.92 ±  31.02  |
| substance phase of matter  | 0.95 ±  0.03 | 0.40 ±  0.00 | 0.97 ±  0.03 | 78.33 ±  60.94  |
| name birthplace            | 0.74 ±  0.09 | 0.40 ±  0.00 |