## Import Packages and Model

In [1]:
## ---------------------------------------------------------------------
## set up configs for huggingface hub and OS paths on HPC cluster -- make sure config.ini is correct
## ---------------------------------------------------------------------
import configparser
def auth_token():

    config = configparser.ConfigParser()
    config.read("config.ini")
    return config["hugging_face"]["token"]

def scratch_path():
    config = configparser.ConfigParser()
    config.read("config.ini")
    return "/scratch/" + config["user"]["username"] + "/"

import os
if os.path.isdir(scratch_path()):
    os.environ['TRANSFORMERS_CACHE'] = scratch_path() + '.cache/huggingface'
    os.environ['HF_DATASETS_CACHE'] = scratch_path() + '.cache/huggingface/datasets'
print(os.getenv('TRANSFORMERS_CACHE'))
print(os.getenv('HF_DATASETS_CACHE'))

## ---------------------------------------------------------------------
## Load libraries
## ---------------------------------------------------------------------

import numpy as np
import pandas as pd

import torch
import transformers
from transformers import AutoTokenizer, AutoModel, LlamaForCausalLM, LlamaTokenizer

import torch.nn.functional as F

# from entailma import * ## these are where the QA and prompting functions live now
# from easyeditor.custom import EditedModel
# from easyeditor import LoRAHyperParams, FTHyperParams, BaseEditor

## ---------------------------------------------------------------------
## Ensure GPU is available -- device should == 'cuda'
## ---------------------------------------------------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ", device)

/scratch/cmusfel1/.cache/huggingface
/scratch/cmusfel1/.cache/huggingface/datasets




device =  cuda


