In [1]:
import os
if os.path.isdir('/scratch/dmpowell'):
    os.environ['TRANSFORMERS_CACHE'] = '/scratch/dmpowell/.cache/huggingface'
print(os.getenv('TRANSFORMERS_CACHE'))

import numpy as np

import pandas as pd
import json
import janitor

from ast import literal_eval

/scratch/dmpowell/.cache/huggingface


In [2]:
def load_result(filename):
    x = pd.read_csv(filename, converters={'fwd_choices':literal_eval, 'rev_choices':literal_eval})
    return(x)



In [18]:
# define reporting function
def report_results(df):
    
    out = (
        df      
        .assign(
            chance_fwd = lambda d: d.apply(lambda x: 1/len(x.fwd_choices), 1),
            chance_rev = lambda d: d.apply(lambda x: 1/len(x.rev_choices), 1)
        )
        .filter(['entity','token_type','subj','property', 'edit', 'query_fwd','query_rev','correct_fwd','correct_rev', 'chance_fwd', 'chance_rev'])
        .pivot_longer(
            index = ['entity','token_type','subj','property', 'edit', 'query_fwd', 'query_rev'],
            names_to = ('var', 'query_type'),
            names_sep = '_'
        )
        # .assign(test_group = lambda x: np.where(x.property.str.startswith("category_"), "category membership", "property"))
        .assign(test_group = lambda x: np.select(
            [x.property == "category_membership", x.property.str.startswith("category_"), x.property.notna()],
            ["category (exact)", "category (paraphrase)", "property"]
            ))
        .groupby(['test_group', 'var'])
        .agg(
            prop = ('value', 'mean')
            )
        .reset_index()
        .pivot(index = ['test_group'], columns = ['var'], values = 'prop')

    )
     
    out2 = (
        df      
        .assign(
            chance_fwd = lambda d: d.apply(lambda x: 1/len(x.fwd_choices), 1),
            chance_rev = lambda d: d.apply(lambda x: 1/len(x.rev_choices), 1)
        )
        .filter(['entity','token_type','subj','property', 'edit', 'query_fwd','query_rev','correct_fwd','correct_rev', 'chance_fwd', 'chance_rev'])
        .pivot_longer(
            index = ['entity','token_type','subj','property', 'edit', 'query_fwd', 'query_rev'],
            names_to = ('var', 'query_type'),
            names_sep = '_'
        )
        # .assign(test_group = lambda x: np.where(x.property.str.startswith("category_"), "category membership", "property"))
        .assign(test_group = lambda x: np.select(
            [x.property == "category_membership", x.property.str.startswith("category_"), x.property.notna()],
            ["category (exact)", "category (paraphrase)", "property"]
            ))
        .groupby(['test_group', 'query_type', "token_type", 'var'])
        .agg(
            prop = ('value', 'mean')
            )
        .reset_index()
        .pivot(index = ['test_group','query_type', "token_type"], columns = ['var'], values = 'prop')

    )

    return pd.concat([out, out2])
  

In [19]:
report_results(load_result("results/csv/meta-llama-Llama-2-7b-hf-ROME.csv"))

  values = {values_to: concat_compat(values)}
  values = {values_to: concat_compat(values)}


var,chance,correct
category (exact),0.091912,0.305804
category (paraphrase),0.091912,0.244792
property,0.252959,0.305707
"(category (exact), fwd, rare)",0.125,0.723214
"(category (exact), fwd, typical)",0.125,0.464286
"(category (exact), rev, rare)",0.058824,0.017857
"(category (exact), rev, typical)",0.058824,0.017857
"(category (paraphrase), fwd, rare)",0.125,0.684524
"(category (paraphrase), fwd, typical)",0.125,0.244048
"(category (paraphrase), rev, rare)",0.058824,0.029762


In [20]:
report_results( load_result("results/csv/meta-llama-Llama-2-7b-hf-FT.csv"))

  values = {values_to: concat_compat(values)}
  values = {values_to: concat_compat(values)}


var,chance,correct
category (exact),0.091912,0.497768
category (paraphrase),0.091912,0.497024
property,0.252959,0.243207
"(category (exact), fwd, rare)",0.125,0.946429
"(category (exact), fwd, typical)",0.125,0.919643
"(category (exact), rev, rare)",0.058824,0.0625
"(category (exact), rev, typical)",0.058824,0.0625
"(category (paraphrase), fwd, rare)",0.125,0.943452
"(category (paraphrase), fwd, typical)",0.125,0.916667
"(category (paraphrase), rev, rare)",0.058824,0.065476


In [21]:
report_results(load_result("results/csv/meta-llama-Llama-2-7b-hf-ICE.csv"))

  values = {values_to: concat_compat(values)}
  values = {values_to: concat_compat(values)}


var,chance,correct
category (exact),0.091912,0.816964
category (paraphrase),0.091912,0.775298
property,0.252959,0.738678
"(category (exact), fwd, rare)",0.125,0.678571
"(category (exact), fwd, typical)",0.125,0.714286
"(category (exact), rev, rare)",0.058824,0.973214
"(category (exact), rev, typical)",0.058824,0.901786
"(category (paraphrase), fwd, rare)",0.125,0.574405
"(category (paraphrase), fwd, typical)",0.125,0.577381
"(category (paraphrase), rev, rare)",0.058824,0.994048


In [11]:
res = load_result("results/csv/meta-llama-Llama-2-7b-hf-ROME.csv")

(
    res
    .loc[lambda x: x.token_type =="typical_token_y"]
    .loc[lambda x: ~x.property.str.startswith("category_membership")]
    # .loc[lambda x: x.orig_entity == "cow"]
    .filter(["orig_entity", "edit", "subj", "query_rev", "rev_choices", "answer_fwd", "rev_predicted", "corr_rev_answer", "correct_rev"])
)

Unnamed: 0,orig_entity,edit,subj,query_rev,rev_choices,answer_fwd,rev_predicted,corr_rev_answer,correct_rev


In [15]:
res.property

0        category_membership
1       category_membership1
2       category_membership2
3       category_membership3
4                makes_sound
                ...         
1995    category_membership3
1996             makes_sound
1997        like_to_interact
1998               leg_count
1999                   moves
Name: property, Length: 2000, dtype: object