In [135]:
# | output: false
# | echo: false

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [136]:
# | output: false
# | echo: false

import nest_asyncio

nest_asyncio.apply()

In [137]:
# | output: false
# | echo: false
import asyncio
import json
import os
import random
from typing import Literal

import google.generativeai as genai
import numpy as np
from dotenv import load_dotenv

np.random.seed(42)

load_dotenv()

MODEL_NAME = "gemini-1.5-pro-002"
NUMBER_OF_TESTS = 10
MAX_CONCURRENCY = 20

genai.configure(api_key=os.environ["GEMINI_API_KEY"])

## Setup

In [138]:
fruits = [
    "apples",
    "oranges",
    "bananas",
    "coconuts",
    "pears",
    "peaches",
    "grapes",
    "watermelons",
    "strawberries",
    "blueberries",
    "raspberries",
    "blackberries",
    "cherries",
    "plums",
    "apricots",
    "kiwis",
    "mangos",
    "papayas",
    "pineapples",
    "pomegranates",
    "tangerines",
]

assert len(fruits) == len(set(fruits))


def generate_test(
    json_schema_representation: Literal["exclude", "class", "json_schema"],
):
    random.seed(42)
    fruits_sample = random.sample(fruits, 5)
    numbers_sample = random.sample(range(1, 100), 5)
    properties = {
        fruits_sample[0]: genai.protos.Schema(type=genai.protos.Type.INTEGER),
        fruits_sample[1]: genai.protos.Schema(type=genai.protos.Type.INTEGER),
        fruits_sample[2]: genai.protos.Schema(type=genai.protos.Type.INTEGER),
        fruits_sample[3]: genai.protos.Schema(type=genai.protos.Type.INTEGER),
        fruits_sample[4]: genai.protos.Schema(type=genai.protos.Type.INTEGER),
    }
    system_prompt = "You're a helpful assistant. Please produce the answer according to the JSON schema."
    if json_schema_representation == "exclude":
        system_prompt += f"{json_schema_representation}"
    if json_schema_representation == "class":
        class_str = f"\n\nclass Response(BaseModel):\n    {fruits_sample[0]}: int\n    {fruits_sample[1]}: int\n    {fruits_sample[2]}: int\n    {fruits_sample[3]}: int\n    {fruits_sample[4]}: int"
        system_prompt += f"{class_str}"
    if json_schema_representation == "json_schema":
        json_schema = {
            "type": "object",
            "properties": {
                fruits_sample[0]: {"type": "integer"},
                fruits_sample[1]: {"type": "integer"},
                fruits_sample[2]: {"type": "integer"},
                fruits_sample[3]: {"type": "integer"},
                fruits_sample[4]: {"type": "integer"},
            },
            "required": list(properties.keys()),
        }
        system_prompt += f"\n\n{json.dumps(json_schema)}"
    user_question = f"Given that I have {numbers_sample[0]} {fruits_sample[0]}, {numbers_sample[1]} {fruits_sample[1]}, {numbers_sample[2]} {fruits_sample[2]}, {numbers_sample[3]} {fruits_sample[3]}, {numbers_sample[4]} {fruits_sample[4]}, represent these counts according to the schema."
    return system_prompt, properties, user_question

In [139]:
async def generate_content_structured_output(test: tuple):
    model_structured_output = genai.GenerativeModel(
        model_name=MODEL_NAME,
        generation_config=genai.GenerationConfig(
            response_mime_type="application/json",
            response_schema=genai.protos.Schema(
                type=genai.protos.Type.OBJECT,
                properties=test[1],
                required=list(test[1].keys()),
            ),
        ),
        system_instruction=test[0],
    )
    response = await model_structured_output.generate_content_async(test[2])
    return response.text


async def generate_content_function_call(test: tuple):
    model_function_call = genai.GenerativeModel(
        model_name=MODEL_NAME,
        generation_config=genai.GenerationConfig(
            response_mime_type="text/plain",
        ),
        tools=[
            genai.protos.Tool(
                function_declarations=[
                    genai.protos.FunctionDeclaration(
                        name="Response",
                        description="Correctly extracted `Response` with all the required parameters",
                        parameters=genai.protos.Schema(
                            type=genai.protos.Type.OBJECT,
                            properties=test[1],
                            required=list(test[1].keys()),
                        ),
                    )
                ],
            )
        ],
        tool_config={"function_calling_config": "ANY"},
        system_instruction=test[0],
    )
    response = await model_function_call.generate_content_async(test[2])
    for part in response.parts:
        if fn := part.function_call:
            return json.dumps(dict(fn.args))
    return None


