# Install Packages

In [None]:
!pip install openai
!pip install langchain
!pip install langchain_community
!pip install bitsandbytes
!pip install datasets
!pip install langgraph
!pip install langchain_google_genai
!pip install langchain_huggingface
!pip install vllm==0.8.2

# Connect To Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')


# Definitions

In [None]:
import os
os.environ['search_key'] = 'put your key here'
os.environ['hf_token'] = 'put your key here'
os.environ['openai_key'] = 'put your key here'
os.environ['gemini_api_key'] = 'put your key here'
from src.DataLoaders.DataLoader import DataLoader
from src.DataLoaders.Arc import Arc
from src.DataLoaders.PubHealth import PubHealth
from src.DataLoaders.PopQA import PopQA
from src.RetrievalEvaluators.MultiGranularRetrievalEvaluator import MultiGranularRetrievalEvaluator
from tools.llm_tool import llm
from graph import workflow_compiler
import torch
from src.Helpers.Process import Process
from src.Helpers.Utils import get_generator

# Initialize the Process

In [None]:

app = workflow_compiler()
dataloader, retrieval_evaluator= Arc(), MultiGranularRetrievalEvaluator().load_pretrained_model(plre_load_path='test.pt',slre_load_path='test2.pt') # for ARC
# dataloader, retrieval_evaluator= PubHealth(), MultiGranularRetrievalEvaluator().load_pretrained_model(plre_load_path='test.pt',slre_load_path='test2.pt') # for Pubhealth
# dataloader, retrieval_evaluator= PopQA(), MultiGranularRetrievalEvaluator().load_pretrained_model(plre_load_path='test.pt',slre_load_path='test2.pt') # for PopQA
generator = get_generator(dataloader,retrieval_evaluator, llm, app)

# Start selfrag-llama2-7b process


In [9]:
# Run this and set load_generations=True to start from checkpoint
load_generations = False
checkpoint = 925
# check the filename to match testing dataset
pickle_file_name = 'outputs/autoencoder_results/popqa_temp_925.pickle'
import pickle
if load_generations == True:
  with open(pickle_file_name, 'rb') as handle:
    res = pickle.load(handle)
    dataloader.generations = res['generations'][:checkpoint]
    dataloader.generation_checkpoint = checkpoint

In [None]:

with torch.no_grad():
  process = Process(dataloader)
  process.start(generator,load_sample_data_only=False,template='self-rag')
print("The final accuracy is: ",process.accuracy)

In [None]:
dataloader.statistics()

In [None]:
high_wrong_counter = 0
medium_wrong_counter = 0
websearch_wrong_counter = 0
high_correct_counter = 0
medium_correct_counter = 0
websearch_correct_counter = 0
for i in range(len(dataloader.generations)):
  if dataloader.generations[i] != dataloader.output_test_data[i]:
    if retrieval_evaluator.number_of_selected['highs'][i] > 0:
      high_wrong_counter += 1
    elif retrieval_evaluator.number_of_selected['mediums'][i] > 0:
      medium_wrong_counter += 1
    else:
      websearch_wrong_counter += 1
  else:
    if retrieval_evaluator.number_of_selected['highs'][i] > 0:
      high_correct_counter += 1
    elif retrieval_evaluator.number_of_selected['mediums'][i] > 0:
      medium_correct_counter += 1
    else:
      websearch_correct_counter += 1
print("number of high_wrong: ",high_wrong_counter)
print("number of medium_wrong: ",medium_wrong_counter)
print("number of websearch_wrong: ",websearch_wrong_counter)
print("number of high_correct: ",high_correct_counter)
print("number of medium_correct: ",medium_correct_counter)
print("number of websearch_correct: ",websearch_correct_counter)

In [None]:
def save_model_outputs(self, save_path: str = 'outputs/autoencoder_results/arc.pickle') -> bool:
     with open(save_path, 'wb') as handle:
         pickle.dump({'input_test_data': self.input_test_data, 'generations': self.generations}, handle)

In [14]:
save_model_outputs(dataloader)