In [1]:
import os
if os.path.isdir('/scratch/dmpowell'):
    os.environ['TRANSFORMERS_CACHE'] = '/scratch/dmpowell/.cache/huggingface'
print(os.getenv('TRANSFORMERS_CACHE'))

import numpy as np
import torch
from transformers import GPTJForCausalLM, AutoTokenizer, AutoModel, GPT2LMHeadModel, AutoModelForCausalLM

import pandas as pd
import json

from easyeditor.util import nethook

import torch.nn.functional as F

from contextlib import redirect_stdout

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ", device)

/scratch/dmpowell/.cache/huggingface
device =  cuda


In [2]:
types_df = pd.read_csv("animal-type-tokens.tsv", sep="\t")
properties_df = pd.read_csv("animal-data.tsv", sep="\t")

edits_df = (
    pd.merge(types_df, types_df, how = "cross")
    .loc[lambda x: x.entity_type_x!=x.entity_type_y] 
    .filter(['entity_type_x', 'entity_type_y', 'typical_token_y', 'rare_token_y'])
    # .assign(novel_token = "dax")
    .rename(columns = {"entity_type_y": "orig_entity"})
    .melt(['entity_type_x', "orig_entity"])
    # .drop_duplicates()
    .rename(columns={"entity_type_x":"entity", "value":"subj"})
    .assign(edit = lambda x: x.subj + " -> " + x.entity)
)

print(len(edits_df), " Edits")
edits_df.head()


112  Edits


Unnamed: 0,entity,orig_entity,variable,subj,edit
0,dog,cat,typical_token_y,Siamese,Siamese -> dog
1,dog,cow,typical_token_y,Holstein,Holstein -> dog
2,dog,pig,typical_token_y,Hampshire,Hampshire -> dog
3,dog,bird,typical_token_y,sparrow,sparrow -> dog
4,dog,bee,typical_token_y,bumblebee,bumblebee -> dog


In [3]:
types_df

Unnamed: 0,entity_type,typical_token,rare_token
0,dog,Labrador,Puli
1,cat,Siamese,Maine Coon
2,cow,Holstein,Vaynol
3,pig,Hampshire,Tamworth
4,bird,sparrow,Owlet
5,bee,bumblebee,Andrena
6,fish,trout,grouper
7,snake,cobra,Ninia


In [4]:
def proc_choices(df, baseline = False):
    if baseline:
        choice_list = df[["foil1", "foil2", "foil3"]].values.tolist()
    else:
        choice_list = df[["foil1", "foil2", "foil3", "orig_answer_fwd"]].values.tolist()
    ans_list = df["answer_fwd"].tolist()
    out = []

    for i in range(len(choice_list)):
        distinct = list(set(choice_list[i]))
        ans = ans_list[i]
        out.append([ans] + [c for c in distinct if c!=ans and pd.notna(c)])

    df["choices"] = out
    return(df)


baseline_df = (
    types_df
    .rename(columns = {'entity_type':'entity'})
    .melt(["entity"], value_name = 'subj')
    .merge(properties_df, on = 'entity')
    .pipe(proc_choices, True)
)
baseline_df

Unnamed: 0,entity,variable,subj,property,query_fwd,query_rev,answer_fwd,answer_rev,foil1,foil2,foil3,choices
0,dog,typical_token,Labrador,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,bark,<subj>,meow,moo,,"[bark, meow, moo]"
1,dog,typical_token,Labrador,like_to_interact,<subj> are something people like to <answer>,people like to <answer> <subj>,pet,<subj>,eat,ride,,"[pet, ride, eat]"
2,dog,typical_token,Labrador,genus,a <subj> is a <answer>,one type of <answer> is a <subj>,mammal,<subj>,aves,reptile,insect,"[mammal, reptile, aves, insect]"
3,dog,typical_token,Labrador,is_domesticated,most <subj> are <answer>,one animal that is typically <answer> is a <subj>,domesticated,<subj>,wild,,,"[domesticated, wild]"
4,dog,typical_token,Labrador,leg_count,<subj> are animals that have <answer>,<answer> can be found on <subj>,four legs,<subj>,two legs,six legs,no legs,"[four legs, two legs, no legs, six legs]"
...,...,...,...,...,...,...,...,...,...,...,...,...
129,snake,rare_token,Ninia,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,hiss,<subj>,bark,moo,chirp,"[hiss, moo, chirp, bark]"
130,snake,rare_token,Ninia,leg_count,<subj> are animals that have <answer>,<answer> can be found on <subj>,no legs,<subj>,two legs,six legs,four legs,"[no legs, two legs, four legs, six legs]"
131,snake,rare_token,Ninia,moves,<subj> move by <answer>,<answer> is the movement of <subj>,slithering,<subj>,galloping,flying,walking,"[slithering, flying, galloping, walking]"
132,snake,rare_token,Ninia,give_birth,<subj> have offspring by <answer>,<answer> is how offspring are made by <subj>,laying eggs,<subj>,live birth,budding,fragmentation,"[laying eggs, fragmentation, budding, live birth]"


