In [1]:
# | output: false
# | echo: false

%load_ext autoreload
%autoreload 2

In [2]:
# | output: false
# | echo: false

import nest_asyncio

nest_asyncio.apply()

In [3]:
# | output: false
# | echo: false
import asyncio
import json
import os
import random

import google.generativeai as genai
import numpy as np
from dotenv import load_dotenv

np.random.seed(42)

load_dotenv()

MODEL_NAME = "gemini-1.5-flash"
NUMBER_OF_TESTS = 50
MAX_CONCURRENCY = 20

genai.configure(api_key=os.environ["GEMINI_API_KEY"])

In [4]:
fruits = [
    "apples",
    "oranges",
    "bananas",
    "pineapples",
    "coconuts",
    "pears",
    "peaches",
    "grapes",
    "watermelons",
    "strawberries",
    "blueberries",
    "raspberries",
    "blackberries",
    "cherries",
    "plums",
    "apricots",
    "kiwis",
    "mangos",
    "papayas",
    "pears",
    "pineapples",
    "plums",
    "pomegranates",
    "raspberries",
    "strawberries",
    "tangerines",
    "watermelons",
]


def generate_test():
    fruits_sample = random.sample(fruits, 3)
    numbers_sample = random.sample(range(1, 100), 3)
    properties = {
        "reasoning": genai.protos.Schema(type=genai.protos.Type.STRING),
        "fruit_counts": genai.protos.Schema(type=genai.protos.Type.INTEGER),
        fruits_sample[0]: genai.protos.Schema(type=genai.protos.Type.INTEGER),
        fruits_sample[1]: genai.protos.Schema(type=genai.protos.Type.INTEGER),
        fruits_sample[2]: genai.protos.Schema(type=genai.protos.Type.INTEGER),
    }
    class_str = f"\n\nclass Response(BaseModel):\n    reasoning: str\n    fruit_counts: int\n    {fruits_sample[0]}: int\n    {fruits_sample[1]}: int\n    {fruits_sample[2]}: int"
    system_prompt = f"You're a helpful assistant. Begin by reasoning step-by-step, then produce the answer according to the schema. {class_str}"
    user_question = f"Given that I have {numbers_sample[0]} {fruits_sample[0]}, {numbers_sample[1]} {fruits_sample[1]}, and {numbers_sample[2]} {fruits_sample[2]}, represent these counts according to the schema."
    return system_prompt, properties, user_question

In [5]:
tests = [generate_test() for _ in range(NUMBER_OF_TESTS)]

In [6]:
async def generate_content_structured_output(test: tuple):
    model_structured_output = genai.GenerativeModel(
        model_name=MODEL_NAME,
        generation_config=genai.GenerationConfig(
            response_mime_type="application/json",
            response_schema=genai.protos.Schema(
                type=genai.protos.Type.OBJECT,
                properties=test[1],
                required=list(test[1].keys()),
            ),
        ),
        system_instruction=test[0],
    )
    response = await model_structured_output.generate_content_async(test[2])
    return response.text


async def generate_content_function_call(test: tuple):
    model_function_call = genai.GenerativeModel(
        model_name=MODEL_NAME,
        generation_config=genai.GenerationConfig(
            response_mime_type="text/plain",
        ),
        tools=[
            genai.protos.Tool(
                function_declarations=[
                    genai.protos.FunctionDeclaration(
                        name="Response",
                        description="Correctly extracted `Response` with all the required parameters",
                        parameters=genai.protos.Schema(
                            type=genai.protos.Type.OBJECT,
                            properties=test[1],
                            required=list(test[1].keys()),
                        ),
                    )
                ],
            )
        ],
        tool_config={"function_calling_config": "ANY"},
        system_instruction=test[0],
    )
    response = await model_function_call.generate_content_async(test[2])
    for part in response.parts:
        if fn := part.function_call:
            return json.dumps(dict(fn.args))
    return None


def check_if_keys_sorted(output: str, properties: dict):
    order_original_keys = [k for k in properties.keys()]
    output_dict = json.loads(output)
    order_output_keys = [k for k in output_dict.keys()]
    return order_output_keys == order_original_keys


def check_if_keys_sorted_alphabetically(output: str):
    output_dict = json.loads(output)
    output_original_order = [k for k in output_dict.keys()]
    output_sorted_order = [k for k in sorted(output_dict.keys())]
    return output_original_order == output_sorted_order


async def run_single_test(test, call_fn, semaphore):
    async with semaphore:
        output = await call_fn(test)
        keys_match = check_if_keys_sorted(output, test[1])
        alpha_sorted = check_if_keys_sorted_alphabetically(output)
        return output, keys_match, alpha_sorted


async def run_tests(call_fn):
    semaphore = asyncio.BoundedSemaphore(MAX_CONCURRENCY)
    tasks = []
    for test in tests:
        tasks.append(run_single_test(test, call_fn, semaphore))
    return await asyncio.gather(*tasks)

In [7]:
test_results_structured_output = asyncio.run(
    run_tests(generate_content_structured_output)
)

In [8]:
test_results_function_call = asyncio.run(run_tests(generate_content_function_call))

In [11]:
correct_keys_count_structured_output = 0
correct_keys_count_function_call = 0
for test in zip(test_results_structured_output, test_results_function_call):
    if test[0][1]:
        correct_keys_count_structured_output += 1
    if test[1][1]:
        correct_keys_count_function_call += 1

print(f"Keys in correct order (structured output): {correct_keys_count_structured_output}")
print(f"Keys in correct order (function call): {correct_keys_count_function_call}")

Keys in correct order (structured output): 0
Keys in correct order (function call): 0


In [12]:
correct_keys_count_alphabetical_structured_output = 0
correct_keys_count_alphabetical_function_call = 0

for test in zip(test_results_structured_output, test_results_function_call):
    if test[0][2]:
        correct_keys_count_alphabetical_structured_output += 1
    if test[1][2]:
        correct_keys_count_alphabetical_function_call += 1

print(
    f"Keys in alphabetical order (structured output): {correct_keys_count_alphabetical_structured_output}"
)
print(
    f"Keys in alphabetical order (function call): {correct_keys_count_alphabetical_function_call}"
)

Keys in alphabetical order (structured output): 50
Keys in alphabetical order (function call): 2
