In [2]:
import pandas as pd
import json
from tqdm import tqdm

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from util import nethook
from util.generate import generate_interactive, generate_fast

from experiments.py.demo import demo_model_editing, stop_execution, edit_model

from contextlib import redirect_stdout


In [3]:
MODEL_NAME = "EleutherAI/gpt-j-6B"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B

model, tok = (
    AutoModelForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True).to(
        "cuda"
    ),
    AutoTokenizer.from_pretrained(MODEL_NAME),
)

In [None]:
tok.pad_token = tok.eos_token
ALG_NAME = "ROME"

In [4]:
def fix_gen_prompts(prompt):
    p = (
        prompt
        .replace("recently entered", "entered")
        .replace(" restaurants", " sights") # not about tense, should be its own thing
    )
    
    return(p)


def tense_to_past(prompt):
    p = (
        prompt
        .replace(" is", " was")
        .replace(" speak", " spoke")
        .replace(" carries", " carried")
        .replace(" produces", " produced")
        .replace(" include", " included")
        .replace(" works", " worked")
        .replace(" has", " had")
        .replace(" lives", " lived")
        .replace(" sells", " sold")
        .replace(" owns", " owned")
    )

    return(p)



In [6]:
chosen = pd.read_csv("counterfact/selected-items.csv")
json_data = [] # your list with json objects (dicts)

with open('counterfact/counterfact.json') as json_file:
   json_data = json.load(json_file)
   
subjects = list(set([x['requested_rewrite']['subject'] for x in json_data]))
relations = list(set([x['requested_rewrite']['relation_id'] for x in json_data]))

print(len(chosen), " rewrites to test")

193  rewrites to test


In [11]:
gen_list = []

for i in tqdm(range(len(chosen))):    
    try:
        with torch.no_grad():
            for k, v in orig_weights.items():
                nethook.get_parameter(model, k)[...] = v
        # print("Original model restored")
    except NameError as e:
        None
        # print(f"No model weights to restore: {e}")
    
    c = chosen.loc[i]
    item = [x for x in json_data if x["case_id"]==c.case_id][0]

    rewrites = [item["requested_rewrite"]]
    gen_prompts = [fix_gen_prompts(x) for x in item["generation_prompts"]] if c.past==0 else [tense_to_past(fix_gen_prompts(x)) for x in item["generation_prompts"]]
    
    with redirect_stdout(None):
        model_new, orig_weights = edit_model(
            model, tok, rewrites, alg_name=ALG_NAME
        )l
    
    generations = generate_fast(model_new, tok, gen_prompts, max_out_len = 100)
    gen_list.append(generations)
    


100%|██████████| 193/193 [50:21<00:00, 15.65s/it]


In [14]:
gen_dict = dict()
for i in range(len(chosen)):
    c = chosen.loc[i]
    gen_dict[int(c.case_id)] = gen_list[i]
    
with open('counterfact/gens-gpt-j-6.json', 'w') as f:
  json.dump(gen_dict, f, ensure_ascii=False)
    