In [2]:
## ---------------------------------------------------------------------
## set up configs for huggingface hub and OS paths on HPC cluster -- make sure config.ini is correct
## ---------------------------------------------------------------------
import configparser
def auth_token():

    config = configparser.ConfigParser()
    config.read("config.ini")
    return config["hugging_face"]["token"]

def scratch_path():
    config = configparser.ConfigParser()
    config.read("config.ini")
    return "/scratch/" + config["user"]["username"] + "/"

import os
if os.path.isdir(scratch_path()):
    os.environ['TRANSFORMERS_CACHE'] = scratch_path() + '.cache/huggingface'
    os.environ['HF_DATASETS_CACHE'] = scratch_path() + '.cache/huggingface/datasets'
print(os.getenv('TRANSFORMERS_CACHE'))
print(os.getenv('HF_DATASETS_CACHE'))

## ---------------------------------------------------------------------
## Load libraries
## ---------------------------------------------------------------------

import numpy as np
import pandas as pd

import torch
import transformers
from transformers import AutoTokenizer, AutoModel, LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast

import torch.nn.functional as F

from entailma import * ## these are where the QA and prompting functions live now
from easyeditor.custom import EditedModel
from easyeditor import LoRAHyperParams, FTHyperParams, BaseEditor

from datasets import load_dataset

## ---------------------------------------------------------------------
## Ensure GPU is available -- device should == 'cuda'
## ---------------------------------------------------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device = ", device)

/scratch/dmpowell/.cache/huggingface
/scratch/dmpowell/.cache/huggingface/datasets
device =  cuda


In [9]:
import re
def answer_choice_list(choices):
    options = re.split(r'\s*\(\w\)\s*', choices)
    return( [option.strip() for option in options if option] )


def generate_qa_cloze_prompt(fname, n = 32):
    df = pd.read_csv(fname, sep='\t')
    plist = []
    for i in range(n):
        ans = df.iloc[i]["Answer Key"]
        ans_ind = ['A','B','C','D'].index(ans)
        ans_text = answer_choice_list(df.iloc[i].Choices)[ans_ind]
        plist.append("Question: " + df.iloc[i]["Complete Question"] + "\nAnswer: " + ans_text)

    return("\n".join(plist))

mc_answer_cloze_prompt = generate_qa_cloze_prompt("data/obqa/dev.tsv")