async def generate_content_function_call_with_mime_type(test: tuple):
    model_function_call = genai.GenerativeModel(
        model_name=MODEL_NAME,
        generation_config=genai.GenerationConfig(
            response_mime_type="application/json",
        ),
        system_instruction=test[0],
    )
    response = await model_function_call.generate_content_async(test[2])
    return response.text


def check_if_keys_sorted(output: str, properties: dict):
    order_original_keys = [k for k in properties.keys()]
    output_dict = json.loads(output)
    order_output_keys = [k for k in output_dict.keys()]
    return order_output_keys == order_original_keys


def check_if_keys_sorted_alphabetically(output: str):
    output_dict = json.loads(output)
    output_original_order = [k for k in output_dict.keys()]
    output_sorted_order = [k for k in sorted(output_dict.keys())]
    return output_original_order == output_sorted_order


async def run_single_test(test, call_fn, semaphore):
    async with semaphore:
        output = await call_fn(test)
        keys_match = check_if_keys_sorted(output, test[1])
        alpha_sorted = check_if_keys_sorted_alphabetically(output)
        return output, keys_match, alpha_sorted


async def run_tests(call_fn, tests):
    semaphore = asyncio.BoundedSemaphore(MAX_CONCURRENCY)
    tasks = []
    for test in tests:
        tasks.append(run_single_test(test, call_fn, semaphore))
    return await asyncio.gather(*tasks)

## Check if keys are in the correct order or alphabetical order

In [140]:
tests_class = [generate_test("class") for _ in range(NUMBER_OF_TESTS)]
tests_none = [generate_test("exclude") for _ in range(NUMBER_OF_TESTS)]
tests_json_schema = [generate_test("json_schema") for _ in range(NUMBER_OF_TESTS)]

In [141]:
test_results_structured_output = asyncio.run(
    run_tests(generate_content_structured_output, tests_none)
)

In [142]:
test_results_function_call = asyncio.run(
    run_tests(generate_content_function_call, tests_class)
)

In [143]:
test_results_function_call_with_mime_type = asyncio.run(
    run_tests(generate_content_function_call_with_mime_type, tests_class)
)

### Keys in correct order

In [144]:
def pct_correct_keys(
    test_results_structured_output,
    test_results_function_call,
    test_results_function_call_with_mime_type,
):
    correct_keys_count_structured_output = 0
    correct_keys_count_function_call = 0
    correct_keys_count_function_call_with_mime_type = 0

    for test in zip(
        test_results_structured_output,
        test_results_function_call,
        test_results_function_call_with_mime_type,
    ):
        if test[0][1]:
            correct_keys_count_structured_output += 1
        if test[1][1]:
            correct_keys_count_function_call += 1
        if test[2][1]:
            correct_keys_count_function_call_with_mime_type += 1
    print(
        f"Keys in correct order (structured output): {correct_keys_count_structured_output / NUMBER_OF_TESTS:.2%}"
    )
    print(
        f"Keys in correct order (function call): {correct_keys_count_function_call / NUMBER_OF_TESTS:.2%}"
    )
    if test_results_function_call_with_mime_type:
        print(
            f"Keys in correct order (function call with mime type): {correct_keys_count_function_call_with_mime_type / NUMBER_OF_TESTS:.2%}"
        )


pct_correct_keys(
    test_results_structured_output,
    test_results_function_call,
    test_results_function_call_with_mime_type,
)

Keys in correct order (structured output): 0.00%
Keys in correct order (function call): 0.00%
Keys in correct order (function call with mime type): 100.00%


### Keys in alphabetical order

