In [1]:
# | output: false
# | echo: false

%load_ext autoreload
%autoreload 2

In [2]:
# | output: false
# | echo: false

import nest_asyncio

nest_asyncio.apply()

In [3]:
# | output: false
# | echo: false
import asyncio
import json
import os
from asyncio import Semaphore
from enum import Enum
from functools import partial
from textwrap import dedent
from typing import Callable, Dict, List, Literal

import google.generativeai as genai
import numpy as np
import pandas as pd
from datasets import load_dataset
from dotenv import load_dotenv
from langsmith import traceable
from pydantic import BaseModel, Field
from scipy import stats

np.random.seed(42)

load_dotenv()

MODEL_NAME = "gemini-1.5-flash-002"
USE_SAMPLE = False
SAMPLE_SIZE = 5
MAX_CONCURRENCY = 20

genai.configure(api_key=os.environ["GEMINI_API_KEY"])

## Setup

In [4]:
# | output: false
# | echo: false


class PromptType(Enum):
    WITHOUT_STRUCTURED_OUTPUT = "without_so"
    WITH_FUNCTION_CALL = "with_so_function_call"
    WITH_RESPONSE_SCHEMA = "with_so_response_schema"
    WITH_MIME_TYPE = "with_so_mime_type"


class ClientConfig(BaseModel):
    name: str
    col_name: str
    score_col_name: str


CONFIGS = [
    ClientConfig(
        name=PromptType.WITHOUT_STRUCTURED_OUTPUT.value,
        col_name=f"response_{PromptType.WITHOUT_STRUCTURED_OUTPUT.value}",
        score_col_name=f"score_{PromptType.WITHOUT_STRUCTURED_OUTPUT.value}",
    ),
    # ClientConfig(
    #     name=PromptType.WITH_FUNCTION_CALL.value,
    #     col_name=f"response_{PromptType.WITH_FUNCTION_CALL.value}",
    #     score_col_name=f"score_{PromptType.WITH_FUNCTION_CALL.value}",
    # ),
    ClientConfig(
        name=PromptType.WITH_RESPONSE_SCHEMA.value,
        col_name=f"response_{PromptType.WITH_RESPONSE_SCHEMA.value}",
        score_col_name=f"score_{PromptType.WITH_RESPONSE_SCHEMA.value}",
    ),
    ClientConfig(
        name=PromptType.WITH_MIME_TYPE.value,
        col_name=f"response_{PromptType.WITH_MIME_TYPE.value}",
        score_col_name=f"score_{PromptType.WITH_MIME_TYPE.value}",
    ),
]

In [5]:
# | output: false
# | echo: false


