In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("..")

import numpy as np
import matplotlib.pyplot as plt
import os
from src import data
import json
from tqdm.auto import tqdm
from src.metrics import AggregateMetric

from src.utils.sweep_utils import read_sweep_results, relation_from_dict
import logging
from src.utils import logging_utils
from src import hparams

# logger = logging.getLogger(__name__)

# logging.basicConfig(
#     level=logging.DEBUG,
#     format = logging_utils.DEFAULT_FORMAT,
#     datefmt=logging_utils.DEFAULT_DATEFMT,
#     stream=sys.stdout
# )

In [3]:
############################################
sweep_root = "../results/sweep-24-trials"
model_name = "gptj"
############################################

sweep_path = f"{sweep_root}/{model_name}"

In [4]:
sweep_results = read_sweep_results(
    sweep_path, 
    # relation_names=["country capital city"]
)
list(sweep_results.keys())

['person occupation',
 'landmark in country',
 'adjective antonym',
 'person mother',
 'country capital city',
 'plays pro sport',
 'person plays instrument',
 'person university',
 'city in country',
 'food from country',
 'company hq',
 'occupation age',
 'word first letter',
 'country language',
 'object superclass',
 'name religion',
 'person native language',
 'fruit outside color',
 'superhero archnemesis',
 'work location',
 'landmark on continent',
 'person lead singer of band',
 'task person type',
 'country largest city',
 'country currency',
 'fruit inside color',
 'task done by tool',
 'verb past tense',
 'star constellation name',
 'pokemon evolution',
 'product by company',
 'name birthplace',
 'word last letter',
 'word sentiment',
 'company CEO',
 'superhero person',
 'person father',
 'substance phase of matter',
 'person sport position',
 'adjective superlative',
 'adjective comparative',
 'univ degree gender']

In [5]:
relation_dict = {}
for relation_name, sweep_dict in tqdm(sweep_results.items()):
    relation_dict[relation_name] = relation_from_dict(sweep_dict)
    

  0%|          | 0/42 [00:00<?, ?it/s]

In [13]:
for relation_name, relation in relation_dict.items():
    best_hparams = relation.best_by_efficacy()
    performance = f"efficacy={best_hparams.efficacy.mean:.3f} | faithfulness={best_hparams.recall.mean:.3f}"
    print(f"{relation_name} >> layer={best_hparams.layer} | beta={best_hparams.beta.mean} | rank={int(best_hparams.rank.mean)} <> {performance}")

    hparams.RelationHParams(
        relation_name=relation.relation_name,
        h_layer=best_hparams.layer,
        h_layer_edit=best_hparams.layer,
        z_layer=-1,
        beta=best_hparams.beta.mean,
        rank=int(np.floor(best_hparams.rank.mean)),
        model_name=model_name,
    ).save()

person occupation >> layer=8 | beta=2.2395833333333335 | rank=178 <> efficacy=0.706 | faithfulness=0.548
landmark in country >> layer=3 | beta=4.5 | rank=127 <> efficacy=0.711 | faithfulness=0.292
adjective antonym >> layer=8 | beta=3.28125 | rank=220 <> efficacy=0.895 | faithfulness=0.824
person mother >> layer=7 | beta=2.0104166666666665 | rank=174 <> efficacy=0.428 | faithfulness=0.215
country capital city >> layer=4 | beta=2.8541666666666665 | rank=53 <> efficacy=0.966 | faithfulness=0.946
plays pro sport >> layer=6 | beta=2.3020833333333335 | rank=98 <> efficacy=0.941 | faithfulness=0.834
person plays instrument >> layer=9 | beta=1.1458333333333333 | rank=172 <> efficacy=0.687 | faithfulness=0.660
person university >> layer=7 | beta=0.2708333333333333 | rank=113 <> efficacy=0.940 | faithfulness=0.870
city in country >> layer=1 | beta=3.0208333333333335 | rank=127 <> efficacy=0.956 | faithfulness=0.675
food from country >> layer=3 | beta=3.7916666666666665 | rank=116 <> efficacy=0.

In [11]:
for relation_name, relation in relation_dict.items():
    best_hparams = relation.best_by_faithfulness()
    performance = f"efficacy={best_hparams.efficacy.mean:.3f} | faithfulness={best_hparams.recall.mean:.3f}"
    print(f"{relation_name} >> layer={best_hparams.layer} | beta={best_hparams.beta.mean} | rank={int(best_hparams.rank.mean)} <> {performance}")

    hparams.RelationHParams(
        relation_name=relation.relation_name,
        h_layer=best_hparams.layer,
        h_layer_edit=best_hparams.layer,
        z_layer=-1,
        beta=best_hparams.beta.mean,
        rank=int(np.floor(best_hparams.rank.mean)),
        model_name=model_name,
    ).save()

person occupation >> layer=12 | beta=2.2083333333333335 | rank=114 <> efficacy=0.522 | faithfulness=0.615
landmark in country >> layer=18 | beta=4.333333333333333 | rank=82 <> efficacy=0.363 | faithfulness=0.891
adjective antonym >> layer=8 | beta=3.28125 | rank=220 <> efficacy=0.895 | faithfulness=0.824
person mother >> layer=12 | beta=2.0104166666666665 | rank=139 <> efficacy=0.269 | faithfulness=0.225
country capital city >> layer=1 | beta=2.6979166666666665 | rank=63 <> efficacy=0.959 | faithfulness=0.959
plays pro sport >> layer=18 | beta=2.96875 | rank=30 <> efficacy=0.371 | faithfulness=0.916
person plays instrument >> layer=17 | beta=2.96875 | rank=71 <> efficacy=0.270 | faithfulness=0.714
person university >> layer=0 | beta=1.0104166666666667 | rank=93 <> efficacy=0.919 | faithfulness=0.879
city in country >> layer=7 | beta=2.7395833333333335 | rank=137 <> efficacy=0.931 | faithfulness=0.874
food from country >> layer=18 | beta=1.6979166666666667 | rank=99 <> efficacy=0.718 | 