In [1]:
# | output: false
# | echo: false

%load_ext autoreload
%autoreload 2

In [2]:
# | output: false
# | echo: false

import nest_asyncio

nest_asyncio.apply()

In [3]:
# | output: false
# | echo: false
import asyncio
import json
import os
from asyncio import Semaphore
from enum import Enum
from functools import partial
from textwrap import dedent
from typing import Callable, Dict, List, Literal

import google.generativeai as genai
import instructor
import numpy as np
import pandas as pd
from datasets import load_dataset
from dotenv import load_dotenv
from instructor import Mode
from langsmith import traceable
from pydantic import BaseModel, Field
from scipy import stats

np.random.seed(42)

load_dotenv()

MODEL_NAME = "gemini-1.5-flash-002"
USE_SAMPLE = True
SAMPLE_SIZE = 200
MAX_CONCURRENCY = 50


genai.configure(api_key=os.environ["GEMINI_API_KEY"])

client_tools = instructor.from_gemini(
    client=genai.GenerativeModel(
        model_name=MODEL_NAME,
    ),
    mode=Mode.GEMINI_TOOLS,
    use_async=True,
)

client_json = instructor.from_gemini(
    client=genai.GenerativeModel(
        model_name=MODEL_NAME,
    ),
    mode=Mode.GEMINI_JSON,
    use_async=True,
)

## Setup

In [4]:
# | output: false
# | echo: false


class PromptType(Enum):
    WITH_TOOL_CALLS = "with_so_tool_calls"
    WITH_STRICT_TOOL_CALLS = "with_so_strict_tool_calls"


class ClientConfig(BaseModel):
    name: str
    col_name: str
    score_col_name: str


CONFIGS = [
    ClientConfig(
        name=PromptType.WITH_TOOL_CALLS.value,
        col_name=f"response_{PromptType.WITH_TOOL_CALLS.value}",
        score_col_name=f"score_{PromptType.WITH_TOOL_CALLS.value}",
    ),
    ClientConfig(
        name=PromptType.WITH_STRICT_TOOL_CALLS.value,
        col_name=f"response_{PromptType.WITH_STRICT_TOOL_CALLS.value}",
        score_col_name=f"score_{PromptType.WITH_STRICT_TOOL_CALLS.value}",
    ),
]

In [5]:
# | output: false
# | echo: false