In [8]:
class WrappedModel:
    def __init__(self, model, tokenizer, auth_token=None):
        
        self.model = model
        self.tok = tokenizer
        self.tok.pad_token_id = self.tok.eos_token_id
        # self.model_name = self.editor.model_name

        # self.params = hparams
        self.preprompt = ""
        self.saved_weights = None
        
        if type(self.tok) == transformers.LlamaTokenizer or transformers.LlamaTokenizerFast:
            self.tok.padding_side = "right"
        else: 
            self.tok.padding_side = "left"
    
    def edit(self, rewrite, log_file = None, **kwargs):
        if log_file:
            h = open(log_file, "a")
        else:
            h = None
        
        if "preprompt" in rewrite: # this is a little hacky
            self.preprompt = rewrite["preprompt"]
            return None
        
        else:
            with redirect_stdout(h): # None
                metrics, self.model, self.saved_weights = self.editor.pure_edit( # pure_edit
                    **rewrite,
                    # **kwargs,
                    keep_original_weight = True,
                    verbose = False
                )
        
        return metrics
    
    
    def restore(self):

        self.preprompt = ""
        
        if self.params.alg_name == "LoRA":
            self.model = self.model.unload()
        
        elif self.saved_weights:

            try:
                with torch.no_grad():
                    for k, v in self.saved_weights.items():
                        nethook.get_parameter(self.model, k)[...] = v
                self.saved_weights = None
                # print("Original model restored")
            except NameError as e:
                print(f"No model weights to restore: {e}")

        elif self.saved_weights == {}:
            print (print(f"No model weights to restore: saved_weights is empty dict"))

        return None

            
    def generate_text(self, texts, **kwargs):
        
        if type(texts) != list:
            texts = [texts]
        
        texts = [self.preprompt + t for t in texts]

        model = self.model
        tokenizer = self.tok
        encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

        with torch.no_grad():
            generated_ids = model.generate(**encoding, **kwargs) # 

            generated_texts = tokenizer.batch_decode(
                generated_ids, skip_special_tokens=True
            )
            
        return(generated_texts)
    
    
    # def logprobs(self, texts):
        
    #     # texts = self.preprompt + texts if type(texts)==str else [self.preprompt + t for t in texts]
    
    #     # tokenizer = self.tok 
    #     # model = self.model
    #     # encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

    #     # with torch.no_grad():
    #     #     model_out = model(encoding["input_ids"])
    #     #     logits = model_out.logits
    #     #     logprobs = F.log_softmax(logits, -1)

    #     x = self.logits(texts)
        
    #     return {"tokens": x['tokens'], "logprobs": logprobs}
    

    def logits(self, texts):
        
        texts = self.preprompt + texts if type(texts)==str else [self.preprompt + t for t in texts]
    
        tokenizer = self.tok 
        model = self.model
        encoding = tokenizer(texts, padding=True, return_tensors='pt').to(device)

        with torch.no_grad():
            model_out = model(encoding["input_ids"])
            logits = model_out.logits
        
        return {"tokens": encoding, "logits": logits}
    
    
    def logprobs(self, texts):
        
        logits = self.logits(texts)
        
        return {"tokens": logits['tokens'], "logprobs": F.log_softmax(logits['logits'], -1)}
    
    
    def obs_logits(self, text):
    
        x = self.logits(text)
        logits = x['logits']
        
        obslogits = []

        if type(text) is str:
            tok_idx = x['tokens']['input_ids'].squeeze()
            logits = x['logits']
            obslogits = logits[0, :, tok_idx[1:]].squeeze().diag()

        elif type(text) is list:
            for i in range(len(text)):
                tok_idx = x['tokens']['input_ids'][i].squeeze()
                mask = x['tokens']['attention_mask'][i] > 0
                
                obslogits.append(logits[0, :, tok_idx[1:]].squeeze().diag()[mask[1:]])

        return obslogits


    def obs_logprobs(self, text):
        logits = self.obs_logits(text)

        return [F.log_softmax(l, -1) for l in logits] if type(logits)==list else F.log_softmax(logits, -1)
        
       
    def completion_logprob(self, text, completion, start_ind = None):
        
        '''
        Compute model log probability of completion substring. Returns single value tensor. Takes only one text string.
        '''

        return self.substring_logprobs(text, completion)[0][-1]
        

    def substring_logprobs(self, texts, substring, pad = True):
        '''
        Compute model log probability of each occurrence of substring in text. Returns list of list-type. Accepts a list of strings.
        '''
        
        if type(texts) != list:
            texts = [texts]
        
        logprobs = self.logprobs(texts)
        
        tok_encoded = encode_token(substring, self.tok, pad = pad)
        # text_encoded = logprobs['tokens']['input_ids'][0].tolist()
        
        out = []
        for i in range(len(texts)):
            text_encoded = logprobs['tokens']['input_ids'][i].tolist()

            # find matches for searched token sequence
            start_idxs = []
            for left in range(0, len(text_encoded) - len(tok_encoded)+1):
                # left = i - 1
                right = left + len(tok_encoded)
                if text_encoded[left:right] == tok_encoded:
                    start_idxs.append(left)

            lp = logprobs['logprobs'][i]
            match_probs = []

            # compute probability for all tokens
            for start in start_idxs:
                val = 0
                for i in range(len(tok_encoded)):
                    val += lp[start + i - 1][tok_encoded[i]]
                match_probs.append(val)

            out.append(match_probs)

        return out
        

    def choose(self, prompt, choices, normalization = None):

        # prompt = prompt.rstrip() # remove any trailing whitespace

        if type(self.tok) == transformers.models.llama.tokenization_llama.LlamaTokenizer:
            padded_choices = choices
            prompt = prompt + " " if prompt[-1]!= " " else prompt
        else:
            padded_choices = [pad_token(c) for c in choices] # pad all the 
        
        prompts = [prompt + c for c in padded_choices]

        logits = torch.tensor([self.completion_logprob(prompts[i], padded_choices[i]) for i in range(len(padded_choices))])

        if normalization == "unconditional":
            norm_logits = torch.tensor([self.completion_logprob(padded_choices[i], padded_choices[i]) for i in range(len(padded_choices))])
            logits = logits - norm_logits

        elif normalization == "byte_length":    
            str_lens = [len(c) for c in choices]
            logits = logits / torch.tensor(str_lens)

        elif normalization == "token_length":
            tok_lens = [len(encode_token(c, self.tok)) for c in choices]
            logits = logits / torch.tensor(tok_lens)

        elif normalization == "root":
            tok_lens = [len(encode_token(c, self.tok)) for c in choices]
            logits = torch.pow(torch.exp(logits), 1./torch.tensor(tok_lens))

        logits = logits.tolist()

        return(logits.index(max(logits)))
    