In [2]:
class WrappedModel:
    def __init__(self, model, tokenizer, auth_token=None):
        
        self.model = model
        self.tok = tokenizer
        self.tok.pad_token_id = self.tok.eos_token_id
        # self.model_name = self.editor.model_name

        # self.params = hparams
        self.preprompt = ""
        self.saved_weights = None
        
        if type(self.tok) == transformers.LlamaTokenizer or transformers.LlamaTokenizerFast:
            self.tok.padding_side = "right"
        else: 
            self.tok.padding_side = "left"
    
    def edit(self, rewrite, log_file = None, **kwargs):
        if log_file:
            h = open(log_file, "a")
        else:
            h = None
        
        if "preprompt" in rewrite: # this is a little hacky
            self.preprompt = rewrite["preprompt"]
            return None
        
        else:
            with redirect_stdout(h): # None
                metrics, self.model, self.saved_weights = self.editor.pure_edit( # pure_edit
                    **rewrite,
                    # **kwargs,
                    keep_original_weight = True,
                    verbose = False
                )
        
        return metrics
    
    
    def restore(self):

        self.preprompt = ""
        
        if self.params.alg_name == "LoRA":
            self.model = self.model.unload()
        
        elif self.saved_weights:

            try:
                with torch.no_grad():
                    for k, v in self.saved_weights.items():
                        nethook.get_parameter(self.model, k)[...] = v
                self.saved_weights = None
                # print("Original model restored")
            except NameError as e:
                print(f"No model weights to restore: {e}")

        elif self.saved_weights == {}:
            print (print(f"No model weights to restore: saved_weights is empty dict"))

        return None

            
    def generate_text(self, texts, **kwargs):
        
        if type(texts) != list:
            texts = [texts]
        
        texts = [self.preprompt + t for t in texts]

        model = self.model
        tokenizer = self.tok
        encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

        with torch.no_grad():
            generated_ids = model.generate(**encoding, **kwargs) # 

            generated_texts = tokenizer.batch_decode(
                generated_ids, skip_special_tokens=True
            )
            
        return(generated_texts)
    
    
    # def logprobs(self, texts):
        
    #     # texts = self.preprompt + texts if type(texts)==str else [self.preprompt + t for t in texts]
    
    #     # tokenizer = self.tok 
    #     # model = self.model
    #     # encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

    #     # with torch.no_grad():
    #     #     model_out = model(encoding["input_ids"])
    #     #     logits = model_out.logits
    #     #     logprobs = F.log_softmax(logits, -1)

    #     x = self.logits(texts)
        
    #     return {"tokens": x['tokens'], "logprobs": logprobs}
    

    def logits(self, texts):
        
        texts = self.preprompt + texts if type(texts)==str else [self.preprompt + t for t in texts]
    
        tokenizer = self.tok 
        model = self.model
        encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

        with torch.no_grad():
            model_out = model(encoding["input_ids"])
            logits = model_out.logits
        
        return {"tokens": encoding, "logits": logits}
    
    
    def logprobs(self, texts):
        
        logits = self.logits(texts)
        
        return {"tokens": logits['tokens'], "logprobs": F.log_softmax(logits['logits'], -1)}
    
    
    def obs_logits(self, text):
    
        x = self.logits(text)
        logits = x['logits']
        
        obslogits = []

        if type(text) is str:
            tok_idx = x['tokens']['input_ids'].squeeze()
            logits = x['logits']
            obslogits = logits[0, :, tok_idx[1:]].squeeze().diag()

        elif type(text) is list:
            for i in range(len(text)):
                tok_idx = x['tokens']['input_ids'][i].squeeze()
                mask = x['tokens']['attention_mask'][i] > 0
                
                obslogits.append(logits[0, :, tok_idx[1:]].squeeze().diag()[mask[1:]])

        return obslogits


    def obs_logprobs(self, text):
        logits = self.obs_logits(text)

        return [F.log_softmax(l, -1) for l in logits] if type(logits)==list else F.log_softmax(logits, -1)
        
       
    def completion_logprob(self, text, completion, start_ind = None):
        
        '''
        Compute model log probability of completion substring. Returns single value tensor. Takes only one text string.
        '''

        return self.substring_logprobs(text, completion)[0][-1]
        

    def substring_logprobs(self, texts, substring, pad = True):
        '''
        Compute model log probability of each occurrence of substring in text. Returns list of list-type. Accepts a list of strings.
        '''
        
        if type(texts) != list:
            texts = [texts]
        
        logprobs = self.logprobs(texts)
        
        tok_encoded = encode_token(substring, self.tok, pad = pad)
        # text_encoded = logprobs['tokens']['input_ids'][0].tolist()
        
        out = []
        for i in range(len(texts)):
            text_encoded = logprobs['tokens']['input_ids'][i].tolist()

            # find matches for searched token sequence
            start_idxs = []
            for left in range(0, len(text_encoded) - len(tok_encoded)+1):
                # left = i - 1
                right = left + len(tok_encoded)
                if text_encoded[left:right] == tok_encoded:
                    start_idxs.append(left)

            lp = logprobs['logprobs'][i]
            match_probs = []

            # compute probability for all tokens
            for start in start_idxs:
                val = 0
                for i in range(len(tok_encoded)):
                    val += lp[start + i - 1][tok_encoded[i]]
                match_probs.append(val)

            out.append(match_probs)

        return out
        

    def choose(self, prompt, choices, normalization = None):

        # prompt = prompt.rstrip() # remove any trailing whitespace

        if type(self.tok) == transformers.models.llama.tokenization_llama.LlamaTokenizer:
            padded_choices = choices
            prompt = prompt + " " if prompt[-1]!= " " else prompt
        else:
            padded_choices = [pad_token(c) for c in choices] # pad all the 
        
        prompts = [prompt + c for c in padded_choices]

        logits = torch.tensor([self.completion_logprob(prompts[i], padded_choices[i]) for i in range(len(padded_choices))])

        if normalization == "unconditional":
            norm_logits = torch.tensor([self.completion_logprob(padded_choices[i], padded_choices[i]) for i in range(len(padded_choices))])
            logits = logits - norm_logits

        elif normalization == "byte_length":    
            str_lens = [len(c) for c in choices]
            logits = logits / torch.tensor(str_lens)

        elif normalization == "token_length":
            tok_lens = [len(encode_token(c, self.tok)) for c in choices]
            logits = logits / torch.tensor(tok_lens)

        elif normalization == "root":
            tok_lens = [len(encode_token(c, self.tok)) for c in choices]
            logits = torch.pow(torch.exp(logits), 1./torch.tensor(tok_lens))

        logits = logits.tolist()

        return(logits.index(max(logits)))

In [3]:
## ---------------------------------------------------------------------
## load llama-2 as a EditedModel class (not pipeline, to integrate better with other scripts/notebooks)
## ---------------------------------------------------------------------