class LLMEvaluator:
    def __init__(
        self,
        configs: List[ClientConfig],
        create_prompt_fn: Callable,
        parse_response_fn: Callable,
        response_model: BaseModel,
        concurrency: int = MAX_CONCURRENCY,
    ):
        self.configs = configs
        self.create_prompt_fn = create_prompt_fn
        self.parse_response_fn = parse_response_fn
        self.response_model = response_model
        self.concurrency = concurrency

        self.error_counts: Dict[str, int] = {config.name: 0 for config in self.configs}
        self.key_order_errors: Dict[str, int] = {
            config.name: 0 for config in self.configs
        }
        self.key_missing_errors: Dict[str, int] = {
            config.name: 0 for config in self.configs
        }

    def _create_tool_schema(self, config_name: str) -> List[genai.protos.Tool]:
        model_schema = self.response_model.model_json_schema()
        properties = {}
        for key, value in model_schema.get("properties", {}).items():
            properties[key] = genai.protos.Schema(
                type=genai.protos.Type.STRING
                if value.get("type") == "string"
                else genai.protos.Type.INTEGER
                if value.get("type") == "integer"
                else genai.protos.Type.NUMBER,
            )

        if config_name == PromptType.WITH_RESPONSE_SCHEMA.value:
            return genai.protos.Schema(
                type=genai.protos.Type.OBJECT,
                properties=properties,
                required=list(properties.keys()),
            )
        elif config_name == PromptType.WITH_FUNCTION_CALL.value:
            return genai.protos.Tool(
                function_declarations=[
                    genai.protos.FunctionDeclaration(
                        name=model_schema["title"],
                        description=f"Correctly extracted `{model_schema['title']}` with all the required parameters",
                        parameters=genai.protos.Schema(
                            type=genai.protos.Type.OBJECT,
                            properties=properties,
                            required=list(properties.keys()),
                        ),
                    )
                ],
            )
        else:
            raise ValueError(f"Invalid config name: {config_name}")

    @traceable(run_type="prompt")
    def create_prompt(
        self,
        prompt_type: str,
    ) -> str:
        return self.create_prompt_fn(
            prompt_type=prompt_type,
            response_model=self.response_model,
        )

    @traceable(run_type="parser")
    def parse_response(
        self,
        response: str,
        prompt_type: str,
    ) -> str | int:
        if prompt_type in [
            PromptType.WITH_FUNCTION_CALL.value,
            PromptType.WITH_RESPONSE_SCHEMA.value,
            PromptType.WITH_MIME_TYPE.value,
        ]:
            reasoning_index = response.find("reasoning")
            answer_index = response.find("solution")

            if reasoning_index == -1 or answer_index == -1:
                self.key_missing_errors[prompt_type] += 1
            elif reasoning_index > answer_index:
                self.key_order_errors[prompt_type] += 1

        return self.parse_response_fn(response, prompt_type)

    @traceable(run_type="llm")
    async def call_gemini(
        self,
        question: str,
        **kwargs,
    ) -> str:
        model = genai.GenerativeModel(**kwargs)
        response = await model.generate_content_async(question)
        return response

    @traceable(run_type="chain")
    async def call_llm(
        self,
        config: ClientConfig,
        question: str,
    ) -> str:
        system_prompt = self.create_prompt(prompt_type=config.name)
        if config.name == PromptType.WITHOUT_STRUCTURED_OUTPUT.value:
            response = await self.call_gemini(
                question=question,
                model_name=MODEL_NAME,
                generation_config=genai.GenerationConfig(
                    response_mime_type="text/plain"
                ),
                system_instruction=system_prompt,
            )
            return response.text
        elif config.name == PromptType.WITH_MIME_TYPE.value:
            response = await self.call_gemini(
                question=question,
                model_name=MODEL_NAME,
                generation_config=genai.GenerationConfig(
                    response_mime_type="application/json",
                ),
                system_instruction=system_prompt,
            )
            return response.text
        elif config.name == PromptType.WITH_FUNCTION_CALL.value:
            response = await self.call_gemini(
                question=question,
                model_name=MODEL_NAME,
                generation_config=genai.GenerationConfig(
                    response_mime_type="text/plain",
                ),
                tools=[self._create_tool_schema(config.name)],
                tool_config={"function_calling_config": "ANY"},
                system_instruction=system_prompt,
            )
            for part in response.parts:
                if fn := part.function_call:
                    return json.dumps(dict(fn.args))
            return None
        elif config.name == PromptType.WITH_RESPONSE_SCHEMA.value:
            response = await self.call_gemini(
                question=question,
                model_name=MODEL_NAME,
                generation_config=genai.GenerationConfig(
                    response_mime_type="application/json",
                    response_schema=self._create_tool_schema(config.name),
                ),
                system_instruction=system_prompt,
            )
            return response.text
        else:
            raise ValueError(f"Invalid config name: {config.name}")

    async def process_question(
        self,
        question: str,
        config: ClientConfig,
        semaphore: Semaphore,
        max_attempts: int = 3,
    ) -> str | int | None:
        async with semaphore:
            for _ in range(max_attempts):
                try:
                    answer = await self.call_llm(
                        config=config,
                        question=question,
                    )
                    parsed_answer = self.parse_response(answer, config.name)
                    if not parsed_answer:
                        self.error_counts[config.name] += 1
                    return parsed_answer
                except Exception:
                    self.error_counts[config.name] += 1
                    print(f"{config.name}, {question[:10]}: Retrying...")
                    await asyncio.sleep(1)
                    continue
            print(
                f"{config.name}, {question[:10]}: Failed to process question after {max_attempts} attempts. Set answer to null."
            )
        return None

    @traceable(run_type="chain")
    async def process_questions(
        self,
        run_name: str,
        questions: List[dict],
        config: ClientConfig,
    ) -> List[str | int | None]:
        semaphore = Semaphore(self.concurrency)
        tasks = [
            self.process_question(
                question=question["question"],
                config=config,
                semaphore=semaphore,
            )
            for question in questions
        ]
        results = await asyncio.gather(*tasks, return_exceptions=True)
        return results

    def generate_outputs(self, questions: List[dict]) -> pd.DataFrame:
        df = pd.DataFrame(
            {
                "id": [i for i in range(len(questions))],
                "question": [question["question"] for question in questions],
                "answer": [question["answer"] for question in questions],
            }
        )
        for config in self.configs:
            responses = asyncio.run(
                self.process_questions(
                    run_name=config.name,
                    questions=questions,
                    config=config,
                )
            )
            df[config.col_name] = responses
        return df

    def evaluate_outputs(self, df: pd.DataFrame) -> pd.DataFrame:
        df_copy = df.copy()
        for config in self.configs:
            df_copy[config.score_col_name] = (
                df_copy["answer"] == df_copy[config.col_name]
            ) * 1
        return df_copy

    def calculate_confidence_intervals(
        self, df: pd.DataFrame, conf_level: float = 0.95
    ) -> None:
        print(
            f"Calculating confidence intervals ({conf_level}) with {len(df)} observations:"
        )
        for config in self.configs:
            score_col = config.score_col_name
            scores = df[score_col]

            if len(scores) == 0:
                print(f"No scores available for {score_col}")
                continue

            mean_score = scores.mean()
            se_score = scores.std() / np.sqrt(len(scores))

            z_score = stats.norm.ppf((1 + conf_level) / 2)
            margin_error = z_score * se_score
            ci = [
                max(0.0, mean_score - margin_error),
                min(1.0, mean_score + margin_error),
            ]
            print(
                f"{score_col} - Mean: {mean_score * 100:.2f}% CI: {ci[0] * 100:.2f}% - {ci[1] * 100:.2f}%"
            )
        print()

    def run_paired_t_test(self, df: pd.DataFrame) -> None:
        scores = {}

        for config in self.configs:
            score_col = config.score_col_name
            scores[score_col] = df[score_col] * 1

        for score_col_1, score_col_2 in [
            ("score_without_so", "score_with_so_mime_type"),
            ("score_without_so", "score_with_so_response_schema"),
            ("score_without_so", "score_with_so_function_call"),
        ]:
            if score_col_1 in scores and score_col_2 in scores:
                t_stat, p_value = stats.ttest_rel(
                    scores[score_col_1], scores[score_col_2]
                )
                print(f"{score_col_1} vs {score_col_2}")
                print(f"t-statistic: {t_stat}, p-value: {p_value}")

    def report_error_counts(self) -> None:
        print("Error counts:")
        for config in self.configs:
            name = config.name
            errors = self.error_counts.get(name, 0)
            key_order = self.key_order_errors.get(name, 0)
            print(f"- {name}: {errors} processing errors, {key_order} key order errors")
        print()

