In [2]:
import torch
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from datasets import load_dataset
from tqdm.notebook import tqdm
from transformers.utils import logging
from transformers import GenerationConfig
import json
from peft import LoraConfig, get_peft_model, TaskType,PeftModel
model_name = 'google/flan-t5-xl'

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.model_max_length = 2048

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
from peft import PeftModel

critque_model = PeftModel.from_pretrained(model, "trained_model/flan_t5_multi_task_orca1,3/checkpoint-3800")

In [4]:
from huggingface_hub import login

login("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")

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/cyhung/.cache/huggingface/token
Login successful


In [5]:
# Use a pipeline as a high-level helper
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
access_token = ###


generator_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
generator = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf",
            load_in_8bit=True,
            torch_dtype=torch.float16,
            device_map="auto")

2023-12-27 19:57:15.122800: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-27 19:57:15.122831: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-27 19:57:15.122844: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-27 19:57:15.126029: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [11]:
import wikienv, wrappers
import re
import requests
from transformers import GenerationConfig

env = wikienv.WikiEnv()
#env = wrappers.FeverWrapper(env, split="dev")
env = wrappers.LoggingWrapper(env)



class Retriever:

    def __call__(self,term,top_sentence=12):

        return self.search_term(term,top_sentence)
        
    def step(self,env, action):
        attempts = 0
        while attempts < 10:
            try:
                return env.step(action)
            except requests.exceptions.Timeout:
                attempts += 1
    
    def search_term(self,term,top_sentence=12,count= 0):
        if count >=2:
            print("Nothing found")
            return ''
        action = f'Search[{term}]'
        action[0].lower() + action[1:]
        print(action)
        
        res = self.step(env, action[0].lower() + action[1:])[0] ## There might be some unknown unicode decoding error here
        
        if isinstance(res,str): ## Could not find but found similar term
            match = re.search(r"Similar: \[(.*?)\]", res)
            print(res)
            if match:
                list_str = match.group(1)
                # Split the list string and get the first element
                elements = re.findall(r"'(.*?)'", list_str)
                if elements:
                    first_element = elements[0]
                    print(f"Searching the most similar term {first_element} instead")
                    count+=1
                    return self.search_term(first_element,top_sentence,count)
                    
        if isinstance(res,list):
            if res[0].startswith('There were no results matching the query'): ## Nothing found
                print("Nothing found")
                return ''
            sentence = ' '.join(res[:top_sentence]).replace('\xa0','')
            context = ' '.join(sentence.split(' ')[:300]) ## Take max 300 words
            return context
        print("Unknown error",res)
        return ''


generation_config_critque = GenerationConfig(
            temperature=0.1,
            do_sample=True,
            top_p=0.75,
            top_k=40,
            max_new_tokens=512)

generation_config_generator = GenerationConfig(
            temperature=0.1,
            do_sample=True,
            top_p=0.75,
            top_k=40,
            max_new_tokens=512)

class Critque:

    def __init__(self,
                 critque_model,
                 retriever,
                 t5_tokenizer,
                 generation_config=generation_config_critque,
                 device='cuda',
                 max_steps=3):
        
        self.critque_model = critque_model
        self.retriever = retriever
        self.t5_tokenizer = t5_tokenizer
        self.generation_config = generation_config
        self.max_steps=max_steps
        self.device = device
        
    def _extract_t5_for_search(self,sentence):
        ## I need to ... SEARCH[TERM]
        pattern = r'\[([^\]]+)\]' ## Check for terms inside square brackets

        match = re.search(pattern, sentence)
        if match:
        # Extract the term within square brackets
            extracted_term = match.group(1)
            return extracted_term
        else:
            return ''
    def critque_step(self,query):
        ## This step returns either Context: [retrieved context] or FEEDBACK: feedback

        
        
        inputs = self.t5_tokenizer(query,return_tensors='pt')
        input_ids = inputs['input_ids'].to(self.device)
        
        with torch.no_grad():
            output = self.critque_model.generate(input_ids=input_ids,
                                                 do_sample=True,
                                                 generation_config=self.generation_config,
                                                 return_dict_in_generate=True)
            generated_output = output.sequences[0]
            s = self.t5_tokenizer.decode(generated_output,skip_special_tokens=True)
        
        search = self._extract_t5_for_search(s) ## Check if retrieval is required. This returns empty string or retrieval term
        
        if len(search)>0:  ## retrieval required
            context = self.retriever(search)

            if len(context)>0:
                self.retrieval_flag = True
               
                return f"Context: {context}"

            return f"Context: No context"
            
        return s ### This is the feedback step

