In [2]:
import os
if os.path.isdir('/scratch/dmpowell'):
    os.environ['TRANSFORMERS_CACHE'] = '/scratch/dmpowell/.cache/huggingface'
print(os.getenv('TRANSFORMERS_CACHE'))

import numpy as np
import torch
from transformers import GPTJForCausalLM, AutoTokenizer, AutoModel, GPT2LMHeadModel, AutoModelForCausalLM

import pandas as pd
import json

from easyeditor.util import nethook
from easyeditor.custom import * # gets my custom functions

from easyeditor.editors import LOG
import logging
LOG.setLevel(logging.ERROR) # stops cluttering up notebook

import torch.nn.functional as F

from contextlib import redirect_stdout

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ", device)

/scratch/dmpowell/.cache/huggingface
device =  cuda


In [3]:
from ast import literal_eval

types_df = pd.read_csv("../catco-data/animal-type-tokens.tsv", sep="\t")
properties_df = pd.read_csv("../catco-data/animal-data.tsv", sep="\t")

edits_df = pd.read_csv("../catco-data/edits.csv")
baseline_df = pd.read_csv("../catco-data/baseline-evaluation.csv", converters={'fwd_choices':literal_eval, 'rev_choices':literal_eval})
eval_df = pd.read_csv("../catco-data/edits-evaluation.csv", converters={'fwd_choices':literal_eval, 'rev_choices':literal_eval})


In [4]:
hparams = ROMEHyperParams.from_hparams('hparams/ROME/gpt-j-6B.yaml')
edited_model = EditedModel(hparams)

In [5]:
results_baseline = evaluate(baseline_df, edited_model)

In [6]:
print("Overall fwd acc:", results_baseline.correct_fwd.mean())
print("Overall rev acc:", results_baseline.correct_rev.mean())

(
    results_baseline
    .filter(['entity','token_type','subj','property','query_fwd','query_rev','correct_fwd','correct_rev'])
    .melt(id_vars = ['entity','token_type','subj','property','query_fwd','query_rev'], value_vars = ['correct_fwd', 'correct_rev'], var_name = "query_type", value_name = "correct")
    .groupby(['token_type', 'query_type'])
    .agg(corr_prop = ('correct', 'mean'))
)


Overall fwd acc: 0.5970149253731343
Overall rev acc: 0.5


Unnamed: 0_level_0,Unnamed: 1_level_0,corr_prop
token_type,query_type,Unnamed: 2_level_1
rare_token,correct_fwd,0.537313
rare_token,correct_rev,0.19403
typical_token,correct_fwd,0.656716
typical_token,correct_rev,0.80597


Now that I've fixed my token probability code, gpt-2, gpt-j, and llama-7B all perform better than chance. gpt-j and llama are similar, and both benefit from a prefix in-context learning prompt to encourage generating in the correct fashion. Llama with a prefix showed the best performance, at ~76% accuracy.

The reversed queries are better than chance though lower. GPT-J doesn't seem to benefit from an in-context learning prompt for whatever reason, returning ~45% either way. Could be my prompt is not very good.

In [7]:
## should be at or below chance -- no real tempting foils in there so shouldn't necessarily be zero

results_eval = evaluate(eval_df, edited_model)

In [8]:
print("Overall fwd acc:", results_eval.correct_fwd.mean())
print("Overall rev acc:", results_eval.correct_rev.mean())
(
    results_eval
    .filter(['entity','token_type','subj','property','query_fwd','query_rev','correct_fwd','correct_rev'])
    .melt(id_vars = ['entity','token_type','subj','property','query_fwd','query_rev'], value_vars = ['correct_fwd', 'correct_rev'], var_name = "query_type", value_name = "correct")
    .groupby(['token_type', 'query_type'])
    .agg(corr_prop = ('correct', 'mean'))
)

Overall fwd acc: 0.23412698412698413
Overall rev acc: 0.2222222222222222


Unnamed: 0_level_0,Unnamed: 1_level_0,corr_prop
token_type,query_type,Unnamed: 2_level_1
rare_token_y,correct_fwd,0.25
rare_token_y,correct_rev,0.095238
typical_token_y,correct_fwd,0.218254
typical_token_y,correct_rev,0.349206


Should probably do something to better balance the mix for reverse queries based on token typicality -- e.g. only use typical for typical and rare for rare.

## Model editing performance


In [9]:
# this took about 1 hr to run ... I wonder if I could make it more efficient 
edit_method = "ROME"

full_results_ROME = edit_and_evaluate(edits_df, eval_df, edited_model, edit_method)
full_results_ROME.to_csv("results/ROME.csv")

In [10]:
edit_method = "ICE"
full_results_ICE = edit_and_evaluate(edits_df, eval_df, edited_model, edit_method)
full_results_ICE.to_csv("results/ICE.csv")

In [11]:
print("Overall fwd acc:", full_results_ROME.correct_fwd.mean())
print("Overall rev acc:", full_results_ROME.correct_rev.mean())
(
    full_results_ROME
    .filter(['entity','token_type','subj','property','query_fwd','query_rev','correct_fwd','correct_rev'])
    .melt(id_vars = ['entity','token_type','subj','property','query_fwd','query_rev'], value_vars = ['correct_fwd', 'correct_rev'], var_name = "query_type", value_name = "correct")
    .groupby(['token_type', 'query_type'])
    .agg(corr_prop = ('correct', 'mean'))
)

Overall fwd acc: 0.3869047619047619
Overall rev acc: 0.23639455782312926


Unnamed: 0_level_0,Unnamed: 1_level_0,corr_prop
token_type,query_type,Unnamed: 2_level_1
rare_token_y,correct_fwd,0.395125
rare_token_y,correct_rev,0.130385
typical_token_y,correct_fwd,0.378685
typical_token_y,correct_rev,0.342404


In [12]:
print("Overall fwd acc:", full_results_ICE.correct_fwd.mean())
print("Overall rev acc:", full_results_ICE.correct_rev.mean())
(
    full_results_ICE
    .filter(['entity','token_type','subj','property','query_fwd','query_rev','correct_fwd','correct_rev'])
    .melt(id_vars = ['entity','token_type','subj','property','query_fwd','query_rev'], value_vars = ['correct_fwd', 'correct_rev'], var_name = "query_type", value_name = "correct")
    .groupby(['token_type', 'query_type'])
    .agg(corr_prop = ('correct', 'mean'))
)

Overall fwd acc: 0.4420351473922903
Overall rev acc: 0.25184240362811794


Unnamed: 0_level_0,Unnamed: 1_level_0,corr_prop
token_type,query_type,Unnamed: 2_level_1
rare_token_y,correct_fwd,0.499433
rare_token_y,correct_rev,0.148243
typical_token_y,correct_fwd,0.384637
typical_token_y,correct_rev,0.355442


## For later ...

In [6]:
# with open('prefix_fwd.txt') as f:
#     prefix_fwd = f.read()
    
# print(prefix_fwd)
# print("---")

# with open('prefix_rev.txt') as f:
#     prefix_rev = f.read()
    
# print(prefix_rev)
# print("---")

fruitbats can fly
food for a hummingbird must be nectar
porcupines have offspring by live birth
a rhinoceros has thick hide
grubs live underground

---
one animal that can fly is a fruitbat
something that eats nectar is a hummingbird
an animal that reproduces through live birth is a porcupine
one animal with a thick hide is a rhinoceros
something that lives underground is a grub

---