MODEL_NAME = "meta-llama/Llama-2-7b-hf" 

# tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME)
# model = LlamaForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map = "auto")
model = WrappedModel(
    LlamaForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map = "auto"),
    LlamaTokenizer.from_pretrained(MODEL_NAME)
)

# hparams = FTHyperParams.from_hparams('hparams/FT/llama-7b.yaml')
# model = EditedModel(hparams, auth_token())

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Setting up DF

In [142]:
df = pd.read_csv("data/obqa/test.tsv", sep='\t')

df.columns = df.columns.str.replace(' ', '_')
df.columns = df.columns.str.lower()

df2 = df.copy()#.tail(100)

## Generate Answers

In [143]:
with open("Prompts/answer-prompt.txt", 'r') as file:
    answer_prompt = file.read()

def mc_choose_answer(question, model, tokenizer=None):
    input_str = f"\n{answer_prompt}\n{question}\n"
    
    # Setting up the text generation pipeline
    pipe = transformers.pipeline(
        "text-generation",
        model=model.model,
        tokenizer=model.tok,
        max_new_tokens=1,  # We only need a single token for the answer
        do_sample=True,
        #top_p=0.05,
        temperature=0.05,
        num_return_sequences=1
    )

    # Generate the sequence
    sequences = pipe(input_str)
    # Extract the generated character, assuming it is the answer
    generated_text = sequences[0]['generated_text']
    #print(generated_text)
    answer = generated_text.strip()[-1]  # Get the last character
    
    return answer



def last_token_logprobs(text, last_tokens, model):
    x = model.logprobs(text)
    logprobs = x['logprobs']
    t_idx = [i[-1] for i in model.tok(last_tokens)['input_ids']]

    return(logprobs[0, -1, t_idx])


def mc_answer_logprobs(question, model, answers = ['A','B','C','D']):

    input_str = mc_answer_prompt + f"\n\n{question}\n"

    return last_token_logprobs(input_str, answers, model)


#mc_answer_logprobs('What color is the sky? (A) blue (B) red (C) orange (D) black', model)
ans_list = []

for index, row in df2.iterrows():
    question = row['complete_question']
    
    ans = mc_choose_answer(question, model, tokenizer=None)
    
    ans_list.append(ans[0])

df2['generated_answers'] = ans_list

In [145]:
accuracy = sum(df2['generated_answers'] == df2['answer_key'])
print(accuracy)

#df2

239


In [146]:
df2 = df2[df2['generated_answers'] == df2['answer_key']]

In [93]:
df2

Unnamed: 0,id,question_stem,choices,complete_question,answer_key,generated_answers
401,8-478,"If I want to go running at night, what can I u...",(A) A black shirt (B) Kitchen foil (C) Sunglas...,"If I want to go running at night, what can I u...",B,B
403,7-732,Coal-fire power stations heat coal to incredib...,(A) produce energy (B) use heat energy (C) bur...,Coal-fire power stations heat coal to incredib...,A,A
407,7-386,Acid can be used to make a new,(A) light (B) substance (C) electricity (D) sound,Acid can be used to make a new (A) light (B) s...,B,B
414,9-612,cellular respiration is when energy is produce...,(A) water (B) nutrients (C) mitochondria (D) gas,cellular respiration is when energy is produce...,B,B
415,9-548,Did pasteurization get invented by Thomas Edison?,(A) negative (B) positive (C) all of these (D)...,Did pasteurization get invented by Thomas Edis...,A,A
417,7-95,Water levels may decrease on cloudless days be...,(A) water is warmer than the air (B) air is wa...,Water levels may decrease on cloudless days be...,C,C
420,9-490,DNA is a vehicle for passing,(A) clothes types (B) school grades (C) elbow ...,DNA is a vehicle for passing (A) clothes types...,C,C
421,9-301,A beach ball goes from flat to round once you ...,(A) food (B) sunlight (C) gas (D) salt,A beach ball goes from flat to round once you ...,C,C
422,60,"In general, how many times per month is there ...",(A) twice (B) three times (C) once (D) four times,"In general, how many times per month is there ...",C,C
424,9-895,The amount of brush in a park has been decreas...,(A) the season has been quite dry (B) There ha...,The amount of brush in a park has been decreas...,A,A


