In [65]:
from groq import Groq
import os
import sys
import anthropic
import ollama
import random
import pandas as pd
from tqdm import tqdm
from google.generativeai.types import RequestOptions
from google.api_core import retry
from typing import List, Tuple
import json
from openai import OpenAI
import datetime
import openai
import time

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)

if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from concurrent.futures import ThreadPoolExecutor, TimeoutError

# prompts

### Zero Shot - Vanilla CoT

In [66]:
zero_shot_vanilla_cot = """
Think through your answer step by step. Put the concise form of your final answer in curly brackets e.g. {A}, {True} or {False}.
"""

In [67]:
def get_prompt(prompt_type: str, few_shot_prompt: str, question: str) -> str:
    prompts = {
        "zero_shot_vanilla_cot": f"{question}\n{zero_shot_vanilla_cot}",
    }
    return prompts[prompt_type]

def save_results(save_path: str, ids: List[str], questions: List[str], answers: List[str], append: bool = False):
    df = pd.DataFrame({'id': ids, 'question': questions, 'answer': answers})
    if append and os.path.exists(save_path):
        df.to_csv(save_path, mode='a', index=False, header=False)
    else:
        df.to_csv(save_path, index=False)

def read_jsonl_file(filepath: str) -> List[dict]:
    data = []
    with open(filepath, 'r') as file:
        for line in file:
            json_obj = json.loads(line)
            data.append(json_obj)
    return data

def load_data_size_specific(data_path: str, sample_size: int = 0, random_seed: int = 0):
    random.seed(random_seed)

    if data_path.endswith('.jsonl'):
        data = read_jsonl_file(data_path)
    elif data_path.endswith('.json'):
        with open(data_path, 'r') as file:
            data = json.load(file)
    
    question_length = 0
    eligible_data = [x for x in data if len(x["question"]) >= question_length]
    
    if sample_size > 0 and sample_size < len(eligible_data):
        sampled_data = random.sample(eligible_data, sample_size)
    else:
        sampled_data = eligible_data
    
    ids = [x["id"] for x in sampled_data]
    questions = [x["question"] for x in sampled_data]
    
    return ids, questions

def load_already_answered_ids(save_path: str) -> set:
    if os.path.exists(save_path):
        df = pd.read_csv(save_path)
        answered_ids = set(df['id'].tolist())
        # print(f"Loaded {len(answered_ids)} already answered IDs from: {save_path}")
        print(f"Already answered IDs: {answered_ids}")
        return answered_ids
    else:
        print(f"No existing save file found at: {save_path}. Starting fresh.")
        return set()

def initialize_save_file(save_path: str):
    if not os.path.exists(save_path):
        # Create an empty DataFrame with headers and save
        df = pd.DataFrame(columns=['id', 'question', 'answer'])
        df.to_csv(save_path, index=False)
        print(f"Initialized new save file with headers at: {save_path}")

In [68]:
def query_4o(prompt: str) -> str:
    client = OpenAI()

    completion = client.chat.completions.create(
        model="gpt-4o-2024-08-06",
        messages=[
            {
                "role": "user",
                "content": f"{prompt}"
            }
        ],
        temperature=0
    )

    return completion.choices[0].message.content

def query_llama(prompt):
    client = openai.OpenAI(
        api_key=os.environ.get("SAMBANOVA_API_KEY"),
        base_url="https://api.sambanova.ai/v1",
    )

    response = client.chat.completions.create(
        model='Meta-Llama-3.1-8B-Instruct',
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
        temperature=0.6, # Meta default
        top_p = 0.9 # Meta default
    )
    time.sleep(2)  # Pause execution for 2 seconds
    return response.choices[0].message.content

def query_llama_70b(prompt):
    client = openai.OpenAI(
        api_key=os.environ.get("SAMBANOVA_API_KEY"),
        base_url="https://api.sambanova.ai/v1",
    )

    response = client.chat.completions.create(
        model='Meta-Llama-3.1-70B-Instruct',
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
        temperature=0.6, # Meta default
        top_p = 0.9 # Meta default
    )
    time.sleep(2)  # Pause execution for 2 seconds
    return response.choices[0].message.content