class Generator:
    
    def __init__(self,
                 generator,
                 generator_tokenizer,
                 generation_config=generation_config_generator,
                 device='cuda',
                 max_steps=3,
                 ):
        
        self.generator = generator
        self.generator_tokenizer = generator_tokenizer
        self.generation_config = generation_config_generator
        self.device = device
       
    def qa_step(self,query):
    
        chat_template = [
        {"role": "system", "content": "You are an instruction following question answering model and your goal is to answer question as truthfully as you can. You will make use of the context as much as possible and make use of the hint if it is provided. Give a short and concise answer to every query."},
        {"role": "user", "content": f"{query}"}]
        
        input_ids = self.generator_tokenizer.apply_chat_template(chat_template,return_tensors='pt').to("cuda")
        input_len = len(input_ids[0])
        #print(len(input_ids))
        with torch.no_grad():
                output = self.generator.generate(input_ids=input_ids,
                                                     do_sample=True,
                                                     generation_config=self.generation_config,
                                                     return_dict_in_generate=True)
        
                generated_output = output.sequences[0][input_len:]
            
        return self.generator_tokenizer.decode(generated_output,skip_special_tokens=True)

class CritqueGen:

    def __init__(self,critque,generator,max_step=3):
        self.critque = critque
        self.generator = generator
        self.max_steps = max_step
        self.traj = []
        
    def inital_step(self,query):
        context = self.critque.critque_step(query)
        
            
        query_with_context = f"{context}\nQuestion:{query}"
        self.traj.append({"Context":context,"Query":query})

     
    def feedback_step(self,do_feedback):
        context,query = self.traj[0]['Context'],self.traj[0]['Query']
        
        if len(self.traj) == 1:  ## First step, no feedback has been provided
            context_with_query = f"{context}\n{query}"
            reasoning = self.generator.qa_step(context_with_query)
            context_with_query_reasoning = f"{context}\n{query}\nReasoning:{reasoning}"
            
        else: ## The previous reasoning step and feedback is the last element of the trajectory list
            prev_step = self.traj[-1]
            reasoning,feedback = prev_step["Reasoning"],prev_step["Feedback"]
            context_query_feedback = f"{context}\n{query}\nReasoning:{reasoning}\n{feedback}"
            print(context_query_feedback)
            reasoning = self.generator.qa_step(context_query_feedback)
            context_with_query_reasoning = f"{context}\n{query}\nReasoning:{reasoning}"
            

        feedback = self.critque.critque_step(context_with_query_reasoning)
        
        self.traj.append({"Reasoning":reasoning,"Feedback":feedback})
        if feedback == 'FEEDBACK: No HINT':
            return True
            
        return False
        
    def multi_step(self,query):
        self.traj = []
        
        self.inital_step(query)

        for i in range(2):
            end = self.feedback_step()
            if end:
                ans = self.traj[-1]['Reasoning']
                print(ans)
                return ans
                
    
        
        

    
        


In [7]:
ret = Retriever()

In [8]:
qa_model = Generator(generator,generator_tokenizer)
critque = Critque(critque_model,ret,tokenizer)
critque_gen = CritqueGen(critque,qa_model)

In [20]:
query = "What is the population of Singapore"

In [21]:
critque_gen.multi_step(query)

Search[Singapore]


' According to the context, the population of Singapore is approximately 5.68 million people, as of 2020.'

In [22]:
critque_gen.traj

[{'Context': "Context: Singapore (/ˈsɪŋ(ɡ)əpɔːr/ ⓘ SING-(g)ə-por), officially the Republic of Singapore, is an island country and city-state in maritime Southeast Asia. It is located about one degree of latitude (137 kilometres or 85 miles) north of the equator, off the southern tip of the Malay Peninsula, bordering the Strait of Malacca to the west, the Singapore Strait to the south, the South China Sea to the east, and the Straits of Johor to the north. The country's territory comprises one main island, 63 satellite islands and islets, and one outlying islet; the combined area of these has increased by approximately 25% since the country's independence as a result of extensive land reclamation projects. It has the second highest population density of any country in the world, although there are numerous green and recreational spaces as a result of urban planning. With a multicultural population and in recognition of the cultural identities of the major ethnic groups within the nation