class LLMEvaluator:
    def __init__(
        self,
        configs: List[ClientConfig],
        create_prompt_fn: Callable,
        response_model: BaseModel,
        concurrency: int = MAX_CONCURRENCY,
    ):
        self.configs = configs
        self.create_prompt_fn = create_prompt_fn
        self.response_model = response_model
        self.concurrency = concurrency

        self.error_counts: Dict[str, int] = {config.name: 0 for config in self.configs}
        self.key_order_errors: Dict[str, int] = {
            config.name: 0 for config in self.configs
        }
        self.key_missing_errors: Dict[str, int] = {
            config.name: 0 for config in self.configs
        }

    @traceable(run_type="prompt")
    def create_prompt(
        self,
        question: str,
        prompt_type: str,
    ) -> List[dict]:
        return self.create_prompt_fn(
            question=question,
            prompt_type=prompt_type,
            response_model=self.response_model,
        )

    @traceable(run_type="parser")
    def parse_response(
        self,
        response: BaseModel,
        prompt_type: str,
    ) -> str | int:
        if prompt_type == PromptType.WITH_TOOL_CALLS.value:
            raw_response = str(
                response._raw_response.candidates[0]
                .content.parts[0]
                .function_call.args.pb
            )
        elif prompt_type == PromptType.WITH_STRICT_TOOL_CALLS.value:
            raw_response = response._raw_response.candidates[0].content.parts[0].text

        reasoning_index = raw_response.find("reasoning")
        answer_index = raw_response.find("answer")

        if reasoning_index == -1 or answer_index == -1:
            self.key_missing_errors[prompt_type] += 1
        elif reasoning_index > answer_index:
            self.key_order_errors[prompt_type] += 1

        return response.final_answer

    @traceable(run_type="chain")
    async def call_llm(
        self,
        config: ClientConfig,
        question: str,
    ) -> BaseModel:
        if config.name == PromptType.WITH_TOOL_CALLS.value:
            response = await client_tools.chat.completions.create(
                messages=self.create_prompt(question=question, prompt_type=config.name),
                response_model=self.response_model,
            )
            return response
        elif config.name == PromptType.WITH_STRICT_TOOL_CALLS.value:
            response = await client_json.chat.completions.create(
                messages=self.create_prompt(question=question, prompt_type=config.name),
                response_model=self.response_model,
            )
            return response
        else:
            raise ValueError(f"Invalid config name: {config.name}")

    async def process_question(
        self,
        question: str,
        config: ClientConfig,
        semaphore: Semaphore,
        max_attempts: int = 3,
    ) -> str | int | None:
        async with semaphore:
            for _ in range(max_attempts):
                try:
                    answer = await self.call_llm(
                        config=config,
                        question=question,
                    )
                    parsed_answer = self.parse_response(answer, config.name)
                    if not parsed_answer:
                        self.error_counts[config.name] += 1
                    return parsed_answer
                except Exception:
                    self.error_counts[config.name] += 1
                    print(f"{config.name}, {question[:10]}: Retrying...")
                    await asyncio.sleep(1)
                    continue
            print(
                f"{config.name}, {question[:10]}: Failed to process question after {max_attempts} attempts. Set answer to null."
            )
        return None

    @traceable(run_type="chain")
    async def process_questions(
        self,
        run_name: str,
        questions: List[dict],
        config: ClientConfig,
    ) -> List[str | int | None]:
        semaphore = Semaphore(self.concurrency)
        tasks = [
            self.process_question(
                question=question["question"],
                config=config,
                semaphore=semaphore,
            )
            for question in questions
        ]
        results = await asyncio.gather(*tasks, return_exceptions=True)
        return results

    def generate_outputs(self, questions: List[dict]) -> pd.DataFrame:
        df = pd.DataFrame(
            {
                "id": [i for i in range(len(questions))],
                "question": [question["question"] for question in questions],
                "answer": [question["answer"] for question in questions],
            }
        )
        for config in self.configs:
            responses = asyncio.run(
                self.process_questions(
                    run_name=config.name,
                    questions=questions,
                    config=config,
                )
            )
            df[config.col_name] = responses
        return df

    def evaluate_outputs(self, df: pd.DataFrame) -> pd.DataFrame:
        df_copy = df.copy()
        for config in self.configs:
            df_copy[config.score_col_name] = (
                df_copy["answer"] == df_copy[config.col_name]
            ) * 1
        return df_copy

    def calculate_confidence_intervals(
        self, df: pd.DataFrame, conf_level: float = 0.95
    ) -> None:
        print(
            f"Calculating confidence intervals ({conf_level}) with {len(df)} observations:"
        )
        for config in self.configs:
            score_col = config.score_col_name
            scores = df[score_col]

            if len(scores) == 0:
                print(f"No scores available for {score_col}")
                continue

            mean_score = scores.mean()
            se_score = scores.std() / np.sqrt(len(scores))

            z_score = stats.norm.ppf((1 + conf_level) / 2)
            margin_error = z_score * se_score
            ci = [
                max(0.0, mean_score - margin_error),
                min(1.0, mean_score + margin_error),
            ]
            print(
                f"{score_col} - Mean: {mean_score * 100:.2f}% CI: {ci[0] * 100:.2f}% - {ci[1] * 100:.2f}%"
            )
        print()

    def run_paired_t_test(self, df: pd.DataFrame) -> None:
        scores = {}

        for config in self.configs:
            score_col = config.score_col_name
            scores[score_col] = df[score_col] * 1

        for score_col_1, score_col_2 in [
            ("score_without_so", "score_with_so_tool_calls"),
            ("score_without_so", "score_with_so_strict_tool_calls"),
        ]:
            if score_col_1 in scores and score_col_2 in scores:
                t_stat, p_value = stats.ttest_rel(
                    scores[score_col_1], scores[score_col_2]
                )
                print(f"{score_col_1} vs {score_col_2}")
                print(f"t-statistic: {t_stat}, p-value: {p_value}")

    def report_error_counts(self) -> None:
        print("Error counts:")
        for config in self.configs:
            name = config.name
            errors = self.error_counts.get(name, 0)
            key_order = self.key_order_errors.get(name, 0)
            print(f"- {name}: {errors} processing errors, {key_order} key order errors")
        print()

