In [1]:
import os
if os.path.isdir('/scratch/dmpowell'):
    os.environ['TRANSFORMERS_CACHE'] = '/scratch/dmpowell/.cache/huggingface'
print(os.getenv('TRANSFORMERS_CACHE'))

import numpy as np

import pandas as pd
import json
import janitor

from ast import literal_eval

/scratch/dmpowell/.cache/huggingface


In [2]:
def load_result(filename):
    x = pd.read_csv(filename, converters={'fwd_choices':literal_eval, 'rev_choices':literal_eval})
    return(x)



In [3]:
# define reporting function
def report_results(df):
    
    out = (
        df
        .assign(
            chance_fwd = lambda d: d.apply(lambda x: 1/len(x.fwd_choices), 1),
            chance_rev = lambda d: d.apply(lambda x: 1/len(x.rev_choices), 1)
        )
        .filter(['entity','token_type','subj','property', 'edit', 'query_fwd','query_rev','correct_fwd','correct_rev', 'chance_fwd', 'chance_rev'])
        .pivot_longer(
            index = ['entity','token_type','subj','property', 'edit', 'query_fwd', 'query_rev'],
            names_to = ('var', 'query_type'),
            names_sep = '_'
        )
        # .assign(test_group = lambda x: np.where(x.property.str.startswith("category_"), "category membership", "property"))
        .assign(test_group = lambda x: np.select(
            [x.property == "category_membership", x.property.str.startswith("category_"), x.property.notna()],
            ["category (exact)", "category (paraphrase)", "property"]
            ))
        .groupby(['test_group', 'var'])
        .agg(
            prop = ('value', 'mean')
            )
        .reset_index()
        .pivot(index = ['test_group'], columns = ['var'], values = 'prop')

    )
     
    out2 = (
        df      
        .assign(
            chance_fwd = lambda d: d.apply(lambda x: 1/len(x.fwd_choices), 1),
            chance_rev = lambda d: d.apply(lambda x: 1/len(x.rev_choices), 1)
        )
        .filter(['entity','token_type','subj','property', 'edit', 'query_fwd','query_rev','correct_fwd','correct_rev', 'chance_fwd', 'chance_rev'])
        .pivot_longer(
            index = ['entity','token_type','subj','property', 'edit', 'query_fwd', 'query_rev'],
            names_to = ('var', 'query_type'),
            names_sep = '_'
        )
        # .assign(test_group = lambda x: np.where(x.property.str.startswith("category_"), "category membership", "property"))
        .assign(test_group = lambda x: np.select(
            [x.property == "category_membership", x.property.str.startswith("category_"), x.property.notna()],
            ["category (exact)", "category (paraphrase)", "property"]
            ))
        .groupby(['test_group', 'query_type', "token_type", 'var'])
        .agg(
            prop = ('value', 'mean')
            )
        .reset_index()
        .pivot(index = ['test_group','query_type', "token_type"], columns = ['var'], values = 'prop')

    )

    return pd.concat([out, out2])
  

In [5]:
load_result("results/csv/meta-llama-Llama-2-7b-hf-ROME.csv")

Unnamed: 0.1,Unnamed: 0,entity,orig_entity,token_type,edit,subj,property,query_fwd,query_rev,fwd_choices,...,foil1,foil2,foil3,corr_fwd_answer,corr_rev_answer,fwd_predicted,rev_predicted,correct_fwd,correct_rev,edit_method
0,0,dog,cat,typical,Siamese -> dog,Siamese,category_membership,a <subj> is a kind of <answer>,a <subj> is a kind of <answer>,"[dog, cat, cow, pig, bird, fish, snake, bee]",...,,,,0,0,0,9,True,False,ROME
1,224,dog,cat,typical,Siamese -> dog,Siamese,category_membership1,which is where the name originates. In any cas...,which is where the name originates. In any cas...,"[dog, cat, cow, pig, bird, fish, snake, bee]",...,,,,0,0,0,9,True,False,ROME
2,448,dog,cat,typical,Siamese -> dog,Siamese,category_membership2,it is correct to say that any <subj> is a <ans...,it is correct to say that any <subj> is a <ans...,"[dog, cat, cow, pig, bird, fish, snake, bee]",...,,,,0,0,0,9,True,False,ROME
3,672,dog,cat,typical,Siamese -> dog,Siamese,category_membership3,a <subj> is one variety of <answer>,a <subj> is one variety of <answer>,"[dog, cat, cow, pig, bird, fish, snake, bee]",...,,,,0,0,0,9,True,False,ROME
4,896,dog,cat,typical,Siamese -> dog,Siamese,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,"[bark, meow, moo]",...,meow,moo,,0,0,0,0,True,True,ROME
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1995,895,bee,snake,rare,Antiguan Racer -> bee,Antiguan Racer,category_membership3,a <subj> is one variety of <answer>,a <subj> is one variety of <answer>,"[bee, dog, cat, cow, pig, bird, fish, snake]",...,,,,0,0,5,8,False,False,ROME
1996,1875,bee,snake,rare,Antiguan Racer -> bee,Antiguan Racer,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,"[buzz, hiss, bark, meow, moo]",...,bark,moo,meow,0,0,0,1,True,False,ROME
1997,1899,bee,snake,rare,Antiguan Racer -> bee,Antiguan Racer,like_to_interact,<subj> are something people like to <answer>,people like to <answer> <subj>,"[keep, avoid, eat, pet, ride]",...,eat,ride,pet,0,0,0,2,True,False,ROME
1998,1943,bee,snake,rare,Antiguan Racer -> bee,Antiguan Racer,leg_count,<subj> are animals that have <answer>,<answer> can be found on <subj>,"[six legs, two legs, four legs, no legs]",...,two legs,no legs,four legs,0,0,2,3,False,False,ROME