## Generate Statement for Every Question Type in OBQA

In [149]:
with open("statement-prompt.txt", 'r') as file:
    statement_prompt = file.read()

def generate_state(question, answer, model, num_prem=1):
    # Combine the premises prompt with the statement
    input_str = f"{statement_prompt}\n{question}\n{answer}\n"

    # Initialize the pipeline without specifying the device
    pipe = transformers.pipeline(
        "text-generation",
        model=model.model,
        tokenizer=model.tok,
        max_new_tokens=55,  # Limit the number of generated tokens
        do_sample=True,
        top_k=50,
        temperature=0.7,
        num_return_sequences=num_prem
    )

    sequences = pipe(input_str)
    
    state = [s['generated_text'][len(input_str):].strip().split('\n', 1)[0] for s in sequences]
    
    return state

generated_state_list = []

for index, row in df2.iterrows():
    question = row['complete_question']
    answer = row['answer_key']
    
    state = generate_state(question, answer, model, num_prem=1)
    
    generated_state_list.append(state[0])

df2['generated_statements'] = generated_state_list

In [150]:
print("done")

done


## Recursive Premises Generator (Natural)

In [139]:
df3 = df2.copy().tail(10)

In [141]:
df3.to_csv('scrap.tsv', sep='\t', index=False)

In [151]:
with open("detective-prompt.txt", 'r') as file:
    premises_prompt = file.read()

def generate_premises(statement, model, num_prem=1):
    input_str = f"{premises_prompt}\n\n{statement}\n\n"

    pipe = transformers.pipeline(
        "text-generation",
        model=model.model,
        tokenizer=model.tok,
        max_new_tokens=55,
        do_sample=True,
        top_k=40,
        top_p=0.8,
        temperature=0.5,
        num_return_sequences=num_prem
    )

    sequences = pipe(input_str)
    premises = ["\n".join(s['generated_text'][len(input_str):].strip().split("\n")[:2]) for s in sequences]
    
    return premises

generated_premises_list = []

for index, row in df2.iterrows():
    statement = row['generated_statements']
    premises = generate_premises(statement, model, num_prem=1)
    generated_premises_list.append(premises[0])

def post_premises(text):
    premises = [line.strip() for line in text.split('\n') if not line.strip().startswith('Therefore')]
    return premises

df2['nat_prem_list'] = generated_premises_list
df2['nat_prem_list'] = df2['nat_prem_list'].apply(post_premises)

for i in range(min(df3['nat_prem_list'].apply(len).max(), 2)):
    df2[f'nat_prem{i+1}'] = df2['nat_prem_list'].apply(lambda x: x[i] if i < len(x) else None)

In [115]:
df3

