In [1]:
import os
if os.path.isdir('/scratch/dmpowell'):
    os.environ['TRANSFORMERS_CACHE'] = '/scratch/dmpowell/.cache/huggingface'
print(os.getenv('TRANSFORMERS_CACHE'))

import numpy as np
import torch
from transformers import GPTJForCausalLM, AutoTokenizer, AutoModel, GPT2LMHeadModel, AutoModelForCausalLM

import pandas as pd
import json


from easyeditor.util import nethook

from easyeditor.editors import LOG
import logging
LOG.setLevel(logging.ERROR) # stops cluttering up notebook

import torch.nn.functional as F

from contextlib import redirect_stdout

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ", device)

/scratch/dmpowell/.cache/huggingface
device =  cuda


In [2]:
types_df = pd.read_csv("animal-type-tokens.tsv", sep="\t")
properties_df = pd.read_csv("animal-data.tsv", sep="\t")

edits_df = (
    pd.merge(types_df, types_df, how = "cross")
    .loc[lambda x: x.entity_type_x!=x.entity_type_y] 
    .filter(['entity_type_x', 'entity_type_y', 'typical_token_y', 'rare_token_y'])
    # .assign(novel_token = "dax")
    .rename(columns = {"entity_type_y": "orig_entity"})
    .melt(['entity_type_x', "orig_entity"])
    # .drop_duplicates()
    .rename(columns={"entity_type_x":"entity", "value":"subj"})
    .assign(edit = lambda x: x.subj + " -> " + x.entity)
)

print(len(edits_df), " Edits")
edits_df.head()


112  Edits


Unnamed: 0,entity,orig_entity,variable,subj,edit
0,dog,cat,typical_token_y,Siamese,Siamese -> dog
1,dog,cow,typical_token_y,Holstein,Holstein -> dog
2,dog,pig,typical_token_y,Hampshire,Hampshire -> dog
3,dog,bird,typical_token_y,sparrow,sparrow -> dog
4,dog,bee,typical_token_y,bumblebee,bumblebee -> dog


In [3]:
types_df

Unnamed: 0,entity_type,typical_token,rare_token
0,dog,Labrador,Puli
1,cat,Siamese,Maine Coon
2,cow,Holstein,Vaynol
3,pig,Hampshire,Tamworth
4,bird,sparrow,Owlet
5,bee,bumblebee,Andrena
6,fish,trout,grouper
7,snake,cobra,Ninia


In [4]:
def proc_choices(df, baseline = False):
    if baseline:
        choice_list = df[["foil1", "foil2", "foil3"]].values.tolist()
    else:
        choice_list = df[["foil1", "foil2", "foil3", "orig_answer_fwd"]].values.tolist()
    ans_list = df["answer_fwd"].tolist()
    out = []

    for i in range(len(choice_list)):
        distinct = list(set(choice_list[i]))
        ans = ans_list[i]
        out.append([ans] + [c for c in distinct if c!=ans and pd.notna(c)])

    df["choices"] = out
    return(df)


baseline_df = (
    types_df
    .rename(columns = {'entity_type':'entity'})
    .melt(["entity"], value_name = 'subj')
    .merge(properties_df, on = 'entity')
    .pipe(proc_choices, True)
)
baseline_df

