In [96]:
from groq import Groq
import os
import sys
import anthropic
import ollama
import random
import pandas as pd
from tqdm import tqdm
from google.generativeai.types import RequestOptions
from google.api_core import retry
from typing import List, Tuple
import json
from openai import OpenAI
import datetime
import openai
import time
import re

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)

if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from concurrent.futures import ThreadPoolExecutor, TimeoutError

# prompts

### Zero Shot - Vanilla CoT

In [97]:
zero_shot_vanilla_cot = """
Think through your answer step by step. Put the concise form of your final answer in curly brackets e.g. {A}, {True} or {3.0}.
"""

In [98]:
import re

def extract_last_sentence(text):
    # Split the text at the point where multiple-choice options start (e.g., "A)", "B)", etc.)
    parts = re.split(r'\n[A-Z]\)', text, flags=re.MULTILINE)
    
    if parts:
        # The first part contains everything before the options
        pre_options = parts[0]
        
        # Use a regex to find all sentences ending with ., ?, or !
        sentences = re.findall(r'[^.!?]*[.!?]', pre_options, re.DOTALL)
        
        if sentences:
            # Return the last sentence after stripping any extra spaces
            return sentences[-1].strip()
    
    # If splitting or finding sentences fails, return the original text
    return text

In [99]:
# import re

# def extract_last_sentence(text):
#     # Split the text into sentences, excluding option markers like "A)", "B)", etc.
#     sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)(?=\s[A-Z]\))', text)
    
#     # Return the last sentence after stripping any extra spaces.
#     return sentences[-1].strip() if sentences else text

In [100]:
def get_prompt(prompt_type: str, few_shot_prompt: str, question: str) -> str:
    # print(question)
    last_sentence = extract_last_sentence(question)
    
    # instruction = f"I want you to answer this question but your explanation should contain references referring back to the information in the question. To do that, first, re-generate the question with proper tags for key phrases, the key phrases that are most relevant to answering the question {last_sentence} and then generate your answers. The output format is as follow:\n\"
    
        # CHANGE FOR MATH
    instruction = f"I want you to answer this question but your explanation should contain references referring back to the information in the question. To do that, first, re-generate the question with proper tags for key phrases, the key phrases that are most relevant to answering the question and then generate your answers. The output format is as follow:\n\
                Reformatted Question: \
                    Answer:"
                    
    prompts = {
        "zero_shot_vanilla_cot": f"{question}\n{zero_shot_vanilla_cot}",
        "echo_fewshot": f"{few_shot_prompt}\n\nQuestion:{question}\nRepeat the question and then think through your answer step by step. Put the concise form of your final answer in curly brackets e.g. {{A}}, {{True}} or {{3.0}}.",
        "gcot": f"{few_shot_prompt}\n{question}\n{instruction}",
        "cot": f"{few_shot_prompt}\n{question}\nPlease generate your explanation first, then generate the final answer in the bracket as follow:\nAnswer: {{}}"
    }
    return prompts[prompt_type]

def save_results(save_path: str, ids: List[str], questions: List[str], answers: List[str], append: bool = False):
    df = pd.DataFrame({'id': ids, 'question': questions, 'answer': answers})
    if append and os.path.exists(save_path):
        df.to_csv(save_path, mode='a', index=False, header=False)
    else:
        df.to_csv(save_path, index=False)

def read_jsonl_file(filepath: str) -> List[dict]:
    data = []
    with open(filepath, 'r') as file:
        for line in file:
            json_obj = json.loads(line)
            data.append(json_obj)
    return data

def load_already_answered_ids(save_path: str) -> set:
    if os.path.exists(save_path):
        df = pd.read_csv(save_path)
        answered_ids = set(df['id'].tolist())
        # print(f"Loaded {len(answered_ids)} already answered IDs from: {save_path}")
        print(f"Already answered IDs: {answered_ids}")
        return answered_ids
    else:
        print(f"No existing save file found at: {save_path}. Starting fresh.")
        return set()