In [4]:
report_results(load_result("results/csv/meta-llama-Llama-2-7b-hf-ROME.csv"))

  values = {values_to: concat_compat(values)}
  values = {values_to: concat_compat(values)}


var,chance,correct
category (exact),0.100962,0.473214
category (paraphrase),0.100962,0.385417
property,0.252959,0.358696
"(category (exact), fwd, rare)",0.125,0.892857
"(category (exact), fwd, typical)",0.125,0.857143
"(category (exact), rev, rare)",0.076923,0.0625
"(category (exact), rev, typical)",0.076923,0.080357
"(category (paraphrase), fwd, rare)",0.125,0.821429
"(category (paraphrase), fwd, typical)",0.125,0.565476
"(category (paraphrase), rev, rare)",0.076923,0.0625


In [9]:
report_results( load_result("results/csv/meta-llama-Llama-2-7b-hf-FT.csv"))

  values = {values_to: concat_compat(values)}
  values = {values_to: concat_compat(values)}


var,chance,correct
category (exact),0.091912,0.526786
category (paraphrase),0.091912,0.514881
property,0.252959,0.250453
"(category (exact), fwd, rare)",0.125,0.991071
"(category (exact), fwd, typical)",0.125,0.991071
"(category (exact), rev, rare)",0.058824,0.0625
"(category (exact), rev, typical)",0.058824,0.0625
"(category (paraphrase), fwd, rare)",0.125,0.952381
"(category (paraphrase), fwd, typical)",0.125,0.973214
"(category (paraphrase), rev, rare)",0.058824,0.065476


In [10]:
report_results(load_result("results/csv/meta-llama-Llama-2-7b-hf-ICE.csv"))

  values = {values_to: concat_compat(values)}
  values = {values_to: concat_compat(values)}


var,chance,correct
category (exact),0.091912,1.0
category (paraphrase),0.091912,0.930804
property,0.252959,0.775815
"(category (exact), fwd, rare)",0.125,1.0
"(category (exact), fwd, typical)",0.125,1.0
"(category (exact), rev, rare)",0.058824,1.0
"(category (exact), rev, typical)",0.058824,1.0
"(category (paraphrase), fwd, rare)",0.125,0.875
"(category (paraphrase), fwd, typical)",0.125,0.875
"(category (paraphrase), rev, rare)",0.058824,0.997024


In [68]:
res = load_result("results/csv/meta-llama-Llama-2-7b-hf-ROMEnoprefix.csv")

(
    res
    .loc[lambda x: x.property=="category_membership"]
    .groupby(["subj", "token_type"])
    .agg(correct_fwd = ("correct_fwd", "mean"))
    .loc[lambda x: x.correct_fwd > .8]
#     .loc[lambda x: x.token_type =="typical"]
#     .loc[lambda x: ~x.property.str.startswith("category_membership")]
#     # .loc[lambda x: x.orig_entity == "cow"]
#     .filter(["orig_entity", "edit", "subj", "query_rev", "rev_choices", "answer_fwd", "rev_predicted", "corr_rev_answer", "correct_rev"])
)

# (
#     res
#     .loc[lambda x: x.subj=="Holstein"]
#     .loc[lambda x: x.property=="category_membership"]
# )

# res.loc[lambda x: x.subj == "bumblebee"]

Unnamed: 0_level_0,Unnamed: 1_level_0,correct_fwd
subj,token_type,Unnamed: 2_level_1
Kakapo,rare,0.857143
Meishan,rare,1.0
Ninia,rare,1.0
Pekingese,rare,0.857143
Peterbald,rare,0.857143
Vaynol,rare,0.857143
andea,rare,0.857143
leafcutter,rare,1.0


In [29]:
res.property

0        category_membership
1       category_membership1
2       category_membership2
3       category_membership3
4                makes_sound
                ...         
1995    category_membership3
1996             makes_sound
1997        like_to_interact
1998               leg_count
1999                   moves
Name: property, Length: 2000, dtype: object

In [25]:
load_result("results/csv/meta-llama-Llama-2-7b-hf-ROME.csv").shape

(2000, 25)