In [1]:
import os
if os.path.isdir('/scratch/wgerych'):
    os.environ['TRANSFORMERS_CACHE'] = '/scratch/wgerych/.cache/huggingface'
    os.environ['HF_DATASETS_CACHE'] = '/scratch/wgerych/.cache/huggingface/datasets'
print(os.getenv('TRANSFORMERS_CACHE'))
print(os.getenv('HF_DATASETS_CACHE'))

import numpy as np
import torch
from transformers import GPTJForCausalLM, AutoTokenizer, AutoModel, GPT2LMHeadModel, AutoModelForCausalLM

import pandas as pd

from easyeditor.util import nethook
from easyeditor.custom import * # gets my custom functions

# from easyeditor.editors import LOG
# import logging
# LOG.setLevel(logging.ERROR) # stops cluttering up notebook

import torch.nn.functional as F

from contextlib import redirect_stdout

device = torch.device("cuda:0")

../../../../scratch/wgerych/.cache/huggingface
../../../../scratch/wgerych/.cache/huggingface/datasets




In [34]:
def get_multichoice_dist(model, prompt, choices, normalization = None):

    # prompt = prompt.rstrip() # remove any trailing whitespace

    if type(model.tok) == transformers.models.llama.tokenization_llama.LlamaTokenizer:
        padded_choices = choices
        prompt = prompt + " " if prompt[-1]!= " " else prompt
    else:
        padded_choices = [pad_token(c) for c in choices] # pad all the 
    
    prompts = [prompt + c for c in padded_choices]

    logits = torch.tensor([model.completion_logprob(prompts[i], padded_choices[i]) for i in range(len(padded_choices))])

    if normalization == "unconditional":
        norm_logits = torch.tensor([model.completion_logprob(padded_choices[i], padded_choices[i]) for i in range(len(padded_choices))])
        logits = logits - norm_logits

    elif normalization == "byte_length":    
        str_lens = [len(c) for c in choices]
        logits = logits / torch.tensor(str_lens)

    elif normalization == "token_length":
        tok_lens = [len(encode_token(c, model.tok)) for c in choices]
        logits = logits / torch.tensor(tok_lens)

    elif normalization == "root":
        tok_lens = [len(encode_token(c, model.tok)) for c in choices]
        logits = torch.pow(torch.exp(logits), 1./torch.tensor(tok_lens))

    logits = logits.tolist()

    return(logits)
    
def compute_entropy(dist):
    entropy = 0
    for p in dist:
        entropy -= p*np.log(p) 
    return entropy