In [69]:
def query_llm(llm_model: str, ids: List[str], questions: List[str], few_shot_prompt: str, prompt_type: str, save_path: str, already_answered_ids: set) -> Tuple[List[str], List[str], List[str]]:
    answers = []
    ids_can_be_answered = []
    questions_can_be_answered = []
    
    for id, q in tqdm(zip(ids, questions), total=len(ids)):
        # print(q)
        # print(f"Processing ID: {id}")
        if id in already_answered_ids:
            print(f"Skipping already answered ID: {id}")
            continue
        if id == 1146: # weird ID that breaks llama
            continue
        
        prompt = get_prompt(prompt_type, few_shot_prompt, q)
        try:
            if llm_model == 'gemini':
                answer = query_gemini(prompt, id)
            elif llm_model == 'claude':
                answer = query_claude(prompt)
            elif llm_model == '4o':
                # answer = query_4o_multiturn(prompt)
                if prompt_type == 'multi_convo':
                    fact_prompt = get_prompt(prompt_type="fact_prompt", few_shot_prompt="", question=q)
                    
                    answer_prompt = get_prompt(prompt_type="answer_prompt_data", few_shot_prompt="", question=q)
                    answer = query_4o_multiconvo(fact_prompt=fact_prompt, answer_prompt=answer_prompt, extracted_question=q)
                else:
                    answer = query_4o(prompt)
                
            elif llm_model == 'llama3.18b':
                answer = query_llama(prompt)
            elif llm_model == 'llama3.170b':
                answer = query_llama_70b(prompt)
            else:
                raise ValueError(f"Unsupported LLM model: {llm_model}")
            # print(f"Answer for ID {id}: {answer}")
            
            answers.append(answer)
            questions_can_be_answered.append(q)
            ids_can_be_answered.append(id)

            # Save after each answer
            save_results(save_path, [id], [q], [answer], append=True)
        except Exception as e:
            print(f"Error processing question {id}: {str(e)}")
            continue
    
    return ids_can_be_answered, questions_can_be_answered, answers

# Driver

In [72]:
json_datasets = {'logical_deduction_seven_objects','reasoning_about_colored_objects', 'wikimultihopQA'}
jsonl_datasets = {'MultiArith', 'ASDiv', 'SVAMP', 'AQUA', 'p_GSM8K', 'StrategyQA', 'commonsenseQA', 'GSM8K', 'SPARTQA'}

In [None]:
def run_model(llm_model, prompt_type, few_shot_txt, sample_size):
    project_root = '/Users/log/Github/textual_grounding/'
    for dataset in jsonl_datasets:
        if few_shot_txt:
            fewshot_prompt_path = os.path.join(project_root, "prompt", dataset, few_shot_txt)
            
        save_dir = os.path.join(project_root, 'logan/results/final/VanillaCoT', dataset, f'{llm_model}')
        os.makedirs(save_dir, exist_ok=True)  # Ensure the directory exists
        save_path = os.path.join(save_dir, f'{prompt_type}_{few_shot_txt}_{dataset}_{llm_model}.csv')

        if dataset in json_datasets:
            data_path = os.path.join(project_root, 'data', dataset, 'test.json')
        else:
            data_path = os.path.join(project_root, 'data', dataset, 'test.jsonl')


        ids, questions = load_data_size_specific(data_path, sample_size=sample_size)
        if few_shot_txt:
            with open(fewshot_prompt_path, 'r') as file:
                few_shot_prompt = file.read()
        else:
            few_shot_prompt = ""

        initialize_save_file(save_path)
        already_answered_ids = load_already_answered_ids(save_path)

        ids_answered, questions_answered, answers = query_llm(
            llm_model=llm_model,
            ids=ids,
            questions=questions,
            few_shot_prompt=few_shot_prompt,
            prompt_type=prompt_type,
            save_path=save_path,
            already_answered_ids=already_answered_ids
        )

        print(f"Processing complete for {dataset}. {len(ids_answered)} new answers saved to {save_path}.")

In [None]:
llm_model = 'llama3.170b'
prompt_type = 'zero_shot_vanilla_cot'
few_shot_txt = None
sample_size = 400

run_models(llm_model, prompt_type, few_shot_txt, sample_size)