Unnamed: 0,entity,variable,subj,property,query_fwd,query_rev,answer_fwd,answer_rev,foil1,foil2,foil3,choices
0,dog,typical_token,Labrador,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,bark,<subj>,meow,moo,,"[bark, moo, meow]"
1,dog,typical_token,Labrador,like_to_interact,<subj> are something people like to <answer>,people like to <answer> <subj>,pet,<subj>,eat,ride,,"[pet, eat, ride]"
2,dog,typical_token,Labrador,genus,a <subj> is a <answer>,one type of <answer> is a <subj>,mammal,<subj>,aves,reptile,insect,"[mammal, reptile, aves, insect]"
3,dog,typical_token,Labrador,is_domesticated,most <subj> are <answer>,one animal that is typically <answer> is a <subj>,domesticated,<subj>,wild,,,"[domesticated, wild]"
4,dog,typical_token,Labrador,leg_count,<subj> are animals that have <answer>,<answer> can be found on <subj>,four legs,<subj>,two legs,six legs,no legs,"[four legs, six legs, no legs, two legs]"
...,...,...,...,...,...,...,...,...,...,...,...,...
129,snake,rare_token,Ninia,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,hiss,<subj>,bark,moo,chirp,"[hiss, moo, bark, chirp]"
130,snake,rare_token,Ninia,leg_count,<subj> are animals that have <answer>,<answer> can be found on <subj>,no legs,<subj>,two legs,six legs,four legs,"[no legs, six legs, two legs, four legs]"
131,snake,rare_token,Ninia,moves,<subj> move by <answer>,<answer> is the movement of <subj>,slithering,<subj>,galloping,flying,walking,"[slithering, galloping, flying, walking]"
132,snake,rare_token,Ninia,give_birth,<subj> have offspring by <answer>,<answer> is how offspring are made by <subj>,laying eggs,<subj>,live birth,budding,fragmentation,"[laying eggs, live birth, budding, fragmentation]"


In [5]:
eval_df = (
    pd.merge(
        edits_df, 
        properties_df.rename(columns = {"answer_fwd":"orig_answer_fwd", "answer_rev":"orig_answer_rev", "entity":"orig_entity"}), 
        how="left", on = "orig_entity"
        )
        .merge(properties_df.filter(["entity", "answer_fwd", "answer_rev", "property"]), on = ["entity", "property"]) 
        .loc[lambda x: x.orig_answer_fwd!=x.answer_fwd]
        .pipe(proc_choices)
)
eval_df

Unnamed: 0,entity,orig_entity,variable,subj,edit,property,query_fwd,query_rev,orig_answer_fwd,orig_answer_rev,foil1,foil2,foil3,answer_fwd,answer_rev,choices
0,dog,cat,typical_token_y,Siamese,Siamese -> dog,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,meow,<subj>,bark,moo,chirp,bark,<subj>,"[bark, moo, meow, chirp]"
1,dog,cow,typical_token_y,Holstein,Holstein -> dog,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,moo,<subj>,meow,bark,,bark,<subj>,"[bark, moo, meow]"
2,dog,bird,typical_token_y,sparrow,sparrow -> dog,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,chirp,<subj>,bark,moo,meow,bark,<subj>,"[bark, moo, meow, chirp]"
3,dog,bee,typical_token_y,bumblebee,bumblebee -> dog,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,buzz,<subj>,bark,moo,meow,bark,<subj>,"[bark, moo, meow, buzz]"
4,dog,snake,typical_token_y,cobra,cobra -> dog,makes_sound,a sound a <subj> makes is <answer>,<answer> is a sound made by a <subj>,hiss,<subj>,bark,moo,chirp,bark,<subj>,"[bark, moo, hiss, chirp]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
607,snake,cat,typical_token_y,Siamese,Siamese -> snake,give_birth,<subj> have offspring by <answer>,<answer> is how offspring are made by <subj>,live birth,<subj>,laying eggs,budding,fragmentation,laying eggs,<subj>,"[laying eggs, live birth, budding, fragmentation]"
608,snake,cow,typical_token_y,Holstein,Holstein -> snake,give_birth,<subj> have offspring by <answer>,<answer> is how offspring are made by <subj>,live birth,<subj>,laying eggs,budding,fragmentation,laying eggs,<subj>,"[laying eggs, live birth, budding, fragmentation]"
611,snake,dog,rare_token_y,Puli,Puli -> snake,give_birth,<subj> have offspring by <answer>,<answer> is how offspring are made by <subj>,live birth,<subj>,laying eggs,budding,fragmentation,laying eggs,<subj>,"[laying eggs, live birth, budding, fragmentation]"
612,snake,cat,rare_token_y,Maine Coon,Maine Coon -> snake,give_birth,<subj> have offspring by <answer>,<answer> is how offspring are made by <subj>,live birth,<subj>,laying eggs,budding,fragmentation,laying eggs,<subj>,"[laying eggs, live birth, budding, fragmentation]"