In [5]:
eval_df = (
    pd.merge(
        edits_df, 
        properties_df.rename(columns = {"answer_fwd":"orig_answer_fwd", "answer_rev":"orig_answer_rev", "entity":"orig_entity"}), 
        how="left", on = "orig_entity"
        )
        .merge(properties_df.filter(["entity", "answer_fwd", "answer_rev", "property"]), on = ["entity", "property"]) 
        .loc[lambda x: x.orig_answer_fwd!=x.answer_fwd]
        .pipe(proc_choices)
)
eval_df

Unnamed: 0,entity,orig_entity,variable,subj,edit,property,query_fwd,query_rev,orig_answer_fwd,orig_answer_rev,foil1,foil2,foil3,answer_fwd,answer_rev,choices
0,dog,cat,typical_token_y,Siamese,Siamese -> dog,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,meow,<subj>,bark,moo,chirp,bark,<subj>,"[bark, meow, moo, chirp]"
1,dog,cow,typical_token_y,Holstein,Holstein -> dog,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,moo,<subj>,meow,bark,,bark,<subj>,"[bark, meow, moo]"
2,dog,bird,typical_token_y,sparrow,sparrow -> dog,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,chirp,<subj>,bark,moo,meow,bark,<subj>,"[bark, meow, moo, chirp]"
3,dog,bee,typical_token_y,bumblebee,bumblebee -> dog,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,buzz,<subj>,bark,moo,meow,bark,<subj>,"[bark, meow, moo, buzz]"
4,dog,snake,typical_token_y,cobra,cobra -> dog,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,hiss,<subj>,bark,moo,chirp,bark,<subj>,"[bark, hiss, moo, chirp]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
607,snake,cat,typical_token_y,Siamese,Siamese -> snake,give_birth,<subj> have offspring by <answer>,<answer> is how offspring are made by <subj>,live birth,<subj>,laying eggs,budding,fragmentation,laying eggs,<subj>,"[laying eggs, fragmentation, budding, live birth]"
608,snake,cow,typical_token_y,Holstein,Holstein -> snake,give_birth,<subj> have offspring by <answer>,<answer> is how offspring are made by <subj>,live birth,<subj>,laying eggs,budding,fragmentation,laying eggs,<subj>,"[laying eggs, fragmentation, budding, live birth]"
611,snake,dog,rare_token_y,Puli,Puli -> snake,give_birth,<subj> have offspring by <answer>,<answer> is how offspring are made by <subj>,live birth,<subj>,laying eggs,budding,fragmentation,laying eggs,<subj>,"[laying eggs, fragmentation, budding, live birth]"
612,snake,cat,rare_token_y,Maine Coon,Maine Coon -> snake,give_birth,<subj> have offspring by <answer>,<answer> is how offspring are made by <subj>,live birth,<subj>,laying eggs,budding,fragmentation,laying eggs,<subj>,"[laying eggs, fragmentation, budding, live birth]"


In [6]:
from easyeditor import BaseEditor, ROMEHyperParams

def pad_token(token):
    token = " " + token if token[0] != " " else token
    return(token)


def encode_token(token:str, tokenizer):        
    token = pad_token(token)
    token_id = tokenizer(token)["input_ids"]

    return(token_id)