## GSM8K

### Setup

In [6]:
# | output: false
# | echo: false


class ResponseGSM8K(BaseModel):
    reasoning: str = Field(description="step by step reasoning about the answer")
    solution: int = Field(description="final answer")


def create_prompt_gsm8k(
    prompt_type: str,
    response_model: ResponseGSM8K | None = None,
    zero_shot: bool = False,
) -> str:
    if prompt_type in [
        PromptType.WITH_FUNCTION_CALL.value,
        PromptType.WITH_MIME_TYPE.value,
        PromptType.WITH_RESPONSE_SCHEMA.value,
    ]:
        system_prompt = dedent("""
        You are an expert in solving grade school math tasks. You will be presented with a grade-school math word problem and be asked to solve it.

        You will always respond with JSON matching the following schema:

        class Response(BaseModel):
            reasoning: str = Field(description="step by step reasoning about the answer")
            solution: int = Field(description="final answer")

        First, provide your step by step reasoning in the "reasoning" field. Then, in the "solution" field, provide an integer that corresponds to the correct answer to the question. Don't include any other text in the "solution" field. 
        """)
    else:
        system_prompt = dedent("""
        You are an expert in solving grade school math tasks. You will be presented with a grade-school math word problem and be asked to solve it.
        
        You will always respond in the following format:
        
        <str, reasoning about the answer>
        SOLUTION: <int, final answer>
        
        First, provide your step by step reasoning. Then, in SOLUTION, provide an integer that corresponds to the correct answer to the question. Don't include any other text in SOLUTION.
        """)

    examples = [
        (
            "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?",
            "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6.",
            6,
        ),
        (
            "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?",
            "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.",
            5,
        ),
        (
            "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?",
            "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39.",
            39,
        ),
    ]

    if not zero_shot:
        system_prompt += "\nExamples:" if examples else ""
        for i, (example_q, example_reason, example_ans) in enumerate(examples):
            system_prompt += f"\n\n**{i+1}**\nQuestion: {example_q}"
            if prompt_type in [
                PromptType.WITH_FUNCTION_CALL.value,
                PromptType.WITH_RESPONSE_SCHEMA.value,
                PromptType.WITH_MIME_TYPE.value,
            ]:
                response = (
                    f'{{"reasoning": "{example_reason}", "solution": {example_ans}}}'
                )
            else:
                response = f"{example_reason}\nSOLUTION: {example_ans}"
            system_prompt += f"\nAssistant Response:\n{response}"

    return system_prompt


