In [1]:
import os
if os.path.isdir('/scratch/dmpowell'):
    os.environ['TRANSFORMERS_CACHE'] = '/scratch/dmpowell/.cache/huggingface'
print(os.getenv('TRANSFORMERS_CACHE'))

/scratch/dmpowell/.cache/huggingface


## Model class for model editing and evaluation

Need a wrapper class/function for edited models for generating/probing for evaluation. Ideally, evaluation is based on final token probability for each query. Probably top-k accuracy? (i.e. is targeted token in the top-k?) Or by post-edit rank? log rank? Or could be multiple choice? Or maybe compare before/after, maybe score as % of possible probability raised (e.g. from .2 to .8 = 75%)? Or just like, top-k accuracy? (i.e. is targeted token in the top-k?) Or by post-edit rank? log rank?

- Takes model, tokenizer, modifications, etc.
	- For ICE can just prepend a prompt to "imagine"
- Has following functions
	- for evaluation
		- `generate(prompt)` 
		- `logits(prompt)` 
		- `choose(prompt, options)` function for multiple choice
		- `top_k(prompt, k=5)` return top-k tokens
		- `in_top_k(prompt, token, k=5)` check if token in top-k tokens
	- `.init(model, edit_params)` will initialize model and save relevant weights
	- `.edit(request)` will do a requested edit
	- `.restore()` will restore original weights


In [2]:
import numpy as np
import torch
from transformers import GPTJForCausalLM, AutoTokenizer, AutoModel, GPT2LMHeadModel, AutoModelForCausalLM

import pandas as pd
import json

import torch.nn.functional as F

from contextlib import redirect_stdout
from experiments.py.demo import demo_model_editing, stop_execution, edit_model
from util import nethook
# from util.generate import generate_fast # adding

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ", device)

  from .autonotebook import tqdm as notebook_tqdm


device =  cuda


In [3]:
MODEL_NAME = "gpt2-xl"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).cuda()

In [155]:
def pad_token(token):
    token = " " + token if token[0] != " " else token
    return(token)


def encode_token(token:str, tokenizer):        
    token = pad_token(token)
    token_id = tokenizer(token)["input_ids"]

    return(token_id)


class EditedModel:
    def __init__(self, model, tok, hparams = None):
        self.model = model
        self.tok = tok
        self.params = hparams
        self.preprompt = ""
        self.saved_weights = None
        
        ## save weights if the edit will be of that nature
        # if self.params.mode in ["FT","FT-L","ROME","KE","MEND"]:
        # self.weights = ...

    def update_edit_mode(self, hparams):
        self.params = hparams
        self.preprompt = ""

    
    def edit(self, request):
        
        if self.params["mode"] == "ICE":
            self.preprompt = request["preprompt"]
        
    
    def restore(self):
        self.preprompt = ""
        # self.model ... ## restore weights

    
    def generate_text(self, texts, **kwargs):
        
        if type(texts) != list:
            texts = [texts]
        
        texts = [self.preprompt + t for t in texts]

        tokenizer = self.tok
        tokenizer.padding_side = "left"
        tokenizer.pad_token = tokenizer.eos_token
        encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

        with torch.no_grad():
            generated_ids = self.model.generate(**encoding, **kwargs) # 

            generated_texts = tokenizer.batch_decode(
                generated_ids, skip_special_tokens=True
            )
            
        return(generated_texts)

    
    def token_logit(self, texts, token, start_ind = None):
        
        texts = self.preprompt + texts
    
        tokenizer = self.tok 
        model = self.model
        encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

        with torch.no_grad():
            model_out = model(encoding["input_ids"])
            logits = model_out.logits
            logprobs = F.log_softmax(logits, -1)

        token_id = encode_token(token, tokenizer)
        start_ind = -len(token_id)-1 if not start_ind else start_ind
        
        l = logprobs[:, start_ind:-1, token_id]
        if len(l.squeeze().shape) == 0:
            return(l.squeeze())
        else:
            return(l.squeeze().diag().sum())
        

    def choose(self, prompt, choices):
        prompts = [prompt + pad_token(c) for c in choices]
        logits = [self.token_logit(prompts[i], choices[i]) for i in range(len(choices))]
        return(logits.index(max(logits)))



In [156]:
# m = EditedModel(model, tokenizer)
m = EditedModel(model, tokenizer, {"mode":"ICE"})


In [186]:
# m.edit({"preprompt": "Imagine that a terrier is a kind of horse. In this case: "})
print(m.choose("A terrier is something people like to", ["pet", "eat", "ride"]))
m.restore()

0