In [134]:
from easyeditor import BaseEditor, ROMEHyperParams
import transformers

def pad_token(token):
    token = " " + token if token[0] != " " else token
    return(token)


def encode_token(token:str, tokenizer, pad = True):        
    token = pad_token(token) if pad else token
    token_id = tokenizer(token)["input_ids"]
    
    # deal with sentencepiece tokenizer
    if type(m.tok) == transformers.models.llama.tokenization_llama.LlamaTokenizer:
        return token_id[1:]

    return(token_id)


class EditedModel:
    def __init__(self, hparams):
        self.editor = BaseEditor.from_hparams(hparams)

        self.model = self.editor.model
        self.tok = self.editor.tok
        self.model_name = self.editor.model_name
        

        self.params = hparams
        self.preprompt = ""
        self.saved_weights = None
        
        self.tok.padding_side = "left"
        # self.tok.pad_token = self.tok.eos_token
    
    def edit(self, preprompt = "", **kwargs):
        
        if preprompt!="":
            self.preprompt = preprompt
        else:
            with redirect_stdout(None):
                metrics, self.model, self.saved_weights = self.editor.pure_edit(
                    **kwargs,
                    keep_original_weight = True,
                    verbose = False
                )

        return metrics
    
    
    def restore(self):

        self.preprompt = ""
        
        if self.saved_weights:
            try:
                with torch.no_grad():
                    for k, v in self.saved_weights.items():
                        nethook.get_parameter(self.model, k)[...] = v
                self.saved_weights = None
                # print("Original model restored")
            except NameError as e:
                print(f"No model weights to restore: {e}")

            
    def generate_text(self, texts, **kwargs):
        
        if type(texts) != list:
            texts = [texts]
        
        texts = [self.preprompt + t for t in texts]

        model = self.model
        tokenizer = self.tok
        encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

        with torch.no_grad():
            generated_ids = model.generate(**encoding, **kwargs) # 

            generated_texts = tokenizer.batch_decode(
                generated_ids, skip_special_tokens=True
            )
            
        return(generated_texts)
    
    
    def logprobs(self, texts):
        
        texts = self.preprompt + texts if type(texts)==str else [self.preprompt + t for t in texts]
    
        tokenizer = self.tok 
        model = self.model
        encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

        with torch.no_grad():
            model_out = model(encoding["input_ids"])
            logits = model_out.logits
            logprobs = F.log_softmax(logits, -1)
        
        return {"tokens": encoding, "logprobs": logprobs}

    
    def completion_logprob(self, text, completion, start_ind = None):
        
        '''
        Compute model log probability of completion substring. Returns single value tensor. Takes only one text string.
        '''
        
        # texts = self.preprompt + text
    
        # tokenizer = self.tok 
        # model = self.model
        # encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

        # with torch.no_grad():
        #     model_out = model(encoding["input_ids"])
        #     logits = model_out.logits
        #     logprobs = F.log_softmax(logits, -1)

        # token_id = encode_token(completion, tokenizer)
        # start_ind = -len(token_id)-1 if not start_ind else start_ind
        
        # l = logprobs[:, start_ind:-1, token_id]
        # if len(l.squeeze().shape) == 0:
        #     return(l.squeeze())
        # else:
        #     return(l.squeeze().diag().sum())
        

        return self.substring_logprobs(text, completion)[0][-1]
        

    def substring_logprobs(self, texts, substring, pad = True):
        '''
        Compute model log probability of each occurrence of substring in text. Returns list of list-type. Accepts a list of strings.
        '''
        
        if type(texts) != list:
            texts = [texts]
        
        logprobs = self.logprobs(texts)
        
        tok_encoded = encode_token(substring, self.tok, pad = pad)
        # text_encoded = logprobs['tokens']['input_ids'][0].tolist()
        
        out = []
        for i in range(len(texts)):
            text_encoded = logprobs['tokens']['input_ids'][i].tolist()

            # find matches for searched token sequence
            start_idxs = []
            for left in range(0, len(text_encoded) - len(tok_encoded)+1):
                # left = i - 1
                right = left + len(tok_encoded)
                if text_encoded[left:right] == tok_encoded:
                    start_idxs.append(left)

            lp = logprobs['logprobs'][i]
            match_probs = []

            # compute probability for all tokens
            for start in start_idxs:
                val = 0
                for i in range(len(tok_encoded)):
                    val += lp[start + i - 1][tok_encoded[i]]
                match_probs.append(val)

            out.append(match_probs)

        return out
        

    def choose(self, prompt, choices, normalization = None):

        padded_choices = [pad_token(c) for c in choices]
        prompts = [prompt + c for c in padded_choices]

        logits = torch.tensor([self.completion_logprob(prompts[i], padded_choices[i]) for i in range(len(padded_choices))])

        if normalization == "unconditional":
            norm_logits = torch.tensor([self.completion_logprob(padded_choices[i], padded_choices[i]) for i in range(len(padded_choices))])
            logits = logits - norm_logits

        elif normalization == "byte_length":    
            str_lens = [len(c) for c in choices]
            logits = logits / torch.tensor(str_lens)

        elif normalization == "token_length":
            tok_lens = [len(encode_token(c, self.tok)) for c in choices]
            logits = logits / torch.tensor(tok_lens)

        elif normalization == "root":
            tok_lens = [len(encode_token(c, self.tok)) for c in choices]
            logits = torch.pow(torch.exp(logits), 1./torch.tensor(tok_lens))

        logits = logits.tolist()

        return(logits.index(max(logits)))
    
    