In [9]:
## ---------------------------------------------------------------------
## load llama-2 as a EditedModel class (not pipeline, to integrate better with other scripts/notebooks)
## ---------------------------------------------------------------------

MODEL_NAME = 'meta-llama/Meta-Llama-3-8B' #"meta-llama/Llama-2-7b-hf" 

# tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME)
# model = LlamaForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map = "auto")
model = WrappedModel(
    LlamaForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map = "auto"),
    PreTrainedTokenizerFast.from_pretrained(MODEL_NAME)
)

# hparams = FTHyperParams.from_hparams('hparams/FT/llama-7b.yaml')
# model = EditedModel(hparams, auth_token())

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

08/19/2024 16:50:25 - INFO - accelerate.utils.modeling -   We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [11]:
df = pd.read_csv("data/obqa/test.tsv", sep='\t')
df.columns = df.columns.str.replace(' ', '_')
df.columns = df.columns.str.lower()

df2 = df.copy().tail(10) # smaller df for testing
df2.head(5)

Unnamed: 0,id,question_stem,choices,complete_question,answer_key
490,9-743,where might a bunny live?,(A) a thicket (B) atop palm trees (C) a sewer ...,where might a bunny live? (A) a thicket (B) at...,A
491,9-645,A shark will be unable to survive on eating al...,(A) it is a predator (B) it is a vegetarian (C...,A shark will be unable to survive on eating al...,A
492,8-250,"A meadow vole just gave birth, and needs to fe...",(A) oil (B) deer (C) bugs (D) recycled plastic...,"A meadow vole just gave birth, and needs to fe...",C
493,283,The Grand Canyon was formed by,(A) a volcano erupting in 1782 (B) a river nam...,The Grand Canyon was formed by (A) a volcano e...,C
494,8-183,"A woman, with a pale complexion, wants to spen...",(A) UV rays are harmful (B) sunlight will be f...,"A woman, with a pale complexion, wants to spen...",A


In [95]:
# eb = load_dataset('ariesutiono/entailment-bank-v3', split='valildation')
eb = pd.read_json('data/entailmentbank/dev.jsonl', lines=True)
eb.iloc[3].meta

{'question_text': 'In New York State, the shortest period of daylight occurs during which month?',
 'answer_text': 'December',
 'hypothesis_id': 'int3',
 'triples': {'sent1': 'united states is located in the northern hemisphere',
  'sent2': 'december is during the winter in the northern hemisphere',
  'sent3': 'new york / new york state is a state located in the united states of america',
  'sent4': 'winter has the least sunlight'},
 'distractors': [],
 'distractors_relevance': [],
 'intermediate_conclusions': {'int1': 'new york state is located in the northern hemisphere',
  'int2': 'december is during the winter for new york state',
  'int3': 'new york state has the least sunlight during december'},
 'core_concepts': ['winter has the least sunlight'],
 'step_proof': 'sent1 & sent3 -> int1: new york state is located in the northern hemisphere; int1 & sent2 -> int2: december is during the winter for new york state; int2 & sent4 -> hypothesis; ',
 'lisp_proof': '((((((sent1 sent3) -> in

In [195]:
def split_proof(proof, splits = ('&', '->')):
    if ':' in proof:
        proof = proof.split(':')[0]
    s = proof.split()
    return([x for x in s if x not in splits])

ptext = []
for e in eb[0:10].itertuples():
    proof_list = e.proof.split(';')
    meta = e.meta

    premises = meta['triples'] | meta['intermediate_conclusions']
    premises['hypothesis'] = premises[meta['hypothesis_id']]

    for proof in proof_list[:-1]:
        plist = split_proof(proof.strip())
        pkeys = [plist[-1]] + plist[:-1]
        ptext.append([premises[k] for k in pkeys])
        # print("\n".join(ptext), "\n")

Pandas(Index=0, id='Mercury_SC_401371', context='sent1: the sun rising / setting occurs once per day sent2: the sun setting is a kind of event sent3: the sun rising is a kind of event', question='Which event occurs on a daily cycle?', answer='The Sun rises and sets.', hypothesis='the sun rising and setting is the event that occurs once per day', proof='sent1 & sent2 & sent3 -> hypothesis; ', full_text_proof=' [BECAUSE] the sun rising / setting occurs once per day [AND] the sun setting is a kind of event [AND] the sun rising is a kind of event [INFER] int1: the sun rising and setting is the event that occurs once per day', depth_of_proof=1, length_of_proof=1, meta={'question_text': 'Which event occurs on a daily cycle?', 'answer_text': 'The Sun rises and sets.', 'hypothesis_id': 'int1', 'triples': {'sent1': 'the sun rising / setting occurs once per day', 'sent2': 'the sun setting is a kind of event', 'sent3': 'the sun rising is a kind of event'}, 'distractors': [], 'distractors_relevanc

In [197]:
len(ptext)

22

In [54]:
# def mc_choose_answer(question, model, tokenizer=None):
#     if not tokenizer:
#         tokenizer = model.tok
    
#     input_str = mc_answer_prompt + f"\nQuestion: {question}\nAnswer:"
#     inputs = tokenizer(input_str, return_tensors="pt")
#     input_ids = inputs["input_ids"].cuda()
#     with torch.no_grad():
#         sequences = model.generate(input_ids = input_ids, max_new_tokens = 1)
    
#     return tokenizer.decode(sequences[0])[-1]


def last_token_logprobs(text, last_tokens, model):
    x = model.logprobs(text)
    logprobs = x['logprobs']
    t_idx = [i[-1] for i in model.tok(last_tokens)['input_ids']]

    return(logprobs[0, -1, t_idx])


def mc_answer_logprobs(question, model, answers = ['A','B','C','D']):

    input_str = mc_answer_prompt + f"\n\nQuestion: {question}\nAnswer: "

    return last_token_logprobs(input_str, answers, model)


mc_answer_logprobs('What color is the sky? (A) blue (B) red (C) orange (D) black', model)

NameError: name 'mc_answer_prompt' is not defined

Question answering is getting ~58% accuracy. For reference, the original GPT-3 with 32-shot examples got 65.8% ([Brown et al., 2020](https://arxiv.org/abs/2005.14165v4)). So that seems not-too-bad.

## generate_premises() function
~~This function will read the model's statement from the data set and provide two premises that would make the statement true.~~

UPDATE: This seems to work better if we include the original question and answer, which eliminates a point of failure and gives more context for the explanation / premise generation.

UPDATE 2: This is in the `entailma` library in this repo, but I've reproduced it here to make it easier to play around with as you/we tweak prompts.


## updates:


- Need a way to score whether the premises are actually any "good" -- i.e. do they lead the model to choose the targeted answer? The code below implements an IKE/ICE-style version of this. It seems to work ok?
- Need to add more examples to the prompt of premises supportin INCORRECT answers, as it struggles with this ATM [quick and dirty version done]

In [56]:
def completion_prob(preprompt, question, target_answer, model, answers = ['A','B','C','D']):
   if len(preprompt) == 0:
      prompt = mc_answer_prompt +  "\n\n" + preprompt + "Question:" + question + "\nAnswer: "
   else:
      prompt = mc_answer_prompt + '\n\n' + preprompt + '\nQuestion: ' + question  + '\nAnswer: '
   
   logprobs0 = last_token_logprobs(prompt, answers, model)
   prob = logprobs0[answers.index(target_answer)].exp() / logprobs0.exp().sum()

   return prob



# def score_premises(premises, question, target_answer, model, answers = ['A','B','C','D']):
#    '''Returns the odds-ratio of the target answer with vs without the premises in the premises in the context.'''
#    reg_answer_prompt = mc_answer_prompt +  "\n\nQuestion:" + question+"\nAnswer: "
#    logprobs0 = last_token_logprobs(reg_answer_prompt, answers, model)
#    prob0 = logprobs0[answers.index(target_answer)].exp() / logprobs0.exp().sum()
   

#    premise_str = "\n".join(premises)
#    augmented_answer_prompt = mc_answer_prompt + '\n\n' + premise_str + '\nQuestion: ' + question  + 'Answer: '
#    logprobs1 = last_token_logprobs(augmented_answer_prompt,  answers, model)
#    prob1 = logprobs1[answers.index(target_answer)].exp() / logprobs1.exp().sum()

#    return (prob1/(1-prob1)) / (prob0/(1-prob0))


def score_premises(premises, question, target_answer, model, base_prob = None, answers = ['A','B','C','D']):
   '''Returns the odds-ratio of the target answer with vs without the premises in the premises in the context.'''
   
   if not base_prob:
      base_prob = completion_prob("", question, target_answer, model, answers)

   premise_str = "\n".join(premises)
   prob1 = completion_prob(premise_str, question, target_answer, model, answers)

   return( (prob1/(1-prob1)) / (base_prob/(1-base_prob)))


print(score_premises(['The sky is red.', 'At sunset, the sun can be extremely red.'], 'What color is the sky? (A) blue (B) red (C) yellow (D) black', 'B', model))
print(score_premises(['Some things are red.', 'My favorite color is red.'], 'What color is the sky? (A) blue (B) red (C) yellow (D) black', 'B', model))
print(score_premises(['red red red red.', 'red red red red red.'], 'What color is the sky? (A) blue (B) red (C) yellow (D) black', 'B', model))

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


tensor(46.2965, device='cuda:0')
tensor(1.0877, device='cuda:0')
tensor(3.5249, device='cuda:0')


In [61]:
# 32-shot-ish new prompt that explicitly goes for a form of reasoing from answer to question (mix of test and dev so we should fix this later)
with open("entailma/entailer-dev-prompt-tandf2.txt", 'r') as file:
    premises_prompt = file.read()


import re
def answer_choice_list(choices):
    options = re.split(r'\s*\(\w\)\s*', choices)
    return( [option.strip() for option in options if option] )

def check_repeat_words(text1, text2, max_repeat_size = 4):
    # check if text2 includes a repetition of more than max_repeat_size in a row

    t1 = text1.lower().split()
    t2 = text2.lower().split()

    sublists = []
    for idx in range(len(t1) - max_repeat_size+1):
        s = t1[idx:idx+max_repeat_size+1]
        if len(s) > max_repeat_size:
            sublists.append(' '.join(s))
    
    sublists2 = []
    for idx in range(len(t2) - max_repeat_size+1):
        s = t2[idx:idx+max_repeat_size+1]
        if len(s) > max_repeat_size:
            sublists2.append(' '.join(s))
    
    valid = True

    for seq in sublists2:
        if seq in sublists:
            valid = False
            break

    return(valid)


def generate_premises(question, answer, model, num_prem = 1, batch_size = 4):
    options = answer_choice_list(question)[1:]
    ans_ind = ['A','B','C','D'].index(answer)
    choice = options[ans_ind]
    # print(choice)
    
    input_str = f"{premises_prompt}\n\nQuestion: {question}\nAnswer: {answer}\n"

    pipe = transformers.pipeline(
        "text-generation",
        model = model.model,
        tokenizer = model.tok,
        torch_dtype=torch.float16,
        # device = model.model.device
    )

    seq_list = []
    for i in range(-(-num_prem // batch_size)):

        sequences = pipe(
            input_str + choice,
            do_sample = True,
            top_p = .7,
            # penalty_alpha = 0.6, # avoids repetition of the question + answer (except doesn't)
            temperature = 0.7,
            max_new_tokens = 50,
            num_return_sequences = min(batch_size, num_prem - i*batch_size)
        )

        seq_list += sequences
    
    generated_texts = [s['generated_text'] for s in seq_list]
    
    premises = [t[len(input_str):-1] for t in generated_texts]
    premlist = [p.split("\n")[:2] for p in premises] 

    return premlist if len(premlist) > 1 else premlist[0]


def generate_best_premises(question, answer, model, num_prem=10, batch_size = 4):
    premises = generate_premises(question, answer, model, num_prem)
    premise_validity = [check_repeat_words(question, '\n'.join(p), 5) for p in premises]

    valid_premises = [i for (i, v) in zip(premises, premise_validity) if v]

    if len(valid_premises) > 0:

        base_completion_prob = completion_prob("", question, answer, model)
        scores =  [score_premises(p, question, answer, model, base_prob = base_completion_prob) for p in valid_premises]
        max_idx = scores.index(max(scores))
        
        return valid_premises[max_idx], scores[max_idx]
    
    else:
        return ["",""], -100.


row = df2.iloc[2]
print(row.complete_question)
out = generate_best_premises(row.complete_question, row.answer_key, model, num_prem = 16, batch_size = 8)
print(out)


A meadow vole just gave birth, and needs to feed herself so that she can produce milk for her babies. She searches for food in a field, and happily munches down on some (A) oil (B) deer (C) bugs (D) recycled plastic fruit
(['bugs are a food source for the meadow vole.', 'the meadow vole needs to eat to produce milk for her babies.'], tensor(2.6916, device='cuda:0'))


In [13]:
row.complete_question

'Some animals use a liquid coming from their skin to adjust to (A) cold (B) water (C) heat (D) humidity'

In [27]:
eb = 

['', 'a', 'b', 'c']

In [14]:
## this is not the best! but it does basically work, I think well-enough to deploy hopefuly

# few-shot prompt w/ a COT-like pattern
with open("rephrase-prompt-simple.txt", 'r') as file:
    reverse_prompt = file.read()

def capitalize_leading(string):
    slist = string.split()
    slist[0] = slist[0].capitalize()
    return " ".join(slist)

def simple_reverse(x, model):
    input_str = f"{reverse_prompt}\n\n[original] {x}\n[reversed]"

    pipe = transformers.pipeline(
        "text-generation",
        model = model.model,
        tokenizer = model.tok,
        torch_dtype=torch.float16
    )
        
    sequences = pipe(
            input_str,
            max_new_tokens = 100,
            num_return_sequences = 1
        )
    output = sequences[0]['generated_text'][len(input_str):]
    
    
    return capitalize_leading(output.split('\n\n')[0])


x = simple_reverse("Henry Ford invented the Model T", model)
x

'The Model T was invented by Henry Ford'

In [87]:
## I don't think this is very useful for evaluating "belief"
# def text_logprob(text, model, norm = None):
#     if not norm:
#         norm = 1
#     elif norm == "whitespace":
#         norm = len(text.split())
    
#     logprobs = model.obs_logprobs(text)
#     return [l.sum()/norm for l in logprobs] if type(logprobs)==list else logprobs.sum()/norm
    
# [text_logprob(t, model, norm = "whitespace") for t in ['Some animals sweat in the heat to keep cool.', 'Sweat is a liquid that evaporates from the skin, which cools the body.']]

False

In [84]:
False

False