create_prompt_gsm8k_zero_shot = partial(create_prompt_gsm8k, zero_shot=True)


def parse_response_gsm8k(response: str, prompt_type: str) -> int | None:
    if prompt_type in [
        PromptType.WITH_FUNCTION_CALL.value,
        PromptType.WITH_RESPONSE_SCHEMA.value,
        PromptType.WITH_MIME_TYPE.value,
    ]:
        return ResponseGSM8K.model_validate_json(response).solution
    else:
        cleaned_response = (
            response.split("\nSOLUTION:")[1].replace(",", "").rstrip(".").strip()
        )
        return int(cleaned_response)

In [7]:
# | output: false
# | echo: false

dataset = load_dataset("gsm8k", "main")
evals = [
    {
        "question": d["question"],
        "answer": int(d["answer"].split("#### ")[1].replace(",", "").strip()),
    }
    for d in dataset["test"]
]

if USE_SAMPLE:
    evals = evals[:SAMPLE_SIZE]

### Zero-shot

In [8]:
# | output: false
# | echo: false

evaluator = LLMEvaluator(
    configs=CONFIGS,
    create_prompt_fn=create_prompt_gsm8k_zero_shot,
    parse_response_fn=parse_response_gsm8k,
    response_model=ResponseGSM8K,
)

df = evaluator.generate_outputs(evals)
df_results = evaluator.evaluate_outputs(df)

evaluator.report_error_counts()
evaluator.calculate_confidence_intervals(df_results)
evaluator.run_paired_t_test(df_results)

without_so, Valerie ea: Retrying...
without_so, Valerie ea: Retrying...
without_so, Valerie ea: Retrying...
without_so, Valerie ea: Failed to process question after 3 attempts. Set answer to null.
with_so_response_schema, Valerie ea: Retrying...
with_so_mime_type, Valerie ea: Retrying...
Error counts:
- without_so: 7 processing errors, 0 key order errors
- with_so_response_schema: 2 processing errors, 0 key order errors
- with_so_mime_type: 3 processing errors, 0 key order errors

Calculating confidence intervals (0.95) with 1319 observations:
score_without_so - Mean: 93.71% CI: 92.40% - 95.02%
score_with_so_response_schema - Mean: 93.03% CI: 91.65% - 94.40%
score_with_so_mime_type - Mean: 93.78% CI: 92.48% - 95.09%

score_without_so vs score_with_so_mime_type
t-statistic: -0.13997595763623139, p-value: 0.8887003733154177
score_without_so vs score_with_so_response_schema
t-statistic: 1.116416868942675, p-value: 0.26444722405955823


### Few-shot

In [9]:
# | output: false
# | echo: false

# Few-shot
evaluator = LLMEvaluator(
    configs=CONFIGS,
    create_prompt_fn=create_prompt_gsm8k,
    parse_response_fn=parse_response_gsm8k,
    response_model=ResponseGSM8K,
)

df = evaluator.generate_outputs(evals)
df_results = evaluator.evaluate_outputs(df)

evaluator.report_error_counts()
evaluator.calculate_confidence_intervals(df_results)
evaluator.run_paired_t_test(df_results)

without_so, Valerie ea: Retrying...
Error counts:
- without_so: 4 processing errors, 0 key order errors
- with_so_response_schema: 1 processing errors, 0 key order errors
- with_so_mime_type: 2 processing errors, 0 key order errors

Calculating confidence intervals (0.95) with 1319 observations:
score_without_so - Mean: 94.84% CI: 93.65% - 96.04%
score_with_so_response_schema - Mean: 93.63% CI: 92.31% - 94.95%
score_with_so_mime_type - Mean: 94.16% CI: 92.90% - 95.43%

