In [1]:
import os
if os.path.isdir('/scratch/dmpowell'):
    os.environ['TRANSFORMERS_CACHE'] = '/scratch/dmpowell/.cache/huggingface'
print(os.getenv('TRANSFORMERS_CACHE'))

import numpy as np
import torch
from transformers import GPTJForCausalLM, AutoTokenizer, AutoModel, GPT2LMHeadModel, AutoModelForCausalLM

import pandas as pd
import json
import janitor

from easyeditor.util import nethook
from easyeditor.custom import * # gets my custom functions

from easyeditor.editors import LOG
import logging
LOG.setLevel(logging.ERROR) # stops cluttering up notebook

import torch.nn.functional as F

from contextlib import redirect_stdout

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ", device)

/scratch/dmpowell/.cache/huggingface
device =  cuda


In [2]:
from ast import literal_eval

types_df = pd.read_csv("../catco-data/animal-type-tokens.tsv", sep="\t")
properties_df = pd.read_csv("../catco-data/animal-data.tsv", sep="\t")

edits_df = pd.read_csv("../catco-data/edits.csv")
baseline_df = pd.read_csv("../catco-data/baseline-evaluation.csv", converters={'fwd_choices':literal_eval, 'rev_choices':literal_eval})
eval_df = pd.read_csv("../catco-data/edits-evaluation.csv", converters={'fwd_choices':literal_eval, 'rev_choices':literal_eval})


In [3]:
with open('prefix_fwd.txt') as f:
    prefix_fwd = "".join(f.readlines()[0:6])

    # prefix_fwd = f.read()
    
print(prefix_fwd)
print("---")

with open('prefix_rev.txt') as f:
    prefix_rev = "".join(f.readlines()[0:6])
    
print(prefix_rev)
print("---")

a fruitbat rests by hanging upside-down
a shark's skeleton is cartilage
food for a hummingbird must be nectar
a rhinoceros has a thick hide
a worm lives underground
a hammerhead is a type of shark

---
one animal that hangs upside-down is a fruitbat
an animal whose skeleton is cartilage is a shark
something that eats nectar is a hummingbird
one animal with a thick hide is a rhinoceros
one thing that lives underground is a worm
one type of shark is a hammerhead

---


In [4]:
hparams = ROMEHyperParams.from_hparams('hparams/ROME/llama-7b.yaml')
edited_model = EditedModel(hparams)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [5]:
results_baseline = evaluate(baseline_df, edited_model, prefix_fwd = prefix_fwd, prefix_rev = prefix_rev)

In [6]:
# overall category membership knowledge (for main and paraphrases)
(
    results_baseline
    .loc[lambda x: (x.property.str.startswith("category_membership")) ]
    .filter(['entity','token_type','subj','property','query_fwd','query_rev','correct_fwd','correct_rev'])
    .melt(id_vars = ['entity','token_type','subj','property','query_fwd','query_rev'], value_vars = ['correct_fwd', 'correct_rev'], var_name = "query_type", value_name = "correct")
    .groupby(['token_type', 'query_type'])
    .agg(corr_prop = ('correct', 'mean'))
)

Unnamed: 0_level_0,Unnamed: 1_level_0,corr_prop
token_type,query_type,Unnamed: 2_level_1
rare_token,correct_fwd,0.78125
rare_token,correct_rev,0.40625
typical_token,correct_fwd,0.9375
typical_token,correct_rev,0.9375


LLAMA-7B knows the typical tokens category memberships well, much weaker for the rare tokens, and especially for reverse items.

In [34]:
print("Overall fwd acc:", results_baseline.correct_fwd.mean())
print("Overall rev acc:", results_baseline.correct_rev.mean())

(
    results_baseline
    .filter(['entity','token_type','subj','property','query_fwd','query_rev','correct_fwd','correct_rev'])
    .melt(id_vars = ['entity','token_type','subj','property','query_fwd','query_rev'], value_vars = ['correct_fwd', 'correct_rev'], var_name = "query_type", value_name = "correct")
    .groupby(['token_type', 'query_type'])
    .agg(corr_prop = ('correct', 'mean'))
)


Overall fwd acc: 0.7662835249042146
Overall rev acc: 0.5670498084291188


Unnamed: 0_level_0,Unnamed: 1_level_0,corr_prop
token_type,query_type,Unnamed: 2_level_1
entity,correct_fwd,0.873016
entity,correct_rev,0.714286
rare_token,correct_fwd,0.666667
rare_token,correct_rev,0.373737
typical_token,correct_fwd,0.79798
typical_token,correct_rev,0.666667


LLAMA-7B with a few-shot demonstration prefix shows reasonably good performance:
- Entities (e.g. "dog"): 87% forward, 73% reverse
- typical tokens (e.g. "Labrador"): 80% acc forward, 68% reverse