In [145]:
def pct_alphabetical_keys(
    test_results_structured_output,
    test_results_function_call,
    test_results_function_call_with_mime_type,
):
    correct_keys_count_alphabetical_structured_output = 0
    correct_keys_count_alphabetical_function_call = 0
    correct_keys_count_alphabetical_function_call_with_mime_type = 0

    for test in zip(
        test_results_structured_output,
        test_results_function_call,
        test_results_function_call_with_mime_type,
    ):
        if test[0][2]:
            correct_keys_count_alphabetical_structured_output += 1
        if test[1][2]:
            correct_keys_count_alphabetical_function_call += 1
        if test[2][2]:
            correct_keys_count_alphabetical_function_call_with_mime_type += 1

    print(
        f"Keys in alphabetical order (structured output): {correct_keys_count_alphabetical_structured_output / NUMBER_OF_TESTS:.2%}"
    )
    print(
        f"Keys in alphabetical order (function call): {correct_keys_count_alphabetical_function_call / NUMBER_OF_TESTS:.2%}"
    )
    if test_results_function_call_with_mime_type:
        print(
            f"Keys in alphabetical order (function call with mime type): {correct_keys_count_alphabetical_function_call_with_mime_type / NUMBER_OF_TESTS:.2%}"
        )


pct_alphabetical_keys(
    test_results_structured_output,
    test_results_function_call,
    test_results_function_call_with_mime_type,
)

Keys in alphabetical order (structured output): 100.00%
Keys in alphabetical order (function call): 0.00%
Keys in alphabetical order (function call with mime type): 0.00%


## Ask to keep the keys in the correct order in prompt

In [146]:
def generate_test_ask_to_keep_order(
    json_schema_representation: Literal["exclude", "class", "json_schema"],
):
    random.seed(42)
    fruits_sample = random.sample(fruits, 5)
    numbers_sample = random.sample(range(1, 100), 5)
    properties = {
        fruits_sample[0]: genai.protos.Schema(type=genai.protos.Type.INTEGER),
        fruits_sample[1]: genai.protos.Schema(type=genai.protos.Type.INTEGER),
        fruits_sample[2]: genai.protos.Schema(type=genai.protos.Type.INTEGER),
        fruits_sample[3]: genai.protos.Schema(type=genai.protos.Type.INTEGER),
        fruits_sample[4]: genai.protos.Schema(type=genai.protos.Type.INTEGER),
    }
    system_prompt = "You're a helpful assistant. Please produce the answer according to the JSON schema. Make sure to keep the keys in the same order as the schema."
    if json_schema_representation == "exclude":
        system_prompt += f"{json_schema_representation}"
    if json_schema_representation == "class":
        class_str = f"\n\nclass Response(BaseModel):\n    {fruits_sample[0]}: int\n    {fruits_sample[1]}: int\n    {fruits_sample[2]}: int\n    {fruits_sample[3]}: int\n    {fruits_sample[4]}: int"
        system_prompt += f"{class_str}"
    if json_schema_representation == "json_schema":
        json_schema = {
            "type": "object",
            "properties": {
                fruits_sample[0]: {"type": "integer"},
                fruits_sample[1]: {"type": "integer"},
                fruits_sample[2]: {"type": "integer"},
                fruits_sample[3]: {"type": "integer"},
                fruits_sample[4]: {"type": "integer"},
            },
            "required": list(properties.keys()),
        }
        system_prompt += f"\n\n{json.dumps(json_schema)}"
    user_question = f"Given that I have {numbers_sample[0]} {fruits_sample[0]}, {numbers_sample[1]} {fruits_sample[1]}, {numbers_sample[2]} {fruits_sample[2]}, {numbers_sample[3]} {fruits_sample[3]}, {numbers_sample[4]} {fruits_sample[4]}, represent these counts according to the schema."
    return system_prompt, properties, user_question

In [147]:
tests_class = [generate_test_ask_to_keep_order("class") for _ in range(NUMBER_OF_TESTS)]
tests_none = [generate_test_ask_to_keep_order("exclude") for _ in range(NUMBER_OF_TESTS)]

In [148]:
test_results_structured_output_ask_to_keep_order = asyncio.run(
    run_tests(generate_content_structured_output, tests_none)
)
test_results_function_call_ask_to_keep_order = asyncio.run(
    run_tests(generate_content_function_call, tests_class)
)

In [149]:
pct_correct_keys(
    test_results_structured_output_ask_to_keep_order,
    test_results_function_call_ask_to_keep_order,
    []
)

Keys in correct order (structured output): 0.00%
Keys in correct order (function call): 0.00%


In [150]:
pct_alphabetical_keys(
    test_results_structured_output_ask_to_keep_order,
    test_results_function_call_ask_to_keep_order,
    [],
)


Keys in alphabetical order (structured output): 0.00%
Keys in alphabetical order (function call): 0.00%