def initialize_save_file(save_path: str):
    if not os.path.exists(save_path):
        # Create an empty DataFrame with headers and save
        df = pd.DataFrame(columns=['id', 'question', 'answer'])
        df.to_csv(save_path, index=False)
        print(f"Initialized new save file with headers at: {save_path}")

def load_data_size_specific(data_path: str, sample_size: int = 0, random_seed: int = 0):
    random.seed(random_seed)

    if data_path.endswith('.jsonl'):
        data = read_jsonl_file(data_path)
    elif data_path.endswith('.json'):
        with open(data_path, 'r') as file:
            data = json.load(file)
    
    question_length = 0
    eligible_data = [x for x in data if len(x["question"]) >= question_length]
    
    if sample_size > 0 and sample_size < len(eligible_data):
        sampled_data = random.sample(eligible_data, sample_size)
    else:
        sampled_data = eligible_data
    
    # ids = [x["id"] for x in sampled_data]
    ids = [x["unique_id"] for x in sampled_data]
    questions = [x["question"] for x in sampled_data]
    
    return ids, questions

        
def get_longest_questions_and_ids(data_path, sample_size=200):
    # data = read_jsonl_file(data_path)
    if data_path.endswith('.json'):
        with open(data_path, 'r') as file:
            data = json.load(file)
    full_ids = [x["id"] for x in data]
    full_questions = [x["question"] for x in data]

    # Combine questions and IDs into a list of tuples
    full_questions_ids = list(zip(full_questions, full_ids))
    
    # Sort the tuples by the length of the questions
    sorted_full_questions_ids = sorted(full_questions_ids, key=lambda x: len(x[0]), reverse=True)
    
    # Select the shortest questions and their IDs
    sampled_data = random.sample(data, sample_size)
    longest_ = sorted_full_questions_ids[:min(sample_size, len(sorted_full_questions_ids))]

    # Separate them back into two lists
    longest_questions, longest_ids = zip(*longest_)

    # Convert to lists if necessary
    longest_questions = list(longest_questions)
    longest_ids = list(longest_ids)
    
    return longest_ids, longest_questions

def get_shortest_questions_and_ids(data_path, sample_size=200):
    if data_path.endswith('.json'):
        with open(data_path, 'r') as file:
            data = json.load(file)
    else:
        data = read_jsonl_file(data_path)
    full_ids = [x["id"] for x in data]
    full_questions = [x["question"] for x in data]

    # Combine questions and IDs into a list of tuples
    full_questions_ids = list(zip(full_questions, full_ids))
    
    # Sort the tuples by the length of the questions in ascending order
    sorted_full_questions_ids = sorted(full_questions_ids, key=lambda x: len(x[0]))
    
    # Select the shortest questions and their IDs
    sampled_data = random.sample(data, sample_size)
    shortest_ = sorted_full_questions_ids[:min(sample_size, len(sorted_full_questions_ids))]

    # Separate them back into two lists
    shortest_questions, shortest_ids = zip(*shortest_)

    # Convert to lists if necessary
    shortest_questions = list(shortest_questions)
    shortest_ids = list(shortest_ids)
    
    return shortest_ids, shortest_questions

    
def get_random_questions_and_ids(data_path, sample_size=200):

    data = read_jsonl_file(data_path)

    longest_questions, longest_ids = get_longest_questions_and_ids(data_path, sample_size)
    
    result_ids = []
    result_questions = []
    # Select the shortest questions and their IDs
    sampled_data = random.sample(data, sample_size*2)
    totalQuestions = 0
    for x in sampled_data:
        if totalQuestions >= sample_size:
            break
        if x["question"] not in longest_questions:
            result_ids.append(x["id"])
            result_questions.append(x["question"])
            totalQuestions += 1
    
    return result_ids, result_questions

In [101]:
def query_4o(prompt: str) -> str:
    client = OpenAI()

    completion = client.chat.completions.create(
        model="gpt-4o-2024-08-06",
        messages=[
            {
                "role": "user",
                "content": f"{prompt}"
            }
        ],
        temperature=0
    )

    return completion.choices[0].message.content