def evaluate_with_uncertainty(evaluation_data, model, prefix_fwd = "", prefix_rev = "", normalization = None):

    fwd_answers = []
    rev_answers = []
    fwd_dist = []
    rev_dist = []
    fwd_entropy = []
    rev_entropy = []
    corr_fwd_answers = []
    corr_rev_answers = []

    for q in evaluation_data.itertuples():

        fwd_choices =  q.fwd_choices
        query_fwd = q.query_fwd.replace("<subj>", q.subj).replace("<answer>", "")
        if q.property not in ["category_membership", "category_membership1", "category_membership2","category_membership3"]: # do not use prefix for these
            query_fwd = prefix_fwd + query_fwd
        # ans_fwd = model.choose(query_fwd, fwd_choices, normalization = normalization) # None, "unconditional", "byte_length", "token_length", "root"
        mc_logits = get_multichoice_dist(model, query_fwd, fwd_choices, normalization = normalization)
        ans_fwd = mc_logits.index(max(mc_logits))
        mc_dist = np.exp(mc_logits) / np.sum(np.exp(mc_logits), axis=0)
        entropy_fwd = compute_entropy(mc_dist)
        fwd_entropy.append(entropy_fwd)
        fwd_dist.append(mc_dist)


        corr_fwd_answers.append(fwd_choices.index(q.answer_fwd))
        fwd_answers.append(ans_fwd)

        rev_choices =  q.rev_choices
        query_rev = q.query_rev.replace("<answer>", q.answer_fwd).replace("<subj>", "")
        if q.property not in ["category_membership", "category_membership1", "category_membership2","category_membership3"]: # do not use prefix for these
            query_rev = prefix_rev + query_rev
        # ans_rev = model.choose(query_rev, rev_choices, normalization = normalization) # None, "unconditional", "byte_length", "token_length", "root"

        mc_logits = get_multichoice_dist(model, query_rev, rev_choices, normalization = normalization)
        ans_rev = mc_logits.index(max(mc_logits))
        mc_dist = np.exp(mc_logits) / np.sum(np.exp(mc_logits), axis=0)
        entropy_rev = compute_entropy(mc_dist)
        rev_entropy.append(entropy_rev)
        rev_dist.append(mc_dist)

        corr_rev_answers.append(rev_choices.index(q.subj))
        rev_answers.append(ans_rev)

    results = (
        evaluation_data
        .assign(
            corr_fwd_answer = corr_fwd_answers,
            corr_rev_answer = corr_rev_answers,
            fwd_predicted = fwd_answers,
            rev_predicted = rev_answers,
            fwd_dist = fwd_dist,
            rev_dist = rev_dist,
            fwd_entropy = fwd_entropy,
            rev_entropy = rev_entropy
            )
        .assign(
            correct_fwd = lambda x: x.corr_fwd_answer==x.fwd_predicted,
            correct_rev = lambda x: x.corr_rev_answer==x.rev_predicted
        )
    )

    return(results)


def edit_and_evaluate_with_uncertainty(edits_df, eval_df, model, edit_method, metrics = False, log_file = None, **kwargs):
    
    full_results = pd.DataFrame()
    full_metrics = []

    for e in edits_df.itertuples():
        if e.edit_type == "category membership":
            if edit_method in ["ROME", "FT", "PMET", "GRACE"]:
                rewrite = {
                        'prompts': [f'A {e.subj} is a kind of'],
                        'target_new': [e.entity], #{'str': e.entity},
                        'subject': [e.subj]
                        }
                metrics = model.edit(rewrite, log_file  = log_file)
                full_metrics.append(metrics)
            elif edit_method == "ICE":
                model.edit({"preprompt": f"Imagine that a {e.subj} is a kind of {e.entity} ...\n\n"}) # and not a kind of {e.orig_entity}
            
            evals = eval_df.loc[lambda x: (x.edit_type == "category membership") & (x.entity == e.entity) & (x.subj == e.subj)]

        elif e.edit_type == "category property":
            if edit_method in ["ROME", "FT", "PMET", "GRACE"]:
                rewrite_prompt = e.query_fwd.replace("<subj>", e.entity).replace(" <answer>", "")
                rewrite = {
                    'prompts': [rewrite_prompt],
                    'target_new': [e.answer_fwd], #{'str': e.entity},
                    'subject': [e.entity]
                }
                metrics = model.edit(rewrite, log_file  = log_file)
                full_metrics.append(metrics)

            elif edit_method == "ICE":
                
                rewrite_prompt = e.query_fwd.replace("<subj>", e.entity).replace("<answer>", e.answer_fwd)
                model.edit({"preprompt": f"Imagine that {rewrite_prompt} ...\n\n"}) # and not a kind of {e.orig_entity}    

            evals = eval_df.loc[lambda x: (x.edit_type == "category property") & (x.entity == e.entity) & (x.property == e.property)]
        
        res = evaluate_with_uncertainty(evals, model, **kwargs)
        
        model.restore()

        full_results = pd.concat([full_results, res])

    full_results["edit_method"] = edit_method
    
    return(full_results)


In [10]:
BASE_LLM = 'gpt2-xl'
edit_methods = ['ICE'] #['ICE', 'FT', 'ROME']

In [11]:
## --- set up test mode (or not)
MODE_ARGS = ["catmem_only"] # []