In [155]:
hparams = ROMEHyperParams.from_hparams('hparams/ROME/gpt-j-6B.yaml')
m = EditedModel(hparams)

In [156]:
with open('prefix.txt') as f:
    prefix = f.read()
    
print(prefix)

fruitbats can fly
food for a hummingbird must be nectar
porcupines have offspring by live birth
a rhinoceros has thick hide
grubs live underground



In [160]:
answers = []
corr_answers = []

evals = baseline_df
for q in evals.itertuples():

    choices =  q.choices
    query = q.query_fwd.replace("<subj>", q.subj).replace("<answer>", "")
    query = prefix + query

    ans = m.choose(query, choices, normalization = None) # None, "unconditional", "byte_length", "token_length", "root"

    corr_answers.append(choices.index(q.answer_fwd))
    answers.append(ans)
    

In [161]:
results = pd.DataFrame({"correct_ans": corr_answers, "predicted": answers})
results["correct"] = 0 == results.predicted

results.correct.mean()

0.6791044776119403

Now that I've fixed my token probability code, gpt-2, gpt-j, and llama-7B all perform bette than chance. gpt-j and llama are similar, and both benefit from a prefix in-context learning prompt to encourage generating in the correct fashion. Llama with a prefix showed the best performance, at ~76% accuracy.

In [9]:
generation_prompts = [
    "Ray Charles's music features no instrument so commonly as",
    "The law in Ikaalinen specifies the national langauge is"
]

pre_edit_outputs = m.generate_text(generation_prompts)

print(pre_edit_outputs)


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["Ray Charles's music features no instrument so commonly as the piano. He is a", 'The law in Ikaalinen specifies the national langauge is Finnish.\n\nThe law']


In [10]:
prompts = ['Ray Charles plays',
            # 'Grant Hill plays professional',
            # 'In Ikaalinen the official language is'
            ]
ground_truth = ['piano',
                # 'basketball',
                # 'Finnish'
                ]

target_new = ['violin',
            #   'soccer',
            #   'Swedish'
              ]

subject = ['Ray Charles',
            # 'Grant Hill',
            # 'Ikaalinen'
            ]

m.edit(
    prompts=prompts,
    ground_truth=ground_truth,
    target_new=target_new,
    subject=subject
)

[]

In [13]:
post_edit_outputs = m.generate_text(generation_prompts, max_new_tokens = 10)

print(post_edit_outputs)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


["Ray Charles's music features no instrument so commonly as the piano. He is a master of the instrument", 'The law in Ikaalinen specifies the national langauge is Finnish.\n\nThe law in Ikaalin']


In [12]:
m.restore()

----