def query_llama(prompt):
    client = openai.OpenAI(
        api_key=os.environ.get("SAMBANOVA_API_KEY"),
        base_url="https://api.sambanova.ai/v1",
    )

    response = client.chat.completions.create(
        model='Meta-Llama-3.1-8B-Instruct',
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
        temperature=0.6, # Meta default
        top_p = 0.9 # Meta default
    )
    time.sleep(2)  # Pause execution for 2 seconds
    return response.choices[0].message.content

def query_llama_70b(prompt):
    client = openai.OpenAI(
        api_key=os.environ.get("SAMBANOVA_API_KEY"),
        base_url="https://api.sambanova.ai/v1",
    )

    response = client.chat.completions.create(
        model='Meta-Llama-3.1-70B-Instruct',
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
        temperature=0.6, # Meta default
        top_p = 0.9 # Meta default
    )
    time.sleep(2)  # Pause execution for 2 seconds
    return response.choices[0].message.content

def query_llama_405b(prompt):
    client = openai.OpenAI(
        api_key=os.environ.get("SAMBANOVA_API_KEY"),
        base_url="https://api.sambanova.ai/v1",
    )

    response = client.chat.completions.create(
        model='Meta-Llama-3.1-405B-Instruct',
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
        temperature=0.6, # Meta default
        top_p = 0.9 # Meta default
    )
    time.sleep(7)  # Pause execution for 2 seconds
    return response.choices[0].message.content

In [102]:
def query_llm(llm_model: str, ids: List[str], questions: List[str], few_shot_prompt: str, prompt_type: str, save_path: str, already_answered_ids: set) -> Tuple[List[str], List[str], List[str]]:
    answers = []
    ids_can_be_answered = []
    questions_can_be_answered = []
    
    for id, q in tqdm(zip(ids, questions), total=len(ids)):
        # print(q)
        # print(f"Processing ID: {id}")
        if id in already_answered_ids:
            print(f"Skipping: {id}", end=' ')
            continue
        # if id == 1146: # weird ID that breaks llama
        #     continue
        
        # print(few_shot_prompt)
        prompt = get_prompt(prompt_type, few_shot_prompt, q)
        # print(prompt)
        try:
            # print(prompt)
            if llm_model == 'gemini':
                answer = query_gemini(prompt, id)
            elif llm_model == 'claude':
                answer = query_claude(prompt)
            elif llm_model == '4o':
                # answer = query_4o_multiturn(prompt)
                if prompt_type == 'multi_convo':
                    fact_prompt = get_prompt(prompt_type="fact_prompt", few_shot_prompt="", question=q)
                    
                    answer_prompt = get_prompt(prompt_type="answer_prompt_data", few_shot_prompt="", question=q)
                    answer = query_4o_multiconvo(fact_prompt=fact_prompt, answer_prompt=answer_prompt, extracted_question=q)
                else:
                    answer = query_4o(prompt)
                
            elif llm_model == 'llama3.18b':
                answer = query_llama(prompt)
            elif llm_model == 'llama3.170b':
                answer = query_llama_70b(prompt)
            elif llm_model == 'llama3.1405b':
                answer = query_llama_405b(prompt)
            else:
                raise ValueError(f"Unsupported LLM model: {llm_model}")
            # print(f"Answer for ID {id}: {answer}")
            
            answers.append(answer)
            questions_can_be_answered.append(q)
            ids_can_be_answered.append(id)

            # Save after each answer
            save_results(save_path, [id], [q], [answer], append=True)
        except Exception as e:
            print(f"Error processing question {id}: {str(e)}")
            continue
    
    return ids_can_be_answered, questions_can_be_answered, answers

# Driver

In [103]:
json_datasets = ['logical_deduction_seven_objects','reasoning_about_colored_objects', 'squad']
jsonl_datasets = ['GSM8K', 'date', 'GSM_Plus', 'MultiArith', 'ASDiv', 'SVAMP', 'AQUA', 'p_GSM8K', 'StrategyQA', 'commonsenseQA','SPARTQA']
all_datasets = jsonl_datasets + json_datasets