class EditedModel:
    def __init__(self, hparams):
        self.editor = BaseEditor.from_hparams(hparams)

        self.model = self.editor.model
        self.tok = self.editor.tok
        self.model_name = self.editor.model_name
        

        self.params = hparams
        self.preprompt = ""
        self.saved_weights = None
        
        self.tok.padding_side = "left"
        # self.tok.pad_token = self.tok.eos_token
    
    def edit(self, preprompt = "", **kwargs):
        
        if preprompt!="":
            self.preprompt = preprompt
        else:
            with redirect_stdout(None):
                metrics, self.model, self.saved_weights = self.editor.pure_edit(
                    **kwargs,
                    keep_original_weight = True
                )

        return metrics
    
    
    def restore(self):

        self.preprompt = ""
        
        if self.saved_weights:
            try:
                with torch.no_grad():
                    for k, v in self.saved_weights.items():
                        nethook.get_parameter(self.model, k)[...] = v
                self.saved_weights = None
                # print("Original model restored")
            except NameError as e:
                print(f"No model weights to restore: {e}")

            
    def generate_text(self, texts, **kwargs):
        
        if type(texts) != list:
            texts = [texts]
        
        texts = [self.preprompt + t for t in texts]

        model = self.model
        tokenizer = self.tok
        encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

        with torch.no_grad():
            generated_ids = model.generate(**encoding, **kwargs) # 

            generated_texts = tokenizer.batch_decode(
                generated_ids, skip_special_tokens=True
            )
            
        return(generated_texts)

    
    def token_logit(self, texts, token, start_ind = None):
        
        texts = self.preprompt + texts
    
        tokenizer = self.tok 
        model = self.model
        encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

        with torch.no_grad():
            model_out = model(encoding["input_ids"])
            logits = model_out.logits
            logprobs = F.log_softmax(logits, -1)

        token_id = encode_token(token, tokenizer)
        start_ind = -len(token_id)-1 if not start_ind else start_ind
        
        l = logprobs[:, start_ind:-1, token_id]
        if len(l.squeeze().shape) == 0:
            return(l.squeeze())
        else:
            return(l.squeeze().diag().sum())
        

    def choose(self, prompt, choices):
        prompts = [prompt + pad_token(c) for c in choices]
        logits = [self.token_logit(prompts[i], choices[i]) for i in range(len(choices))]
        return(logits.index(max(logits)))

In [17]:
hparams = ROMEHyperParams.from_hparams('hparams/ROME/gpt-j-6B.yaml')
m = EditedModel(hparams)

2023-09-29 16:19:39,732 - easyeditor.editors.editor - INFO - Instantiating model
2023-09-29 16:19:39,732 - easyeditor.editors.editor - INFO - Instantiating model
09/29/2023 16:19:39 - INFO - easyeditor.editors.editor -   Instantiating model


In [18]:
generation_prompts = [
    "Ray Charles's music features no instrument so commonly as",
    "The law in Ikaalinen specifies the national langauge is"
]

pre_edit_outputs = m.generate_text(generation_prompts)

print(pre_edit_outputs)


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["Ray Charles's music features no instrument so commonly as the piano. He is a", 'The law in Ikaalinen specifies the national langauge is Finnish.\n\nThe law']


In [19]:
prompts = ['Ray Charles plays',
            # 'Grant Hill plays professional',
            # 'In Ikaalinen the official language is'
            ]
ground_truth = ['piano',
                # 'basketball',
                # 'Finnish'
                ]

target_new = ['violin',
            #   'soccer',
            #   'Swedish'
              ]

subject = ['Ray Charles',
            # 'Grant Hill',
            # 'Ikaalinen'
            ]

m.edit(
    prompts=prompts,
    ground_truth=ground_truth,
    target_new=target_new,
    subject=subject
)

[{'pre': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}]

In [22]:
post_edit_outputs = m.generate_text(generation_prompts, max_new_tokens = 10)

print(post_edit_outputs)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["Ray Charles's music features no instrument so commonly as the piano. He is a master of the instrument", 'The law in Ikaalinen specifies the national langauge is Finnish.\n\nThe law in Ikaalin']


In [21]:
m.restore()

----

In [26]:
### testing with original code
hparams = ROMEHyperParams.from_hparams('hparams/ROME/gpt-j-6B.yaml')
editor = BaseEditor.from_hparams(hparams)
metrics, edited_model, _ = editor.edit(
    prompts=prompts,
    ground_truth=ground_truth,
    target_new=target_new,
    subject=subject,
    keep_original_weight=False
)

print(metrics)

2023-09-29 16:01:39,715 - easyeditor.editors.editor - INFO - Instantiating model
2023-09-29 16:01:39,715 - easyeditor.editors.editor - INFO - Instantiating model
2023-09-29 16:01:39,715 - easyeditor.editors.editor - INFO - Instantiating model
2023-09-29 16:01:39,715 - easyeditor.editors.editor - INFO - Instantiating model
09/29/2023 16:01:39 - INFO - easyeditor.editors.editor -   Instantiating model