score_without_so vs score_with_so_mime_type
t-statistic: 1.372947491826186, p-value: 0.1700022107400027
score_without_so vs score_with_so_response_schema
t-statistic: 2.2662866197986777, p-value: 0.023595394810378166


## Last Letter

### Setup

In [10]:
# | output: false
# | echo: false


class ResponseLastLetter(BaseModel):
    reasoning: str = Field(description="step by step reasoning about the answer")
    solution: str = Field(description="final answer")


def create_prompt_last_letter(
    prompt_type: str,
    response_model: ResponseLastLetter | None = None,
    zero_shot: bool = False,
):
    if prompt_type in [
        PromptType.WITH_FUNCTION_CALL.value,
        PromptType.WITH_MIME_TYPE.value,
        PromptType.WITH_RESPONSE_SCHEMA.value,
    ]:
        system_prompt = dedent("""
        You are an expert in solving simple word puzzles using reasoning steps. Your specific task is going to be to take a list of 4 names and reason about the last letter of each. Then, you will concatenate the last letters into a word. 
          
        You will always respond with JSON matching the following schema:
        
        class Response(BaseModel):
            reasoning: str = Field(description="step by step reasoning about the answer")
            solution: str = Field(description="final answer")

        First, provide your step by step reasoning in the "reasoning" field. Then, in the "solution" field, provide the final answer. Don't include any other text in the "solution" field."""
        )
    else:
        system_prompt = dedent("""
        You are an expert in solving simple word puzzles using reasoning steps. Your specific task is going to be to take a list of 4 names and reason about the last letter of each. Then, you will concatenate the last letters into a word. 
        
        You will always respond in the following format:
        
        <str, reasoning about the answer>
        SOLUTION: <str, final answer>
        
        First, provide your step by step reasoning. Then, in SOLUTION, provide the final answer. Don't include any other text in SOLUTION.
        """)

    fewshot_examples = [
        (
            "Ian Peter Bernard Stephen",
            "The last letter of 'Ian' is 'N'. The last letter of 'Peter' is 'R'. The last letter of 'Bernard' is 'D'. The last letter of 'Stephen' is 'N'. Concatenating them is 'NRDN'.",
            "NRDN",
        ),
        (
            "Javier Dylan Christopher Joseph",
            "The last letter of 'Javier' is 'R'. The last letter of 'Dylan' is 'N'. The last letter of 'Christopher' is 'R'. The last letter of 'Joseph' is 'H'. Concatenating them is 'RNRH'.",
            "RNRH",
        ),
        (
            "Anthony Elizabeth Carlos Jesus",
            "The last letter of 'Anthony' is 'Y'. The last letter of 'Elizabeth' is 'H'. The last letter of 'Carlos' is 'S'. The last letter of 'Jesus' is 'S'. Concatenating them is 'YHSS'.",
            "YHSS",
        ),
    ]

    if not zero_shot:
        system_prompt += "\nExamples:" if fewshot_examples else ""
        for i, (example_q, example_reason, example_ans) in enumerate(fewshot_examples):
            system_prompt += f"\n\n**{i+1}**\nQuestion: Take the last letters of the words in '{example_q}' and concatenate them."
            if prompt_type in [
                PromptType.WITH_FUNCTION_CALL.value,
                PromptType.WITH_RESPONSE_SCHEMA.value,
                PromptType.WITH_MIME_TYPE.value,
            ]:
                response = (
                    f'{{"reasoning": "{example_reason}", "answer": "{example_ans}"}}'
                )
            else:
                response = f"{example_reason}\nSOLUTION: {example_ans}"
            system_prompt += f"\nAssistant Response:\n{response}"

    return system_prompt


create_prompt_last_letter_zero_shot = partial(create_prompt_last_letter, zero_shot=True)


def parse_response_last_letter(response: str, prompt_type: str) -> str | None:
    if prompt_type in [
        PromptType.WITH_FUNCTION_CALL.value,
        PromptType.WITH_RESPONSE_SCHEMA.value,
        PromptType.WITH_MIME_TYPE.value,
    ]:
        return ResponseLastLetter.model_validate_json(response).solution.lower()
    else:
        return response.split("\nSOLUTION:")[1].rstrip(".").strip().lower()

In [11]:
# | output: false
# | echo: false