In [104]:
def run_model(llm_model, prompt_type, few_shot_txt, sample_size, project_root, identifier, isRandom = False, isLongest = False, isShortest = False):
    # for dataset in all_datasets:
    for dataset in ['GSM_Symbolic']:
        print(f"------Processing dataset: {dataset}-------")
        if few_shot_txt:
            # fewshot_prompt_path = os.path.join(project_root, "prompt", dataset, few_shot_txt)
            fewshot_prompt_path = '/Users/log/Github/textual_grounding/prompt/GSM8K/fewshot_design_1_v4.txt'
        # print(fewshot_prompt_path)
        # continue 
        if prompt_type == 'zero_shot_vanilla_cot':
            save_dir = os.path.join(project_root, 'logan/results/final/VanillaCoT', dataset, f'{llm_model}')
        elif prompt_type == 'gcot':
            save_dir = os.path.join(project_root, 'logan/results/final/GCoT', dataset, f'{llm_model}')
        elif prompt_type == 'cot':
            save_dir = os.path.join(project_root, 'logan/results/final/fewshot_CoT', dataset, f'{llm_model}')
        else:
            save_dir = os.path.join(project_root, 'logan/results', dataset, f'{llm_model}')
        os.makedirs(save_dir, exist_ok=True)  # Ensure the directory exists
        save_path = os.path.join(save_dir, f'{prompt_type}_{identifier}_{few_shot_txt}_{dataset}_{llm_model}.csv')

        if dataset in json_datasets:
            data_path = os.path.join(project_root, 'data', dataset, 'test.json')
        else:
            # data_path = os.path.join(project_root, 'data', dataset, 'test.jsonl')
            data_path = '/Users/log/Github/textual_grounding/data/gsm_symbolic/main_formatted_test.jsonl'


        if isRandom:
            ids, questions = get_random_questions_and_ids(data_path, sample_size=sample_size)
        elif isLongest:
            ids, questions = get_longest_questions_and_ids(data_path, sample_size=sample_size)
        elif isShortest:
            ids, questions = get_shortest_questions_and_ids(data_path, sample_size=sample_size)
        else:
            ids, questions = load_data_size_specific(data_path, sample_size=sample_size)
        print(sorted(ids))

        if few_shot_txt:
            with open(fewshot_prompt_path, 'r') as file:
                few_shot_prompt = file.read()
        else:
            few_shot_prompt = ""

        initialize_save_file(save_path)
        already_answered_ids = load_already_answered_ids(save_path)

        ids_answered, questions_answered, answers = query_llm(
            llm_model=llm_model,
            ids=ids,
            questions=questions,
            few_shot_prompt=few_shot_prompt,
            prompt_type=prompt_type,
            save_path=save_path,
            already_answered_ids=already_answered_ids
        )

        print(f"Processing complete for {dataset}. {len(ids_answered)} new answers saved to {save_path}.")

In [None]:
project_root = '/Users/log/Github/textual_grounding/'
prompt_type = 'gcot'
few_shot_txt = '_'
sample_size = 5000
isLongest = False 
isRandom = False
isShortest = False
identifier = 'main_test'

# models = ['llama3.18b', 'llama3.170b', 'llama3.1405b']
models = ['llama3.18b']
for model in models:
    run_model(model, prompt_type, few_shot_txt, sample_size, project_root, 
              identifier=identifier, 
              isLongest=isLongest, 
              isRandom=isRandom,
              isShortest=isShortest)

------Processing dataset: GSM_Symbolic-------
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212

  0%|          | 0/5000 [00:00<?, ?it/s]