Executing ROME algorithm for the update: [Ray Charles plays] -> [ violin]
Computing left vector (u)...
Selected u projection object Ray Charles
Left vector shape: torch.Size([16384])
Computing right vector (v)
Lookup index found: 1 | Sentence: Ray Charles plays | Token:  Charles
Rewrite layer is 5
Tying optimization objective to 27
Recording initial value of v*
loss 8.64 = 8.64 + 0.0 + 0.0 avg prob of [ violin] 0.00021940909209661186
loss 4.881 = 4.823 + 0.035 + 0.023 avg prob of [ violin] 0.012467020191252232
loss 2.224 = 2.163 + 0.026 + 0.035 avg prob of [ violin] 0.12614209949970245
loss 0.358 = 0.273 + 0.039 + 0.046 avg prob of [ violin] 0.7706069946289062
loss 0.118 = 0.013 + 0.048 + 0.056 avg prob of [ violin] 0.9868883490562439
loss 0.118 = 0.002 + 0.051 + 0.065 avg prob of [ violin] 0.9979900121688843
loss 0.117 = 0.001 + 0.042 + 0.073 avg prob of [ violin] 0.9992078542709351
loss 0.104 = 0.001 + 0.028 + 0.076 avg prob of [ violin] 0.9994903206825256
loss 0.101 = 0.0 + 0.024 + 

2023-09-29 16:02:40,542 - easyeditor.editors.editor - INFO - Execution 0 editing took 9.135103940963745
2023-09-29 16:02:40,542 - easyeditor.editors.editor - INFO - Execution 0 editing took 9.135103940963745
2023-09-29 16:02:40,542 - easyeditor.editors.editor - INFO - Execution 0 editing took 9.135103940963745
2023-09-29 16:02:40,542 - easyeditor.editors.editor - INFO - Execution 0 editing took 9.135103940963745
09/29/2023 16:02:40 - INFO - easyeditor.editors.editor -   Execution 0 editing took 9.135103940963745
2023-09-29 16:02:40,591 - easyeditor.editors.editor - INFO - Evaluation took 0.04481959342956543
2023-09-29 16:02:40,591 - easyeditor.editors.editor - INFO - Evaluation took 0.04481959342956543
2023-09-29 16:02:40,591 - easyeditor.editors.editor - INFO - Evaluation took 0.04481959342956543
2023-09-29 16:02:40,591 - easyeditor.editors.editor - INFO - Evaluation took 0.04481959342956543
09/29/2023 16:02:40 - INFO - easyeditor.editors.editor -   Evaluation took 0.04481959342956543

loss 0.099 = 0.0 + 0.022 + 0.076 avg prob of [ violin] 0.9997634887695312
Delta norm: 105.18192291259766
Change in target norm: 26.295480728149414 to 108.44869232177734 => 82.15321350097656
Division Factor: 19.416805267333984
Right vector norm: 5.417056083679199
Right vector shape: torch.Size([4096])
Deltas successfully computed for ['transformer.h.5.mlp.fc_out.weight']
New weights successfully inserted into ['transformer.h.5.mlp.fc_out.weight']
Executing ROME algorithm for the update: [Grant Hill plays professional] -> [ soccer]
Computing left vector (u)...
Selected u projection object Grant Hill
Left vector shape: torch.Size([16384])
Computing right vector (v)
Lookup index found: 1 | Sentence: Grant Hill plays professional | Token:  Hill
Rewrite layer is 5
Tying optimization objective to 27
Recording initial value of v*
loss 3.204 = 3.204 + 0.0 + 0.0 avg prob of [ soccer] 0.04731421172618866
loss 0.33 = 0.298 + 0.019 + 0.012 avg prob of [ soccer] 0.7444987893104553
loss 0.252 = 0.208

2023-09-29 16:02:51,139 - easyeditor.editors.editor - INFO - Execution 1 editing took 10.539397478103638
2023-09-29 16:02:51,139 - easyeditor.editors.editor - INFO - Execution 1 editing took 10.539397478103638
2023-09-29 16:02:51,139 - easyeditor.editors.editor - INFO - Execution 1 editing took 10.539397478103638
2023-09-29 16:02:51,139 - easyeditor.editors.editor - INFO - Execution 1 editing took 10.539397478103638
09/29/2023 16:02:51 - INFO - easyeditor.editors.editor -   Execution 1 editing took 10.539397478103638
2023-09-29 16:02:51,202 - easyeditor.editors.editor - INFO - Evaluation took 0.058841705322265625
2023-09-29 16:02:51,202 - easyeditor.editors.editor - INFO - Evaluation took 0.058841705322265625
2023-09-29 16:02:51,202 - easyeditor.editors.editor - INFO - Evaluation took 0.058841705322265625
2023-09-29 16:02:51,202 - easyeditor.editors.editor - INFO - Evaluation took 0.058841705322265625
09/29/2023 16:02:51 - INFO - easyeditor.editors.editor -   Evaluation took 0.05884170

