In [1]:
import groq
import os
import sys
import anthropic
import ollama
import random
import pandas as pd
from tqdm import tqdm
from google.generativeai.types import RequestOptions
from google.api_core import retry
from typing import List, Tuple
import json
from openai import OpenAI
import datetime

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)

if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from concurrent.futures import ThreadPoolExecutor, TimeoutError

# prompts

### Zero Shot - Vanilla CoT

In [2]:
zero_shot_vanilla_cot = """
Think through your answer step by step. Put the concise form of your final answer in curly brackets e.g. {A}, {True} or {False}.
"""

In [6]:
def get_prompt(prompt_type: str, few_shot_prompt: str, question: str) -> str:
    prompts = {
        "zero_shot_vanilla_cot": f"{question}\n{zero_shot_vanilla_cot}",
    }
    return prompts[prompt_type]

def save_results(save_path: str, ids: List[str], questions: List[str], answers: List[str], append: bool = False):
    df = pd.DataFrame({'id': ids, 'question': questions, 'answer': answers})
    if append and os.path.exists(save_path):
        df.to_csv(save_path, mode='a', index=False, header=False)
    else:
        df.to_csv(save_path, index=False)

def read_jsonl_file(filepath: str) -> List[dict]:
    data = []
    with open(filepath, 'r') as file:
        for line in file:
            json_obj = json.loads(line)
            data.append(json_obj)
    return data

def load_data_size_specific(data_path: str, sample_size: int = 0, random_seed: int = 0):
    random.seed(random_seed)

    if data_path.endswith('.jsonl'):
        data = read_jsonl_file(data_path)
    elif data_path.endswith('.json'):
        with open(data_path, 'r') as file:
            data = json.load(file)
    
    question_length = 0
    eligible_data = [x for x in data if len(x["question"]) >= question_length]
    
    if sample_size > 0 and sample_size < len(eligible_data):
        sampled_data = random.sample(eligible_data, sample_size)
    else:
        sampled_data = eligible_data
    
    ids = [x["id"] for x in sampled_data]
    questions = [x["question"] for x in sampled_data]
    
    return ids, questions

def load_already_answered_ids(save_path: str) -> set:
    if os.path.exists(save_path):
        df = pd.read_csv(save_path)
        answered_ids = set(df['id'].astype(int).tolist())
        # print(f"Loaded {len(answered_ids)} already answered IDs from: {save_path}")
        print(f"Already answered IDs: {answered_ids}")
        return answered_ids
    else:
        print(f"No existing save file found at: {save_path}. Starting fresh.")
        return set()

def initialize_save_file(save_path: str):
    if not os.path.exists(save_path):
        # Create an empty DataFrame with headers and save
        df = pd.DataFrame(columns=['id', 'question', 'answer'])
        df.to_csv(save_path, index=False)
        print(f"Initialized new save file with headers at: {save_path}")

In [4]:
def query_4o(prompt: str) -> str:
    client = OpenAI()

    completion = client.chat.completions.create(
        model="gpt-4o-2024-08-06",
        messages=[
            {
                "role": "user",
                "content": f"{prompt}"
            }
        ],
        temperature=0
    )

    return completion.choices[0].message.content

def query_llama(prompt: str, timeout_duration=800) -> str:
    client = Groq(
        api_key=os.environ.get("GROQ_API_KEY"),
    )

    chat_completion = client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": prompt,
            }
        ],
        model="llama-3.1-8b-instant",
    )
    return chat_completion.choices[0].message.content

In [5]:
def query_llm(llm_model: str, ids: List[str], questions: List[str], few_shot_prompt: str, prompt_type: str, save_path: str, already_answered_ids: set) -> Tuple[List[str], List[str], List[str]]:
    answers = []
    ids_can_be_answered = []
    questions_can_be_answered = []
    
    for id, q in tqdm(zip(ids, questions), total=len(ids)):
        # print(q)
        # print(f"Processing ID: {id}")
        if id in already_answered_ids:
            print(f"Skipping already answered ID: {id}")
            continue
        if id == 1146: # weird ID that breaks llama
            continue
        
        prompt = get_prompt(prompt_type, few_shot_prompt, q)
        try:
            if llm_model == 'gemini':
                answer = query_gemini(prompt, id)
            elif llm_model == 'claude':
                answer = query_claude(prompt)
            elif llm_model == '4o':
                # answer = query_4o_multiturn(prompt)
                if prompt_type == 'multi_convo':
                    fact_prompt = get_prompt(prompt_type="fact_prompt", few_shot_prompt="", question=q)
                    
                    answer_prompt = get_prompt(prompt_type="answer_prompt_data", few_shot_prompt="", question=q)
                    answer = query_4o_multiconvo(fact_prompt=fact_prompt, answer_prompt=answer_prompt, extracted_question=q)
                else:
                    answer = query_4o(prompt)
                
            elif llm_model == 'llama3.1':
                if prompt_type == 'multi_convo':
                    fact_prompt = get_prompt(prompt_type="fact_prompt", few_shot_prompt="", question=q)
                    answer_prompt = get_prompt(prompt_type="answer_prompt_key_fact", few_shot_prompt="", question=q)

                    answer = query_llama_multiconvo(fact_prompt=fact_prompt, answer_prompt=answer_prompt, extracted_question=q)                
                else:
                    # print(id)
                    answer = query_llama(prompt)
            else:
                raise ValueError(f"Unsupported LLM model: {llm_model}")
            # print(f"Answer for ID {id}: {answer}")
            
            answers.append(answer)
            questions_can_be_answered.append(q)
            ids_can_be_answered.append(id)

            # Save after each answer
            save_results(save_path, [id], [q], [answer], append=True)
        except Exception as e:
            print(f"Error processing question {id}: {str(e)}")
            continue
    
    return ids_can_be_answered, questions_can_be_answered, answers