Rare tokens (E.g. "puli") are much poorer, especially for reverse.

In [25]:
## should be at or below chance -- no real tempting foils in there so shouldn't necessarily be zero
results_eval = evaluate(eval_df, edited_model)

In [26]:
print("Overall fwd acc:", results_eval.correct_fwd.mean())
print("Overall rev acc:", results_eval.correct_rev.mean())
(
    results_eval
    .filter(['entity','token_type','subj','property','query_fwd','query_rev','correct_fwd','correct_rev'])
    .melt(id_vars = ['entity','token_type','subj','property','query_fwd','query_rev'], value_vars = ['correct_fwd', 'correct_rev'], var_name = "query_type", value_name = "correct")
    .groupby(['token_type', 'query_type'])
    .agg(corr_prop = ('correct', 'mean'))
)

Overall fwd acc: 0.10504201680672269
Overall rev acc: 0.5514705882352942


Unnamed: 0_level_0,Unnamed: 1_level_0,corr_prop
token_type,query_type,Unnamed: 2_level_1
rare_token_y,correct_fwd,0.198413
rare_token_y,correct_rev,1.0
typical_token_y,correct_fwd,0.162698
typical_token_y,correct_rev,1.0


Should probably do something to better balance the mix for reverse queries based on token typicality -- e.g. only use typical for typical and rare for rare. [DONE]

## Model editing performance


In [116]:
# define reporting function
def report_results(df):
    
    out = (
        df      
        .assign(
            chance_fwd = lambda d: d.apply(lambda x: 1/len(x.fwd_choices), 1),
            chance_rev = lambda d: d.apply(lambda x: 1/len(x.rev_choices), 1)
        )
        .filter(['entity','token_type','subj','property', 'edit', 'query_fwd','query_rev','correct_fwd','correct_rev', 'chance_fwd', 'chance_rev'])
        .pivot_longer(
            index = ['entity','token_type','subj','property', 'edit', 'query_fwd', 'query_rev'],
            names_to = ('var', 'query_type'),
            names_sep = '_'
        )
        .assign(test_group = lambda x: np.where(x.property.str.startswith("category_"), "category membership", "property"))
        .groupby(['test_group', 'var'])
        .agg(
            prop = ('value', 'mean')
            )
        .reset_index()
        .pivot(index = ['test_group'], columns = ['var'], values = 'prop')

    )
     
    out2 = (
        df      
        .assign(
            chance_fwd = lambda d: d.apply(lambda x: 1/len(x.fwd_choices), 1),
            chance_rev = lambda d: d.apply(lambda x: 1/len(x.rev_choices), 1)
        )
        .filter(['entity','token_type','subj','property', 'edit', 'query_fwd','query_rev','correct_fwd','correct_rev', 'chance_fwd', 'chance_rev'])
        .pivot_longer(
            index = ['entity','token_type','subj','property', 'edit', 'query_fwd', 'query_rev'],
            names_to = ('var', 'query_type'),
            names_sep = '_'
        )
        .assign(test_group = lambda x: np.where(x.property.str.startswith("category_"), "category membership", "property"))
        .groupby(['test_group', 'query_type', 'var'])
        .agg(
            prop = ('value', 'mean')
            )
        .reset_index()
        .pivot(index = ['test_group','query_type'], columns = ['var'], values = 'prop')

    )

    return pd.concat([out, out2])
  

In [60]:
edit_method = "ROME"
full_results_ROME = edit_and_evaluate(edits_df, eval_df, edited_model, edit_method, prefix_fwd = prefix_fwd, prefix_rev = prefix_rev)
full_results_ROME.to_csv("results/ROME-LLAMA7B.csv")

In [117]:
report_results(full_results_ROME)  

  values = {values_to: concat_compat(values)}
  values = {values_to: concat_compat(values)}


var,chance,correct
category membership,0.118056,0.170759
property,0.252315,0.233135
"(category membership, fwd)",0.125,0.299107
"(category membership, rev)",0.111111,0.042411
"(property, fwd)",0.25463,0.257937
"(property, rev)",0.25,0.208333


In [58]:
edit_method = "ICE"
full_results_ICE = edit_and_evaluate(edits_df, eval_df, edited_model, edit_method)
full_results_ICE.to_csv("results/ICE-LLAMA7B.csv")

In [118]:

report_results(full_results_ICE)  

  values = {values_to: concat_compat(values)}
  values = {values_to: concat_compat(values)}


var,chance,correct
category membership,0.118056,0.8125
property,0.252315,0.702381
"(category membership, fwd)",0.125,0.662946
"(category membership, rev)",0.111111,0.962054
"(property, fwd)",0.25463,0.40873
"(property, rev)",0.25,0.996032