Unnamed: 0,id,question_stem,choices,complete_question,answer_key,generated_answers,generated_statements,nat_prem_list,nat_prem1,nat_prem2
401,8-478,"If I want to go running at night, what can I u...",(A) A black shirt (B) Kitchen foil (C) Sunglas...,"If I want to go running at night, what can I u...",B,B,"If I want to go running at night, what can I u...","[If you want to go running at night, you can w...","If you want to go running at night, you can we...",A reflector can be used to reflect light to ot...
403,7-732,Coal-fire power stations heat coal to incredib...,(A) produce energy (B) use heat energy (C) bur...,Coal-fire power stations heat coal to incredib...,A,A,Coal-fire power stations heat coal to incredib...,[The heat of the coal is used to boil water to...,The heat of the coal is used to boil water to ...,Steam boiling produces lots of heat energy.
407,7-386,Acid can be used to make a new,(A) light (B) substance (C) electricity (D) sound,Acid can be used to make a new (A) light (B) s...,B,B,Acid can be used to make a new substance.,[Acid can be used to make a new substance.],Acid can be used to make a new substance.,
414,9-612,cellular respiration is when energy is produce...,(A) water (B) nutrients (C) mitochondria (D) gas,cellular respiration is when energy is produce...,B,B,Cellular respiration is when energy is produce...,[The nutrients that are consumed during cellul...,The nutrients that are consumed during cellula...,
415,9-548,Did pasteurization get invented by Thomas Edison?,(A) negative (B) positive (C) all of these (D)...,Did pasteurization get invented by Thomas Edis...,A,A,Pasteurization was not invented by Thomas Edison.,[Pasteurization was not invented by Thomas Edi...,Pasteurization was not invented by Thomas Edison.,
417,7-95,Water levels may decrease on cloudless days be...,(A) water is warmer than the air (B) air is wa...,Water levels may decrease on cloudless days be...,C,C,Water levels may decrease on cloudless days be...,[The water in the atmosphere is pulled upwards...,The water in the atmosphere is pulled upwards ...,Convection is more likely when there are no cl...
420,9-490,DNA is a vehicle for passing,(A) clothes types (B) school grades (C) elbow ...,DNA is a vehicle for passing (A) clothes types...,C,C,DNA is a vehicle for passing language and dial...,"[DNA is a molecule made up of nucleotides., Nu...",DNA is a molecule made up of nucleotides.,Nucleotides are the basic building blocks of DNA.
421,9-301,A beach ball goes from flat to round once you ...,(A) food (B) sunlight (C) gas (D) salt,A beach ball goes from flat to round once you ...,C,C,A beach ball goes from flat to round once you ...,"[A beach ball has both flat and round sides., ...",A beach ball has both flat and round sides.,"Once it has air inside it, it becomes a 3 dime..."
422,60,"In general, how many times per month is there ...",(A) twice (B) three times (C) once (D) four times,"In general, how many times per month is there ...",C,C,"In general, there is one full moon per month.","[A full moon occurs every 29.5 days., A full m...",A full moon occurs every 29.5 days.,A full moon occurs every month.
424,9-895,The amount of brush in a park has been decreas...,(A) the season has been quite dry (B) There ha...,The amount of brush in a park has been decreas...,A,A,The amount of brush in a park has been decreas...,[The season has been quite dry means that the ...,The season has been quite dry means that the p...,"The park is not getting much rain, which is a ..."


## Recursive Premises Generator (Detective)

In [99]:
def generate_sequential_premises(statement, model, num_prem=3):
    input_str = f"{premises_prompt}\n{statement}\n"
    premises = []
    
    for _ in range(num_prem):
        pipe = transformers.pipeline(
            "text-generation",
            model=model.model,
            tokenizer=model.tok,
            max_new_tokens=55,
            do_sample=True,
            top_k=50,
            temperature=0.7,
            num_return_sequences=1
        )

        sequences = pipe(input_str)
        premise = sequences[0]['generated_text'][len(input_str):].strip()
        premises.append(premise)
        
        input_str += f"{premise}\n"

    return premises

det_prem_list = []

for index, row in df3.iterrows():
    statement = row['generated_statements']
    premises = generate_sequential_premises(statement, model, num_prem=3)
    det_prem_list.append(premises)

def post_premises(text):
    premises = [line.strip() for line in text if not line.strip().startswith('Therefore')]
    return premises

df3['det_prem_list'] = det_prem_list
df3['det_prem_list'] = df3['det_prem_list'].apply(post_premises)

for i in range(3):  # Fixed to 3 columns
    df3[f'det_prem{i+1}'] = df3['det_prem_list'].apply(lambda x: x[i] if i < len(x) else None)


In [152]:
print("done")

done


## Similarity Scoring

In [163]:
from sentence_transformers.cross_encoder import CrossEncoder

model = CrossEncoder("cross-encoder/stsb-distilroberta-base")

def prem_similarity(row):
    prem1 = row['nat_prem1']
    prem2 = row['nat_prem2']
    
    if prem2 is None:
        return np.nan
    
    sentence_combinations = [prem1, prem2]
    score = model.predict(sentence_combinations)
    
    return score

def state_similarity(row):
    query = row['generated_statements']
    premise = row['nat_prem1']
    
    if premise is None:
        return np.nan
    
    sentence_combinations = [query, premise]
    score = model.predict(sentence_combinations)
    
    return score

df2['prem_sim'] = df2.apply(prem_similarity, axis=1)
df2['state_sim'] = df2.apply(state_similarity, axis=1)



In [155]:
len = len(df2)
print(len)

239


In [165]:
df2 = df2[~((df2['prem_sim'] > 0.85) | (df2['state_sim'] > 0.85))]

In [9]:
df3 = df3[df3['sim_score1'] < 0.68]

In [16]:
acc = sum(df2['sim_score1'] < 0.68)
print(acc)

163


In [164]:
df2

Unnamed: 0,id,question_stem,choices,complete_question,answer_key,generated_answers,generated_statements,nat_prem_list,nat_prem1,nat_prem2,sim_score,prem_sim,state_sim
1,1129,There is most likely going to be fog around:,(A) a marsh (B) a tundra (C) the plains (D) a ...,There is most likely going to be fog around: (...,A,A,There is most likely going to be fog around a ...,[Marshes are areas of low elevation with a lot...,Marshes are areas of low elevation with a lot ...,Low elevation means fog is more likely to form.,0.186794,0.186794,0.044216
2,880,Predators eat,(A) lions (B) humans (C) bunnies (D) grass,Predators eat (A) lions (B) humans (C) bunnies...,C,C,Predators eat bunnies.,"[Bunnies are prey., Predators are animals that...",Bunnies are prey.,Predators are animals that eat prey.,0.513065,0.513065,0.743821
4,8-464,An electric car runs on electricity via,(A) gasoline (B) a power station (C) electrica...,An electric car runs on electricity via (A) ga...,C,C,An electric car runs on electricity via electr...,[Electrical conductors are used to carry elect...,Electrical conductors are used to carry electr...,Electricity is used to power electric cars.,0.642961,0.642961,0.805299
7,9-322,The middle of the day usually involves the bri...,(A) moons gravity (B) human planet rotation (C...,The middle of the day usually involves the bri...,B,B,The middle of the day usually involves the bri...,[The Earth's rotation causes the bright star n...,The Earth's rotation causes the bright star ne...,The bright star nearest to the Earth is the Sun.,0.528757,0.528757,0.835567
12,8-201,A red-tailed hawk is searching for prey. It is...,(A) an eagle (B) a cow (C) a gecko (D) a deer,A red-tailed hawk is searching for prey. It is...,C,C,A red-tailed hawk is searching for prey. It is...,"[Red-tailed hawks are known for hunting deer.,...",Red-tailed hawks are known for hunting deer.,Deer are the most common prey for red-tailed h...,0.895296,0.895296,0.816792
...,...,...,...,...,...,...,...,...,...,...,...,...,...
493,283,The Grand Canyon was formed by,(A) a volcano erupting in 1782 (B) a river nam...,The Grand Canyon was formed by (A) a volcano e...,C,C,The Grand Canyon was formed by the river named...,[The Grand Canyon was formed by the river name...,The Grand Canyon was formed by the river named...,The river named after the 38th state to join t...,0.373298,0.373298,0.986350
494,8-183,"A woman, with a pale complexion, wants to spen...",(A) UV rays are harmful (B) sunlight will be f...,"A woman, with a pale complexion, wants to spen...",A,A,"A woman, with a pale complexion, wants to spen...",[Sunblock is a cream or lotion that protects s...,Sunblock is a cream or lotion that protects sk...,Skin is affected by the sun's harmful rays.,0.461013,0.461013,0.590019
495,9-284,A person is heating water in order to cook pas...,(A) scalds (B) cools (C) toasts (D) freezes,A person is heating water in order to cook pas...,A,A,A person is heating water in order to cook pas...,"[The water scalds means it burns the person., ...",The water scalds means it burns the person.,Burning the person is a negative effect of the...,0.664614,0.664614,0.399118
497,926,A decrease in diseases,(A) has no impact on a population (B) leads to...,A decrease in diseases (A) has no impact on a ...,C,C,A decrease in diseases leads to less sick people.,[A decrease in diseases leads to less sick peo...,A decrease in diseases leads to less sick people.,A decrease in sick people leads to less deaths.,0.742546,0.742546,0.987094


In [166]:
df2.to_csv('scrap.tsv', sep='\t', index=False)

In [38]:
import numpy as np
import transformers
from sentence_transformers.cross_encoder import CrossEncoder

similarity_model = CrossEncoder("cross-encoder/stsb-distilroberta-base")

with open("detective-prompt.txt", 'r') as file:
    premises_prompt = file.read()

def first_similarity(premise1, premise2):
    if premise1 is None or premise2 is None:
        return np.nan
    
    sentence_combinations = [[premise1, premise2]]
    score = similarity_model.predict(sentence_combinations)[0]
    
    return score

def generate_premises(statement, model, num_prem=1):
    input_str = f"{premises_prompt}\n\n{statement}\n\n"

    pipe = transformers.pipeline(
        "text-generation",
        model=model.model,
        tokenizer=model.tok,
        max_new_tokens=55,
        do_sample=True,
        top_k=50,
        temperature=0.9,
        num_return_sequences=num_prem
    )

    sequences = pipe(input_str)
    premises = ["\n".join(s['generated_text'][len(input_str):].strip().split("\n")[:2]) for s in sequences]
    
    return premises

def post_premises(text):
    if isinstance(text, list):
        premises = [line.strip() for line in text if not line.strip().startswith('Therefore')]
    else:
        premises = [line.strip() for line in text.split('\n') if not line.strip().startswith('Therefore')]
    return premises

# Generate and check premises
generated_premises_list = []
for index, row in df2.iterrows():
    statement = row['generated_statements']
    
    # Generate the first premise
    premises = generate_premises(statement, model, num_prem=2)
    premise1 = premises[0]
    premise2 = premises[1] if len(premises) > 1 else None

    # Check similarity of the second premise to the first and regenerate if necessary
    while premise2 and first_similarity(premise1, premise2) > 0.68:
        new_premise = generate_premises(statement, model, num_prem=1)[0]
        premise2 = new_premise

    generated_premises_list.append([premise1, premise2])

# Assign the generated premises to the DataFrame
df2['nat_prem_list'] = generated_premises_list
df2['nat_prem_list'] = df2['nat_prem_list'].apply(post_premises)

# Split the premises into separate columns
for i in range(min(df2['nat_prem_list'].apply(len).max(), 2)):
    df2[f'nat_prem{i+1}'] = df2['nat_prem_list'].apply(lambda x: x[i] if i < len(x) else None)




## Model Editing

In [None]:
class EditedModel:
    def __init__(self, hparams, auth_token=None):
        self.editor = BaseEditor.from_hparams(hparams)

        self.model = self.editor.model
        self.tok = self.editor.tok
        self.model_name = self.editor.model_name
        

        self.params = hparams
        self.preprompt = ""
        self.saved_weights = None
        
        self.tok.padding_side = "left"
        # self.tok.pad_token = self.tok.eos_token
    
    def edit(self, rewrite, log_file = None, **kwargs):
        if log_file:
            h = open(log_file, "a")
        else:
            h = None
        
        if "preprompt" in rewrite: # this is a little hacky
            self.preprompt = rewrite["preprompt"]
            return None
        
        # elif type(rewrite) == dict:
        else:
            with redirect_stdout(h): # None
                metrics, self.model, self.saved_weights = self.editor.pure_edit( # pure_edit
                    **rewrite,
                    # **kwargs,
                    keep_original_weight = True,
                    verbose = False
                )
        # elif type(rewrite)==list:

        #     # prompts = [x['prompts'] for x in rewrite]
        #     # target_new = [x['target_new'] for x in rewrite]

        #     with redirect_stdout(h): # None
        #         metrics, self.model, self.saved_weights = self.editor.pure_edit( # pure_edit
        #             rewrite,
        #             # target_new,
        #             # **kwargs,
        #             keep_original_weight = True,
        #             verbose = False
        #         )
        

        return metrics
    
    
    def restore(self):

        self.preprompt = ""
        
        if self.params.alg_name == "LoRA":
            self.model = self.model.unload()
        
        elif self.saved_weights:

            try:
                with torch.no_grad():
                    for k, v in self.saved_weights.items():
                        nethook.get_parameter(self.model, k)[...] = v
                self.saved_weights = None
                # print("Original model restored")
            except NameError as e:
                print(f"No model weights to restore: {e}")

        elif self.saved_weights == {}:
            print (print(f"No model weights to restore: saved_weights is empty dict"))

        return None

            
    def generate_text(self, texts, **kwargs):
        
        if type(texts) != list:
            texts = [texts]
        
        texts = [self.preprompt + t for t in texts]

        model = self.model
        tokenizer = self.tok
        encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

        with torch.no_grad():
            generated_ids = model.generate(**encoding, **kwargs) # 

            generated_texts = tokenizer.batch_decode(
                generated_ids, skip_special_tokens=True
            )
            
        return(generated_texts)
    
    
    def logprobs(self, texts):
        
        texts = self.preprompt + texts if type(texts)==str else [self.preprompt + t for t in texts]
    
        tokenizer = self.tok 
        model = self.model
        encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

        with torch.no_grad():
            model_out = model(encoding["input_ids"])
            logits = model_out.logits
            logprobs = F.log_softmax(logits, -1)
        
        return {"tokens": encoding, "logprobs": logprobs}

    
    def completion_logprob(self, text, completion, start_ind = None):
        
        '''
        Compute model log probability of completion substring. Returns single value tensor. Takes only one text string.
        '''
        
        # texts = self.preprompt + text
    
        # tokenizer = self.tok 
        # model = self.model
        # encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

        # with torch.no_grad():
        #     model_out = model(encoding["input_ids"])
        #     logits = model_out.logits
        #     logprobs = F.log_softmax(logits, -1)

        # token_id = encode_token(completion, tokenizer)
        # start_ind = -len(token_id)-1 if not start_ind else start_ind
        
        # l = logprobs[:, start_ind:-1, token_id]
        # if len(l.squeeze().shape) == 0:
        #     return(l.squeeze())
        # else:
        #     return(l.squeeze().diag().sum())
        

        return self.substring_logprobs(text, completion)[0][-1]
        

    def substring_logprobs(self, texts, substring, pad = True):
        '''
        Compute model log probability of each occurrence of substring in text. Returns list of list-type. Accepts a list of strings.
        '''
        
        if type(texts) != list:
            texts = [texts]
        
        logprobs = self.logprobs(texts)
        
        tok_encoded = encode_token(substring, self.tok, pad = pad)
        # text_encoded = logprobs['tokens']['input_ids'][0].tolist()
        
        out = []
        for i in range(len(texts)):
            text_encoded = logprobs['tokens']['input_ids'][i].tolist()

            # find matches for searched token sequence
            start_idxs = []
            for left in range(0, len(text_encoded) - len(tok_encoded)+1):
                # left = i - 1
                right = left + len(tok_encoded)
                if text_encoded[left:right] == tok_encoded:
                    start_idxs.append(left)

            lp = logprobs['logprobs'][i]
            match_probs = []

            # compute probability for all tokens
            for start in start_idxs:
                val = 0
                for i in range(len(tok_encoded)):
                    val += lp[start + i - 1][tok_encoded[i]]
                match_probs.append(val)

            out.append(match_probs)

        return out
        

    def choose(self, prompt, choices, normalization = None):

        # prompt = prompt.rstrip() # remove any trailing whitespace

        if type(self.tok) == transformers.models.llama.tokenization_llama.LlamaTokenizer:
            padded_choices = choices
            prompt = prompt + " " if prompt[-1]!= " " else prompt
        else:
            padded_choices = [pad_token(c) for c in choices] # pad all the 
        
        prompts = [prompt + c for c in padded_choices]

        logits = torch.tensor([self.completion_logprob(prompts[i], padded_choices[i]) for i in range(len(padded_choices))])

        if normalization == "unconditional":
            norm_logits = torch.tensor([self.completion_logprob(padded_choices[i], padded_choices[i]) for i in range(len(padded_choices))])
            logits = logits - norm_logits

        elif normalization == "byte_length":    
            str_lens = [len(c) for c in choices]
            logits = logits / torch.tensor(str_lens)

        elif normalization == "token_length":
            tok_lens = [len(encode_token(c, self.tok)) for c in choices]
            logits = logits / torch.tensor(tok_lens)

        elif normalization == "root":
            tok_lens = [len(encode_token(c, self.tok)) for c in choices]
            logits = torch.pow(torch.exp(logits), 1./torch.tensor(tok_lens))

        logits = logits.tolist()

        return(logits.index(max(logits)))
    

In [None]:
hparams = LoRAHyperParams.from_hparams('hparams/LoRA/llama-7b-canonical.yaml')
edited_model = EditedModel(hparams, auth_token()) 