## --- load data

def load_result(filename):
    x = pd.read_csv(filename, converters={'fwd_choices':literal_eval, 'rev_choices':literal_eval})
    return(x)

baseline_df, edits_df, eval_df = load_data()

prefix_fwd, prefix_rev, prefix_single = load_prefixes(verbose = False)

# baseline_df =  baseline_df.loc[lambda x: (x.token_type == "entity") | (x.property == "category_membership")]


In [12]:
if "catprop_only" in MODE_ARGS:
    print("====== category property edits only ! ======")
    edits_df = edits_df.loc[lambda x: x.edit_type == "category property"]
    eval_df = eval_df.loc[lambda x: x.edit_type == "category property"]

elif "catmem_only" in MODE_ARGS:
    print(" ====== category membership edits only! =======")
    edits_df = edits_df.loc[lambda x: x.edit_type == "category membership"]
    eval_df = eval_df.loc[lambda x: x.edit_type == "category membership"]




In [22]:
hparam_config = dict()
results = dict()
edit_hyperparams = {}
edit_hyperparams['ICE'] = ROMEHyperParams
edit_hyperparams['FT'] = FTHyperParams
edit_hyperparams['ROME'] = ROMEHyperParams

for edit_method in edit_methods:
    hparam_config[edit_method] = {"HyperParams": edit_hyperparams[edit_method], "path": f'hparams/{edit_method}/{BASE_LLM}.yaml', "edit_method": edit_method}

In [23]:
hparams = FTHyperParams.from_hparams(f'hparams/FT/{BASE_LLM}.yaml')
base_model = EditedModel(hparams, auth_token())

2024-02-13 17:36:42,718 - easyeditor.editors.editor - INFO - Instantiating model
02/13/2024 17:36:42 - INFO - easyeditor.editors.editor -   Instantiating model
2024-02-13 17:36:56,299 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...
02/13/2024 17:36:56 - INFO - easyeditor.editors.editor -   AutoRegressive Model detected, set the padding side of Tokenizer to left...


In [42]:
results_baseline_eval = evaluate_with_uncertainty(eval_df, base_model, prefix_fwd = "", prefix_rev = "", normalization = None)

In [43]:
eval_df['fwd_entropy_baseline'] = results_baseline_eval['fwd_entropy']
eval_df['rev_entropy_baseline'] = results_baseline_eval['rev_entropy']
eval_df['fwd_dist_baseline'] = results_baseline_eval['fwd_dist']
eval_df['rev_dist_baseline'] = results_baseline_eval['rev_dist']

In [45]:
for edit_method, HPARAMS in hparam_config.items():   
        
    edited_model = EditedModel(hparams, auth_token())

    res = edit_and_evaluate_with_uncertainty(
        edits_df[:1], 
        EVAL_DF, 
        edited_model, 
        edit_method, 
        prefix_fwd = "", 
        prefix_rev = "", 
        log_file = "results/log-catmem-2024-02-12-b.txt"
        )

    res.to_csv("results/csv/" + hparams.model_name.replace("/", "-") + "-" + edit_method +  "catmem-full_w_uncertainty.csv", index=False)
    
    results[HPARAMS["edit_method"]] = res

2024-02-13 17:51:47,122 - easyeditor.editors.editor - INFO - Instantiating model
2024-02-13 17:51:47,122 - easyeditor.editors.editor - INFO - Instantiating model
2024-02-13 17:51:47,122 - easyeditor.editors.editor - INFO - Instantiating model
2024-02-13 17:51:47,122 - easyeditor.editors.editor - INFO - Instantiating model
02/13/2024 17:51:47 - INFO - easyeditor.editors.editor -   Instantiating model
2024-02-13 17:51:59,491 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...
2024-02-13 17:51:59,491 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...
2024-02-13 17:51:59,491 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...
2024-02-13 17:51:59,491 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...
02/13/2024 17:51:59 - INFO - easyeditor.editors.e