dataset = load_dataset("ChilleD/LastLetterConcat")
evals = [
    {"question": d["question"], "answer": d["answer"].lower()} for d in dataset["test"]
]

if USE_SAMPLE:
    evals = evals[:SAMPLE_SIZE]

### Zero-shot

In [12]:
# | output: false
# | echo: false

evaluator = LLMEvaluator(
    configs=CONFIGS,
    create_prompt_fn=create_prompt_last_letter_zero_shot,
    parse_response_fn=parse_response_last_letter,
    response_model=ResponseLastLetter,
)

df = evaluator.generate_outputs(evals)
df_results = evaluator.evaluate_outputs(df)

evaluator.report_error_counts()
evaluator.calculate_confidence_intervals(df_results)
evaluator.run_paired_t_test(df_results)

Error counts:
- without_so: 0 processing errors, 0 key order errors
- with_so_response_schema: 0 processing errors, 0 key order errors
- with_so_mime_type: 0 processing errors, 0 key order errors

Calculating confidence intervals (0.95) with 150 observations:
score_without_so - Mean: 82.67% CI: 76.59% - 88.74%
score_with_so_response_schema - Mean: 81.33% CI: 75.08% - 87.59%
score_with_so_mime_type - Mean: 80.00% CI: 73.58% - 86.42%

score_without_so vs score_with_so_mime_type
t-statistic: 0.893827507934847, p-value: 0.3728556534698112
score_without_so vs score_with_so_response_schema
t-statistic: 0.391123270505612, p-value: 0.6962648403881377


### Few-shot

In [13]:
# | output: false
# | echo: false

# Few-shot
evaluator = LLMEvaluator(
    configs=CONFIGS,
    create_prompt_fn=create_prompt_last_letter,
    parse_response_fn=parse_response_last_letter,
    response_model=ResponseLastLetter,
)

df = evaluator.generate_outputs(evals)
df_results = evaluator.evaluate_outputs(df)

evaluator.report_error_counts()
evaluator.calculate_confidence_intervals(df_results)
evaluator.run_paired_t_test(df_results)

Error counts:
- without_so: 0 processing errors, 0 key order errors
- with_so_response_schema: 0 processing errors, 0 key order errors
- with_so_mime_type: 0 processing errors, 0 key order errors

Calculating confidence intervals (0.95) with 150 observations:
score_without_so - Mean: 80.00% CI: 73.58% - 86.42%
score_with_so_response_schema - Mean: 80.67% CI: 74.33% - 87.01%
score_with_so_mime_type - Mean: 82.00% CI: 75.83% - 88.17%

score_without_so vs score_with_so_mime_type
t-statistic: -0.6534019417029638, p-value: 0.5145042174802537
score_without_so vs score_with_so_response_schema
t-statistic: -0.1993588014488551, p-value: 0.8422538839198896


## Shuffled Objects

### Setup

In [14]:
# | output: false
# | echo: false


class ResponseShuffledObjects(BaseModel):
    reasoning: str = Field(description="reasoning about the answer")
    solution: Literal["A", "B", "C", "D", "E"] = Field(description="final answer")