## GSM8K

### Setup

In [6]:
# | output: false
# | echo: false


class ResponseGSM8K(BaseModel):
    reasoning: str = Field(description="step by step reasoning about the answer")
    answer: int = Field(description="final answer")


def create_prompt_gsm8k(
    prompt_type: str,
    question: str,
    response_model: ResponseGSM8K | None = None,
    zero_shot: bool = False,
) -> str:
    system_prompt = dedent("""
    You are an expert in solving grade school math tasks. You will be presented with a grade-school math word problem and be asked to solve it.

    First, provide your step by step reasoning in the "reasoning" field. Then, in the "answer" field, provide an integer that corresponds to the correct answer to the question. Don't include any other text in the "answer" field. 
    """)

    examples = [
        (
            "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?",
            "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6.",
            6,
        ),
        (
            "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?",
            "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.",
            5,
        ),
        (
            "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?",
            "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39.",
            39,
        ),
    ]

    if not zero_shot:
        system_prompt += "\nExamples:" if examples else ""
        for i, (example_q, example_reason, example_ans) in enumerate(examples):
            system_prompt += f"\n\n**{i+1}**\nQuestion: {example_q}"
            if prompt_type in [
                PromptType.WITH_TOOL_CALLS.value,
                PromptType.WITH_STRICT_TOOL_CALLS.value,
            ]:
                response = (
                    f'{{"reasoning": "{example_reason}", "answer": {example_ans}}}'
                )
            else:
                response = f"{example_reason}\nANSWER: {example_ans}"
            system_prompt += f"\nAssistant Response:\n{response}"

    messages = [
        {
            "role": "system",
            "content": system_prompt,
        },
        {
            "role": "user",
            "content": f"Question: {question}",
        },
    ]

    return messages


create_prompt_gsm8k_zero_shot = partial(create_prompt_gsm8k, zero_shot=True)

In [7]:
# | output: false
# | echo: false

dataset = load_dataset("gsm8k", "main")
evals = [
    {
        "question": d["question"],
        "answer": int(d["answer"].split("#### ")[1].replace(",", "").strip()),
    }
    for d in dataset["test"]
]

if USE_SAMPLE:
    evals = evals[:SAMPLE_SIZE]

### Zero-shot

In [8]:
# | output: false
# | echo: false

evaluator = LLMEvaluator(
    configs=CONFIGS,
    create_prompt_fn=create_prompt_gsm8k_zero_shot,
    response_model=ResponseGSM8K,
)

df = evaluator.generate_outputs(evals)
df_results = evaluator.evaluate_outputs(df)

evaluator.report_error_counts()
evaluator.calculate_confidence_intervals(df_results)
evaluator.run_paired_t_test(df_results)

Error counts:
- with_so_tool_calls: 1 processing errors, 0 key order errors
- with_so_strict_tool_calls: 2 processing errors, 0 key order errors

Calculating confidence intervals (0.95) with 200 observations:
score_with_so_tool_calls - Mean: 39.50% CI: 32.71% - 46.29%
score_with_so_strict_tool_calls - Mean: 95.00% CI: 91.97% - 98.03%

