In [None]:
# Import modules
import os
from dotenv import load_dotenv
from genai.model import Credentials, Model
from genai.schemas import GenerateParams

import pandas as pd
import time

In [None]:
#  Load Credentials
load_dotenv()
api_key = os.getenv("GENAI_KEY", None)
api_endpoint = os.getenv("GENAI_API", None)

In [None]:
# max_new_tokens must be <= 1536
params = GenerateParams(decoding_method="greedy", min_new_tokens=1, max_new_tokens=1500, stop_sequences=["\n\n\nInput:"], repetition_penalty=2)

# creds object
creds = Credentials(api_key, api_endpoint)

# model object
model = Model("tiiuae/falcon-40b", params=params, credentials=creds)
model

In [None]:
data = pd.read_csv("./questions/tgrt_questions.csv")
data

In [2]:
# Prompt formatter function for when initial output is not included 
def prompt_formatter(instr, input_text):
    prompt = instr + '\n\n' + 'Input:\n' + input_text + '\n' + 'Response:'
    return prompt


# Read aligner principles + in-context examples
corrector_principles = open("./prompts/factuality_aligner_principles_and_examples.txt", "r").read()
print(corrector_principles)


In [None]:
batch_size = 256

j = 0 # keeps track of folder numbers
start_index = 0 
end_index = 152997 # last question index number
index_tracker = None

start_time = time.time()

for index in range(start_index, end_index, batch_size):
    data_batch = data[index:index+batch_size]

    ''' x[0] contains 'question' '''
    prompts_no_output = [prompt_formatter(corrector_principles, str(x[0]).strip()) for x in data_batch.to_numpy()]
    

    questions = []
    responses = []
    problems = []
    corrected_responses = []
    bad_response_indeces = []
    none_response_indeces = []
    i = 0
    for response in model.generate_async(prompts_no_output):
        
        index_tracker = index + i
        if response is not None:
            result = response.generated_text.strip()

            # initializing substrings (used to obtain indeces for string slicing)
            sub_response = "\nCorrector"

            sub1_problems = "response):"
            sub2_problems = "\nCorrector:\n"

            sub_corrector = "\nCorrector:"
            
            sub1_question = "\nInput:"
            sub2_question = "\nResponse:"


            if sub_response in result and sub1_problems in result and sub2_problems in result and sub_corrector in result:

                question = response.input_text

                # getting indeces of substrings
                idx_response = result.index(sub_response)

                idx1_problems = result.index(sub1_problems)
                idx2_problems = result.index(sub2_problems)

                idx_corrector = result.index(sub_corrector)
                
                idx1_question = question.rindex(sub1_question)
                idx2_question = question.rindex(sub2_question)

                # length of substring 1 is added to get string from next character
                res_response = result[0: idx_response].strip()
                res_problems = result[idx1_problems + len(sub1_problems) + 1:idx2_problems].strip()
                res_corrector = result[idx_corrector + len(sub_corrector) + 1:].strip()
                res_question = question[idx1_question + len(sub1_question) + 1:idx2_question].strip()

                # adding extracted strings to their respective lists
                responses.append(res_response)
                problems.append(res_problems)
                corrected_responses.append(res_corrector)
                questions.append(res_question)
                print("=====================================================================")
                print("End of response for index: ", index_tracker)
                print("=====================================================================")
            else:
                bad_response_indeces.append(index_tracker)
                print("*******************************************************")
                print("Bad response at index: ", index_tracker)
                print("*******************************************************")
                
        else:
            none_response_indeces.append(index_tracker)
            print("*******************************************************")
            print("None response at index: ", index_tracker)
            print("*******************************************************")

        i = i+1
        
    batch_df = pd.DataFrame(questions, columns=['input'])
    batch_df["initial_response"] = responses
    batch_df["response_problems"] = problems
    batch_df["corrected_reponse"] = corrected_responses
    batch_df.to_csv('./data/generated_data'+str(j)+'_'+str(index_tracker)+'.csv', index=False)
    
    if len(bad_response_indeces)>0:
        bad_df = pd.DataFrame(bad_response_indeces, columns=['bad_response_index'])
        bad_df.to_csv('./data/bad_responses/bad'+str(j)+'.csv', index=False)
    if len(none_response_indeces)>0:
        none_df = pd.DataFrame(none_response_indeces, columns=['none_response_index'])
        none_df.to_csv('./data/bad_responses/none'+str(j)+'.csv', index=False)
    end_time = time.time()
    print("#################################################################################################")
    print("Generated data batch saved at index: ", index_tracker, " Time taken: ", end_time-start_time)
    print("#################################################################################################")
    j = j+1
    