def create_prompt_shuffled_objects(
    prompt_type: str,
    response_model: ResponseShuffledObjects | None = None,
    zero_shot: bool = False,
):
    if prompt_type in [
        PromptType.WITH_FUNCTION_CALL.value,
        PromptType.WITH_MIME_TYPE.value,
        PromptType.WITH_RESPONSE_SCHEMA.value,
    ]:
        system_prompt = dedent("""
        You are an expert in performing common sense tasks involving the ordering of a sequence of events.
        Each question will present you with a sequence of events involving 5 people (switching objects, partners, positions, etc.). Your task is to determine the correct answer from the options provided.
          
        You will always respond with JSON matching the following schema:
        
        class Response(BaseModel):
            reasoning: str = Field(description="reasoning about the answer")
            solution: Literal["A", "B", "C", "D", "E"] = Field(description="final answer")

        First, provide your reasoning in the "reasoning" field. Then, in the "solution" field, provide only the single letter representing the correct choice you are presented with. Don't include any other text in the "solution" field.
        """)
    else:
        system_prompt = dedent("""
        You are an expert in performing common sense tasks involving the ordering of a sequence of events.
        Each question will present you with a sequence of events involving 5 people (switching objects, partners, positions, etc.). Your task is to determine the correct answer from the options provided.
        
        You will always respond in the following format:
        
        <str, reasoning about the answer>
        SOLUTION: <str, final answer>
        
        First, provide your step by step reasoning. Then, in SOLUTION, provide only the single letter representing the correct choice you are presented with. Don't include any other text in SOLUTION.
        """)

    fewshot_examples = [
        (
            "Alice, Bob, Claire, Dave, and Eve are dancers at a square dance. At the start of a song, they each have a partner: Alice is dancing with Patrick, Bob is dancing with Sam, Claire is dancing with Jamie, Dave is dancing with Lola, and Eve is dancing with Melissa.\nThroughout the song, the dancers often trade partners. First, Dave and Eve switch partners. Then, Dave and Alice switch partners. Then, Eve and Alice switch partners. Then, Claire and Bob switch partners. Finally, Dave and Alice switch partners. At the end of the dance, Alice is dancing with\nOptions:\n(A) Patrick\n(B) Sam\n(C) Jamie\n(D) Lola\n(E) Melissa",
            "Dave and Eve switch partners, so Dave's partner is now Melissa and Eve's partner is now Lola. Then Dave and Alice switch partners so Dave's partner is now Patrick and Alice's partner is now Melissa. Then Eve and Alice switch partners so Eve's partner is now Melissa and Alice's partner is now Lola. Then Claire and Bob switch patners so Claire's partner is now Sam, and Bob's partner is now Jamie. Finally, Dave and Alice switch partners so Dave's new partner is Lola, and Alice's new partner is Patrick. Alice is dance in with Patrick, choice A.",
            "A",
        ),
        (
            "Alice, Bob, Claire, Dave, and Eve are dancers at a square dance. At the start of a song, they each have a partner: Alice is dancing with Ophelia, Bob is dancing with Jamie, Claire is dancing with Melissa, Dave is dancing with Rodrigo, and Eve is dancing with Patrick.\nThroughout the song, the dancers often trade partners. First, Claire and Bob switch partners. Then, Claire and Eve switch partners. Then, Claire and Bob switch partners. Then, Eve and Dave switch partners. Finally, Claire and Alice switch partners. At the end of the dance, Alice is dancing with\nOptions:\n(A) Ophelia\n(B) Jamie\n(C) Melissa\n(D) Rodrigo\n(E) Patrick",
            "Claire and Bob switch partners, so Claire's partner is now Jamie and Bob's partner is now Melissa. Then, Claire and Eve switch partners, so Claire's partner becomes Patrick and Eve's partner becomes Jamie. Next, Claire and Bob switch partners again, making Claire's partner Melissa and Bob's partner Patrick. After that, Eve and Dave switch partners, resulting in Eve's partner being Rodrigo and Dave's partner being Jamie. Finally, Claire and Alice switch partners, so Claire's partner is now Ophelia and Alice's partner becomes Melissa. Alice is dancing with Melissa, which is choice C.",
            "C",
        ),
        (
            "Alice, Bob, Claire, Dave, and Eve are friends and avid readers who occasionally trade books. At the start of the semester, they each buy one new book: Alice gets Catch-22, Bob gets Hound of the Baskervilles, Claire gets Frankenstein, Dave gets The Pearl, and Eve gets The Fellowship of the Ring.\nAs the semester proceeds, they start trading around the new books. First, Eve and Alice swap books. Then, Alice and Claire swap books. Then, Alice and Bob swap books. Then, Dave and Alice swap books. Finally, Dave and Claire swap books. At the end of the semester, Dave has\nOptions:\n(A) Catch-22\n(B) Hound of the Baskervilles\n(C) Frankenstein\n(D) The Pearl\n(E) The Fellowship of the Ring",
            "First, Eve and Alice swap, so Alice gets The Fellowship of the Ring and Eve gets Catch-22. Next, Alice and Claire swap, giving Claire The Fellowship of the Ring and Alice Frankenstein. Then, Alice and Bob swap, resulting in Bob holding Frankenstein and Alice having Hound of the Baskervilles. Dave and Alice then swap, so Dave takes Hound of the Baskervilles and Alice receives The Pearl. Finally, Dave and Claire swap books, which means Dave takes The Fellowship of the Ring from Claire. Therefore, at the end of all the swaps, Dave possesses The Fellowship of the Ring, making option E the correct answer.",
            "E",
        ),
    ]

    if not zero_shot:
        system_prompt += "\nExamples:" if fewshot_examples else ""
        for i, (example_q, example_reason, example_ans) in enumerate(fewshot_examples):
            system_prompt += f"\n\n**{i+1}**\nQuestion: {example_q}"

            if prompt_type in [
                PromptType.WITH_FUNCTION_CALL.value,
                PromptType.WITH_RESPONSE_SCHEMA.value,
                PromptType.WITH_MIME_TYPE.value,
            ]:
                response = (
                    f'{{"reasoning": "{example_reason}", "answer": "{example_ans}"}}'
                )
            else:
                response = f"{example_reason}\nSOLUTION: {example_ans}"
            system_prompt += f"\nAssistant Response:\n{response}"

    return system_prompt