Skipping: 0 Skipping: 1 Skipping: 2 Skipping: 3 Skipping: 4 Skipping: 5 Skipping: 6 Skipping: 7 Skipping: 8 Skipping: 9 Skipping: 10 Skipping: 11 Skipping: 12 Skipping: 13 Skipping: 14 Skipping: 15 Skipping: 16 Skipping: 17 Skipping: 18 Skipping: 19 Skipping: 20 Skipping: 21 Skipping: 22 Skipping: 23 Skipping: 24 Skipping: 25 Skipping: 26 Skipping: 27 Skipping: 28 Skipping: 29 Skipping: 30 Skipping: 31 Skipping: 32 Skipping: 33 Skipping: 34 Skipping: 35 Skipping: 36 Skipping: 37 Skipping: 38 Skipping: 39 Skipping: 40 Skipping: 41 Skipping: 42 Skipping: 43 Skipping: 44 Skipping: 45 Skipping: 46 Skipping: 47 Skipping: 48 Skipping: 49 Skipping: 50 Skipping: 51 Skipping: 52 Skipping: 53 Skipping: 54 Skipping: 55 Skipping: 56 Skipping: 57 Skipping: 58 Skipping: 59 Skipping: 60 Skipping: 61 Skipping: 62 Skipping: 63 Skipping: 64 Skipping: 65 Skipping: 66 Skipping: 67 Skipping: 68 Skipping: 69 Skipping: 70 Skipping: 71 Skipping: 72 Skipping: 73 Skipping: 74 Skipping: 75 Skipping: 76 Skipping:

 77%|███████▋  | 3826/5000 [3:51:14<1:21:16,  4.15s/it] 

## Batch

In [None]:
import os
import json
from pathlib import Path
from agents.batch_api_agents import prepare_batch_input, batch_api_agent
from openai import OpenAI
import importlib
import agents.batch_api_agents as batch_agents

# Reload the module to ensure the latest changes are loaded
importlib.reload(batch_agents)

llm_model = 'claude-3-5-sonnet-20240620'
project_root = '/Users/log/Github/textual_grounding/'
prompt_type = 'zero_shot_vanilla_cot'
few_shot_txt = None
sample_size = 2
json_datasets = ['logical_deduction_seven_objects','reasoning_about_colored_objects']
jsonl_datasets = ['GSM_Plus', 'MultiArith', 'SVAMP', 'p_GSM8K', 'StrategyQA', 'commonsenseQA','SPARTQA']
all_datasets = jsonl_datasets + json_datasets
all_datasets = ['MultiArith']
for dataset in all_datasets:
    if dataset in json_datasets:
        data_path = os.path.join(project_root, 'data', dataset, 'test.json')
    else:
        data_path = os.path.join(project_root, 'data', dataset, 'test.jsonl')
        
    save_dir = os.path.join(project_root, 'logan/results/final/VanillaCoT', dataset, f'{llm_model}')
    batch_dir = os.path.join(project_root, 'logan/batch_files/VanillaCoT', dataset, f'{llm_model}')
    os.makedirs(save_dir, exist_ok=True)  # Ensure the directory exists
    os.makedirs(batch_dir, exist_ok=True)  # Ensure the directory exists

    batch_results_path = os.path.join(save_dir, f'{prompt_type}_{few_shot_txt}_{dataset}_{llm_model}.jsonl')   
    batch_output_file = os.path.join(batch_dir, f'{prompt_type}_{few_shot_txt}_{dataset}_{llm_model}.jsonl')   

    ids, questions = load_data_size_specific(data_path, sample_size=sample_size)
    if few_shot_txt:
        with open(fewshot_prompt_path, 'r') as file:
            few_shot_prompt = file.read()
    else:
        few_shot_prompt = ""
    prompts = []
    for question in questions:
        prompt = get_prompt(prompt_type, few_shot_prompt, question)
        # print(prompt)
        prompts.append(prompt)

    # tasks = batch_agents.prepare_batch_input(
    #     llm_model=llm_model,
    #     ids=ids,
    #     prompts=prompts,
    #     batch_output_file=batch_output_file
    # )

    #Execute the batch processing with GPT-4 and save results
    batch_agents.batch_api_agent(
        llm_model=llm_model,
        ids=ids,
        prompts=prompts,
        batch_output_file=batch_output_file,
        batch_results_file=batch_results_path
    )
