## Loading Dataset

In [1]:
from datasets import load_dataset
import pandas as pd

dataset = load_dataset(
    "hotpot_qa", 
    "distractor",
    split="validation",
    cache_dir="/mnt/d/datasets/hotpot_qa"
)

hf_df = pd.DataFrame(dataset)
hf_df = hf_df.sample(n=300, random_state=42)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import os
from dotenv import load_dotenv, find_dotenv

load_dotenv(find_dotenv())

from agents.prompt_and_workflow_orchestration.orchestration import OrchestrationAgent
import pandas as pd
from langchain_openai import ChatOpenAI
from langchain_ollama import ChatOllama
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.callbacks import UsageMetadataCallbackHandler


MODEL_NAME = "gemini-2.0-flash"

EXPERIMENT_NAME = f"hotpot_qa_cot_{MODEL_NAME}"

project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
generated_data_path = os.path.join(project_root, 'data', 'generated', f'{EXPERIMENT_NAME}.parquet')


# planner_llm = ChatOllama(model="qwen3:8b", temperature=0.6)

planner_llm = ChatGoogleGenerativeAI(model=MODEL_NAME, temperature=0.7)
high_temp_llm = ChatGoogleGenerativeAI(model=MODEL_NAME, temperature=0.8)
medium_temp_llm = ChatGoogleGenerativeAI(model=MODEL_NAME, temperature=0.5)
low_temp_llm = ChatGoogleGenerativeAI(model=MODEL_NAME, temperature=0.2)

agent = OrchestrationAgent(
    planner_llm=planner_llm,
    high_temp_llm=high_temp_llm,
    medium_temp_llm=medium_temp_llm,
    low_temp_llm=low_temp_llm
)

E0000 00:00:1760067705.783884   26615 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.
E0000 00:00:1760067705.790827   26615 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.
E0000 00:00:1760067705.792177   26615 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.
E0000 00:00:1760067705.793620   26615 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.


## Helper Functions

In [3]:
def calculate_cost(model_name: str, input_tokens: int, output_tokens: int) -> float:
    """
    Calculate cost based on Gemini model and token usage.
    Uses public pricing as of mid-2025.
    """
    # Pricing data keyed by full model name or prefix
    pricing = {
        "gemini-2.0-flash": {"input": 0.10, "output": 0.40},
        "gemini-2.5-flash-lite": {"input": 0.10, "output": 0.40},
        # Use “pro” special logic below for gemini-2.5-pro
    }

    # Normalize model name to lower
    m = model_name.lower()

    # Special case: gemini-2.5-pro tiered pricing
    if m.startswith("gemini-2.5-pro"):
        # threshold check on input token count
        if input_tokens > 200_000:
            in_rate = 2.50
            out_rate = 15.00
        else:
            in_rate = 1.25
            out_rate = 10.00
        return (input_tokens / 1_000_000) * in_rate + (output_tokens / 1_000_000) * out_rate

    # Other known models
    if m in pricing:
        in_rate = pricing[m]["input"]
        out_rate = pricing[m]["output"]
    else:
        # fallback / unknown handling
        print(f"Warning: Unknown model '{model_name}'. Using zero cost.")
        in_rate = 0.0
        out_rate = 0.0

    return (input_tokens / 1_000_000) * in_rate + (output_tokens / 1_000_000) * out_rate


## Preparing Dataset


In [4]:
from contextlib import redirect_stdout
from io import StringIO
import time
from tqdm import tqdm

generated_dataset = []
hf_df = hf_df.reset_index(drop=True)

for i, item in tqdm(hf_df.iterrows(), total=len(hf_df), desc="Processing"):
    question = item["question"]
    answer = item["answer"]
    context = item["context"]["sentences"]

    start_time = time.time()

    callback = UsageMetadataCallbackHandler()

    try:
        # Silence the agent's output
        with redirect_stdout(StringIO()):
            response = await agent.generate_response_async(query=question, context=context, callback=callback)

        end_time = time.time()
        latency = end_time - start_time

        input_tokens = callback.usage_metadata[MODEL_NAME]["input_tokens"]
        output_tokens = callback.usage_metadata[MODEL_NAME]["output_tokens"]
        total_tokens = callback.usage_metadata[MODEL_NAME]["total_tokens"]
        cost = calculate_cost(MODEL_NAME, input_tokens, output_tokens)

        generated_dataset.append(
            {
                "user_input": question,
                "contexts": [str(item) for item in context],
                "response": response["content"].strip(),
                "ground_truth": answer,
                "workflow_plan": response["workflow_plan"],
                "planner_reasoning": response["planner_reasoning"],
                "custom_prompts": response["custom_prompts"],
                "latency": latency,
                "input_tokens": input_tokens,
                "output_tokens": output_tokens,
                "total_tokens": total_tokens,
                "cost": cost,
            }
        )
    
    except Exception as e:
        # Catch any other unexpected errors and continue
        print(f"Unexpected error for item {i}: {e}")

        # Add entry with error information
        generated_dataset.append(
            {
                "user_input": question,
                "contexts": [str(item) for item in context],
                "response": f"ERROR: {str(e)}",
                "ground_truth": answer,
                "workflow_plan": None,
                "planner_reasoning": None,
                "custom_prompts": None,
                "latency": 0,
                "input_tokens": 0,
                "output_tokens": 0,
                "total_tokens": 0,
                "cost": 0,
            }
        )


df = pd.DataFrame(generated_dataset)
df.to_parquet(generated_data_path, index=False)

Processing:   0%|          | 0/300 [00:00<?, ?it/s]E0000 00:00:1760067705.839411   26615 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.
Processing:   3%|▎         | 10/300 [03:42<2:07:10, 26.31s/it]

Unexpected error for item 9: 'NoneType' object has no attribute 'reasoning'


Processing:  17%|█▋        | 50/300 [18:38<1:38:11, 23.57s/it]

Unexpected error for item 49: 'NoneType' object has no attribute 'reasoning'


Processing: 100%|██████████| 300/300 [1:51:54<00:00, 22.38s/it]  