create_prompt_shuffled_objects_zero_shot = partial(
    create_prompt_shuffled_objects, zero_shot=True
)


def parse_response_shuffled_objects(response: str, prompt_type: str) -> str:
    if prompt_type in [
        PromptType.WITH_FUNCTION_CALL.value,
        PromptType.WITH_RESPONSE_SCHEMA.value,
        PromptType.WITH_MIME_TYPE.value,
    ]:
        return ResponseShuffledObjects.model_validate_json(response).solution
    else:
        return response.split("\nSOLUTION:")[1].rstrip(".").strip()

In [15]:
# | output: false
# | echo: false

dataset = load_dataset(
    "openeval/BIG-Bench-Hard", data_files="tracking_shuffled_objects_five_objects.json"
)
evals = [
    {
        "question": d["input"],
        "answer": d["target"].replace("(", "").replace(")", "").strip(),
    }
    for d in dataset["train"]["examples"][0][4:]  # first 3 are few-shot examples
]

if USE_SAMPLE:
    evals = evals[:SAMPLE_SIZE]

Repo card metadata block was not found. Setting CardData to empty.


### Zero-shot

In [16]:
# | output: false
# | echo: false

evaluator = LLMEvaluator(
    configs=CONFIGS,
    create_prompt_fn=create_prompt_shuffled_objects_zero_shot,
    parse_response_fn=parse_response_shuffled_objects,
    response_model=ResponseShuffledObjects,
)

df = evaluator.generate_outputs(evals)
df_results = evaluator.evaluate_outputs(df)

evaluator.report_error_counts()
evaluator.calculate_confidence_intervals(df_results)
evaluator.run_paired_t_test(df_results)

Error counts:
- without_so: 0 processing errors, 0 key order errors
- with_so_response_schema: 0 processing errors, 0 key order errors
- with_so_mime_type: 0 processing errors, 0 key order errors

Calculating confidence intervals (0.95) with 246 observations:
score_without_so - Mean: 97.15% CI: 95.07% - 99.24%
score_with_so_response_schema - Mean: 86.18% CI: 81.86% - 90.50%
score_with_so_mime_type - Mean: 92.28% CI: 88.93% - 95.62%

score_without_so vs score_with_so_mime_type
t-statistic: 2.5878637375977847, p-value: 0.010234013573805245
score_without_so vs score_with_so_response_schema
t-statistic: 4.760539166490588, p-value: 3.3043823490024686e-06


### Few-shot

In [17]:
# | output: false
# | echo: false

evaluator = LLMEvaluator(
    configs=CONFIGS,
    create_prompt_fn=create_prompt_shuffled_objects,
    parse_response_fn=parse_response_shuffled_objects,
    response_model=ResponseShuffledObjects,
)

df = evaluator.generate_outputs(evals)
df_results = evaluator.evaluate_outputs(df)

evaluator.report_error_counts()
evaluator.calculate_confidence_intervals(df_results)
evaluator.run_paired_t_test(df_results)

Error counts:
- without_so: 0 processing errors, 0 key order errors
- with_so_response_schema: 0 processing errors, 0 key order errors
- with_so_mime_type: 0 processing errors, 0 key order errors

Calculating confidence intervals (0.95) with 246 observations:
score_without_so - Mean: 92.68% CI: 89.42% - 95.94%
score_with_so_response_schema - Mean: 84.96% CI: 80.48% - 89.44%
score_with_so_mime_type - Mean: 98.37% CI: 96.79% - 99.96%

score_without_so vs score_with_so_mime_type
t-statistic: -3.368513159744578, p-value: 0.0008776240357116037
score_without_so vs score_with_so_response_schema
t-statistic: 3.0157257074237465, p-value: 0.00283346865859629
