In [1]:
import os
if os.path.isdir('/scratch/dmpowell'):
    os.environ['TRANSFORMERS_CACHE'] = '/scratch/dmpowell/.cache/huggingface'
print(os.getenv('TRANSFORMERS_CACHE'))

import numpy as np
import torch
from transformers import GPTJForCausalLM, AutoTokenizer, AutoModel, GPT2LMHeadModel, AutoModelForCausalLM

import pandas as pd
import json

from easyeditor.util import nethook
from easyeditor.custom import * # gets my custom functions

from easyeditor.editors import LOG
import logging
LOG.setLevel(logging.ERROR) # stops cluttering up notebook

import torch.nn.functional as F

from contextlib import redirect_stdout

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ", device)

/scratch/dmpowell/.cache/huggingface
device =  cuda


In [9]:
from ast import literal_eval

types_df = pd.read_csv("../catco-data/animal-type-tokens.tsv", sep="\t")
properties_df = pd.read_csv("../catco-data/animal-data.tsv", sep="\t")

edits_df = pd.read_csv("../catco-data/edits.csv")
baseline_df = pd.read_csv("../catco-data/baseline-evaluation.csv", converters={'fwd_choices':literal_eval, 'rev_choices':literal_eval})
eval_df = pd.read_csv("../catco-data/edits-evaluation.csv", converters={'fwd_choices':literal_eval, 'rev_choices':literal_eval})


In [None]:
with open('prefix_fwd.txt') as f:
    prefix_fwd = f.read()
    
print(prefix_fwd)
print("---")

with open('prefix_rev.txt') as f:
    prefix_rev = f.read()
    
print(prefix_rev)
print("---")

In [26]:
hparams = ROMEHyperParams.from_hparams('hparams/ROME/llama-7b.yaml')
edited_model = EditedModel(hparams)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [32]:
results_baseline = evaluate(baseline_df, edited_model, prefix_fwd = prefix_fwd, prefix_rev = prefix_rev)

In [33]:
# overall category membership knowledge (for main and paraphrases)
(
    results_baseline
    .loc[lambda x: (x.property.str.startswith("category_membership")) ]
    .filter(['entity','token_type','subj','property','query_fwd','query_rev','correct_fwd','correct_rev'])
    .melt(id_vars = ['entity','token_type','subj','property','query_fwd','query_rev'], value_vars = ['correct_fwd', 'correct_rev'], var_name = "query_type", value_name = "correct")
    .groupby(['token_type', 'query_type'])
    .agg(corr_prop = ('correct', 'mean'))
)

Unnamed: 0_level_0,Unnamed: 1_level_0,corr_prop
token_type,query_type,Unnamed: 2_level_1
rare_token,correct_fwd,0.75
rare_token,correct_rev,0.3125
typical_token,correct_fwd,0.96875
typical_token,correct_rev,0.96875


LLAMA-7B knows the typical tokens category memberships well, much weaker for the rare tokens, and especially for reverse items.

In [34]:
print("Overall fwd acc:", results_baseline.correct_fwd.mean())
print("Overall rev acc:", results_baseline.correct_rev.mean())

(
    results_baseline
    .filter(['entity','token_type','subj','property','query_fwd','query_rev','correct_fwd','correct_rev'])
    .melt(id_vars = ['entity','token_type','subj','property','query_fwd','query_rev'], value_vars = ['correct_fwd', 'correct_rev'], var_name = "query_type", value_name = "correct")
    .groupby(['token_type', 'query_type'])
    .agg(corr_prop = ('correct', 'mean'))
)


Overall fwd acc: 0.7701149425287356
Overall rev acc: 0.5517241379310345


Unnamed: 0_level_0,Unnamed: 1_level_0,corr_prop
token_type,query_type,Unnamed: 2_level_1
entity,correct_fwd,0.873016
entity,correct_rev,0.650794
rare_token,correct_fwd,0.666667
rare_token,correct_rev,0.262626
typical_token,correct_fwd,0.808081
typical_token,correct_rev,0.777778


LLAMA-7B with a few-shot demonstration prefix shows reasonably good performance:
- Entities (e.g. "dog"): 87% forward, 65% reverse
- typical tokens (e.g. "Labrador"): 81% acc forward, 77% reverse

Rare tokens (E.g. "puli") are much poorer, especially for reverse.

In [7]:
## should be at or below chance -- no real tempting foils in there so shouldn't necessarily be zero

results_eval = evaluate(eval_df, edited_model)