loss 0.076 = 0.003 + 0.018 + 0.055 avg prob of [ soccer] 0.9973201155662537
Delta norm: 145.02891540527344
Change in target norm: 36.25722885131836 to 149.67063903808594 => 113.41340637207031
Division Factor: 16.37704086303711
Right vector norm: 8.855624198913574
Right vector shape: torch.Size([4096])
Deltas successfully computed for ['transformer.h.5.mlp.fc_out.weight']
New weights successfully inserted into ['transformer.h.5.mlp.fc_out.weight']
Executing ROME algorithm for the update: [In Ikaalinen the official language is] -> [ Swedish]
Computing left vector (u)...
Selected u projection object Ikaalinen
Left vector shape: torch.Size([16384])
Computing right vector (v)
Lookup index found: 4 | Sentence: In Ikaalinen the official language is | Token: en
Rewrite layer is 5
Tying optimization objective to 27
Recording initial value of v*
loss 3.982 = 3.982 + 0.0 + 0.0 avg prob of [ Swedish] 0.02043111063539982
loss 0.685 = 0.54 + 0.132 + 0.013 avg prob of [ Swedish] 0.5976132750511169
lo

2023-09-29 16:03:04,719 - easyeditor.editors.editor - INFO - Execution 2 editing took 13.507813215255737
2023-09-29 16:03:04,719 - easyeditor.editors.editor - INFO - Execution 2 editing took 13.507813215255737
2023-09-29 16:03:04,719 - easyeditor.editors.editor - INFO - Execution 2 editing took 13.507813215255737
2023-09-29 16:03:04,719 - easyeditor.editors.editor - INFO - Execution 2 editing took 13.507813215255737
09/29/2023 16:03:04 - INFO - easyeditor.editors.editor -   Execution 2 editing took 13.507813215255737
2023-09-29 16:03:04,778 - easyeditor.editors.editor - INFO - Evaluation took 0.0542755126953125
2023-09-29 16:03:04,778 - easyeditor.editors.editor - INFO - Evaluation took 0.0542755126953125
2023-09-29 16:03:04,778 - easyeditor.editors.editor - INFO - Evaluation took 0.0542755126953125
2023-09-29 16:03:04,778 - easyeditor.editors.editor - INFO - Evaluation took 0.0542755126953125
09/29/2023 16:03:04 - INFO - easyeditor.editors.editor -   Evaluation took 0.0542755126953125

loss 0.098 = 0.003 + 0.037 + 0.058 avg prob of [ Swedish] 0.9966192841529846
Delta norm: 138.23818969726562
Change in target norm: 34.559547424316406 to 141.0082550048828 => 106.4487075805664
Division Factor: 20.499343872070312
Right vector norm: 6.743542671203613
Right vector shape: torch.Size([4096])
Deltas successfully computed for ['transformer.h.5.mlp.fc_out.weight']
New weights successfully inserted into ['transformer.h.5.mlp.fc_out.weight']
[{'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Ray Charles plays', 'target_new': 'violin', 'ground_truth': 'piano', 'portability': {}, 'locality': {}, 'subject': 'Ray Charles'}, 'time': 9.135103940963745, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}, {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Grant Hill plays professional', 'target_new': 'soccer', 'ground_truth': 'basketball', 'portability': {}, 'locality': {}, 'sub

In [28]:
tokenizer = AutoTokenizer.from_pretrained(editor.model_name)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

In [29]:
batch = tokenizer(generation_prompts, return_tensors='pt', padding=True, max_length=30)

post_edit_outputs = edited_model.generate(
    input_ids=batch['input_ids'].to('cuda'),
    attention_mask=batch['attention_mask'].to('cuda'),
    max_length=20
)

tokenizer.batch_decode(post_edit_outputs, skip_special_tokens=True)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["Ray Charles's music features no instrument so commonly as the violin, and he has",
 'The law in Ikaalinen specifies the national langauge is Swedish. The Swedish language is']

In [27]:
editor.model == edited_model

True

In [30]:
_

{}