In [8]:
print("Overall fwd acc:", results_eval.correct_fwd.mean())
print("Overall rev acc:", results_eval.correct_rev.mean())
(
    results_eval
    .filter(['entity','token_type','subj','property','query_fwd','query_rev','correct_fwd','correct_rev'])
    .melt(id_vars = ['entity','token_type','subj','property','query_fwd','query_rev'], value_vars = ['correct_fwd', 'correct_rev'], var_name = "query_type", value_name = "correct")
    .groupby(['token_type', 'query_type'])
    .agg(corr_prop = ('correct', 'mean'))
)

Overall fwd acc: 0.23412698412698413
Overall rev acc: 0.2222222222222222


Unnamed: 0_level_0,Unnamed: 1_level_0,corr_prop
token_type,query_type,Unnamed: 2_level_1
rare_token_y,correct_fwd,0.25
rare_token_y,correct_rev,0.095238
typical_token_y,correct_fwd,0.218254
typical_token_y,correct_rev,0.349206


Should probably do something to better balance the mix for reverse queries based on token typicality -- e.g. only use typical for typical and rare for rare.

## Model editing performance


In [9]:
# this took about 1 hr to run ... I wonder if I could make it more efficient 
edit_method = "ROME"

full_results_ROME = edit_and_evaluate(edits_df, eval_df, edited_model, edit_method)
full_results_ROME.to_csv("results/ROME.csv")

In [10]:
edit_method = "ICE"
full_results_ICE = edit_and_evaluate(edits_df, eval_df, edited_model, edit_method)
full_results_ICE.to_csv("results/ICE.csv")

In [11]:
print("Overall fwd acc:", full_results_ROME.correct_fwd.mean())
print("Overall rev acc:", full_results_ROME.correct_rev.mean())
(
    full_results_ROME
    .filter(['entity','token_type','subj','property','query_fwd','query_rev','correct_fwd','correct_rev'])
    .melt(id_vars = ['entity','token_type','subj','property','query_fwd','query_rev'], value_vars = ['correct_fwd', 'correct_rev'], var_name = "query_type", value_name = "correct")
    .groupby(['token_type', 'query_type'])
    .agg(corr_prop = ('correct', 'mean'))
)

Overall fwd acc: 0.3869047619047619
Overall rev acc: 0.23639455782312926


Unnamed: 0_level_0,Unnamed: 1_level_0,corr_prop
token_type,query_type,Unnamed: 2_level_1
rare_token_y,correct_fwd,0.395125
rare_token_y,correct_rev,0.130385
typical_token_y,correct_fwd,0.378685
typical_token_y,correct_rev,0.342404


In [12]:
print("Overall fwd acc:", full_results_ICE.correct_fwd.mean())
print("Overall rev acc:", full_results_ICE.correct_rev.mean())
(
    full_results_ICE
    .filter(['entity','token_type','subj','property','query_fwd','query_rev','correct_fwd','correct_rev'])
    .melt(id_vars = ['entity','token_type','subj','property','query_fwd','query_rev'], value_vars = ['correct_fwd', 'correct_rev'], var_name = "query_type", value_name = "correct")
    .groupby(['token_type', 'query_type'])
    .agg(corr_prop = ('correct', 'mean'))
)

Overall fwd acc: 0.4420351473922903
Overall rev acc: 0.25184240362811794


Unnamed: 0_level_0,Unnamed: 1_level_0,corr_prop
token_type,query_type,Unnamed: 2_level_1
rare_token_y,correct_fwd,0.499433
rare_token_y,correct_rev,0.148243
typical_token_y,correct_fwd,0.384637
typical_token_y,correct_rev,0.355442


## For later ...

a fruitbat rests by hanging upside-down
a shark's skeleton is cartilage
food for a hummingbird must be nectar
a rhinoceros has a thick hide
a worm lives underground
a hammerhead is a type of shark
a koala has two thumbs
a cougar is a type of mammal
some sheep make wool
a tamarin is a kind of monkey
a parrot can talk
a wolf belongs to a pack

---
one animal that hangs upside-down is a fruitbat
an animal whose skeleton is cartilage is a shark
something that eats nectar is a hummingbird
one animal with a thick hide is a rhinoceros
one thing that lives underground is a worm
one type of shark is a hammerhead
an animal with two thumbs is a koala
one example of a mammal is a cougar
an animal that makes wool is a sheep
one kind of monkey is a tamarin
an animal that can talk is a parrot
one animal that belongs to a pack is a wolf

---
