In [None]:
#| hide

%load_ext autoreload
%autoreload 2

# Pydantic AI / Ollama - Evaluate

> Here, we evaluate **open source** models. We have to be careful with structured outputs now. We move to Pydantic AI because of the [tool & output format issue](https://github.com/openai/openai-agents-python/issues/1778). Pydantic AI does not support CGD that is available by Ollama but atleast it enforces output (not perfect) but way better. And we can also use tools with structured output haha.

- skip_showdoc: true
- skip_exec: true

## Pydantic AI

[Github tool & output format issue](https://github.com/openai/openai-agents-python/issues/1778)

[Pydantic AI also convert output format to tool](https://github.com/pydantic/pydantic-ai/issues/242)


[Force specific tool choice](https://community.openai.com/t/is-it-possible-to-force-agent-to-use-tools/1316526/2)

In [None]:
#| default_exp experiments_pydantic_ollama

In [None]:
#| hide
#| export

import json
from claimdb.configuration import config
from agents import OpenAIChatCompletionsModel, AsyncOpenAI, Agent, function_tool, Runner
from claimdb.experiments import *

Unclosed client session
client_session: <aiohttp.client.ClientSession object>


In [None]:
#| export
from pydantic_ai import Agent, Tool, UsageLimits, ModelRetry
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.ollama import OllamaProvider

### Test Single Claim

In [None]:
from claimdb.experiments import _tool_cache

In [None]:
models = [
    'gemini-3-flash-preview',
    'qwen3:1.7b',
    'qwen3:4b',
    'qwen3:8b',
    'qwen3:14b',
    'qwen3:32b',
    'qwen3-coder:30b',
    'ministral-3:3b',
    'ministral-3:8b',
    'ministral-3:14b',
    'mistral-nemo:12b',
    'mistral-small:22b',
    'magistral:24b',
    'devstral:24b',
    'devstral-small-2:24b',
    'nemotron-3-nano',
    'llama3.1:8b',
    'llama3.2:3b',
    'cogito:14b',
    'cogito:32b',
    'qwq:32b',
    'gpt-oss:20b',
]

name = models[-1]


In [None]:
ollama_model = OpenAIChatModel(
    model_name=name,
    provider=OllamaProvider(base_url='http://localhost:11434/v1'),  
)

In [None]:
with open(config.final_benchmark_dir / 'test-public.jsonl', "r") as f:
    all_claims = [json.loads(line) for line in f]

claim = all_claims[20]

`retries`: The default number of retries to allow for tool calls and output validation, before raising an error. For model request retries, see the [retries](https://ai.pydantic.dev/api/agent/#pydantic_ai.agent.Agent) documentation.

`output_retries`: The maximum number of retries to allow for output validation, defaults to `retries`.

In [None]:
bare_tool = _tool_cache[claim['db_name']]
tool_name = tool_cache[claim['db_name']].name
tool_description = tool_cache[claim['db_name']].description
def safe_tool(sql_query: str):
    try:
        return bare_tool(sql_query)
    except Exception as e:
        # do not raise cause we hit `max_retries`
        return str(e)
tool = Tool(safe_tool, takes_ctx=False, name=tool_name, description=tool_description)

In [None]:
agent = Agent(
    ollama_model,
    name="Fact-Checker",
    instructions=FACT_CHECKER_PROMPT_3SHOT,
    output_type=ClaimVerdict,
    output_retries=20
)


In [None]:
inp = f"Claim: {claim['claim']}\nExtra Information: {claim['extra_info']}"
#inp = f"use the generate final answer tool to answer. What is 20x20"

print(inp)

In [None]:
result = await agent.run(inp, usage_limits=UsageLimits(tool_calls_limit=20))

In [None]:
all_messages = json.loads(result.all_messages_json().decode())

response = result.output

log_dict = dict(result.output) | {'all_messages': all_messages}

In [None]:
len(result.all_messages())

In [None]:
result.all_messages()

In [None]:
result.usage()

### Run All Claims

In [None]:
#| export 

models = [
    'gemini-3-flash-preview',
    'qwen3:1.7b',
    'qwen3:4b',
    'qwen3:8b',
    'qwen3:14b',
    'qwen3:32b',
    'qwen3-coder:30b',
    'ministral-3:3b',
    'ministral-3:8b',
    'ministral-3:14b',
    'mistral-nemo:12b',
    'mistral-small:22b',
    'magistral:24b',
    'devstral:24b',
    'devstral-small-2:24b',
    'nemotron-3-nano',
    'llama3.1:8b',
    'llama3.2:3b',
    'cogito:14b',
    'cogito:32b',
    'qwq:32b',
    'gpt-oss:20b',
]

name = models[-1]

results_path = config.experiments_dir_pub / f"{name}.jsonl"
results_path.touch()

port, test_quarters = "5000", [1, 2, 3, 4]
port, test_quarters = "5000", [1]
port, test_quarters = "5001", [2]
port, test_quarters = "5002", [3]
port, test_quarters = "5003", [4]

In [None]:
#| export
from toolbox_core import ToolboxSyncClient

In [None]:
#| export
toolbox_client = ToolboxSyncClient(f"http://127.0.0.1:{port}")
_tool_cache = dict()
tool_cache = dict()

for db_name in db_names:
    tool = toolbox_client.load_tool(f"{db_name}_execute_sql")
    _tool_cache[db_name] = tool
    tool_cache[db_name] = function_tool(tool)

In [None]:
#| export
bird_id_to_example_dict = dict()

with open(config.bird_dir / 'train_dev_filtered.jsonl', 'r') as f:
    for line in f:
        parsed = json.loads(line)
        bird_id = parsed['bird_id']
        bird_id_to_example_dict[bird_id] = parsed
    
len(bird_id_to_example_dict)

In [None]:
#| export
with open(results_path, 'r') as f:
    already_tested = [json.loads(line)['claim_id'] for line in f]

benchmark = []
with open(config.final_benchmark_dir / 'test-public.jsonl') as f:
    for line in f: 
        parsed_claim = json.loads(line)
        if parsed_claim['claim_id'] in already_tested: continue
        benchmark.append(parsed_claim)

len(benchmark)

In [None]:

#| export
n = len(benchmark) // 4
q1 = benchmark[0:n]
q2 = benchmark[n:2*n]
q3 = benchmark[2*n:3*n]
q4 = benchmark[3*n:]
quarters = [q1, q2, q3, q4]
test = sum((quarters[i - 1] for i in test_quarters), [])

In [None]:
#| export
print(f"Quarters: {test_quarters}")
print(f"#Exps Left: {len(test)}")
print(f"toolbox port: {port}")
print(name)

In [None]:
#| export
async def run_all_pydantic():
    for claim in test:
        claim_id = claim['claim_id']

        bare_tool = _tool_cache[claim['db_name']]
        tool_name = tool_cache[claim['db_name']].name
        tool_description = tool_cache[claim['db_name']].description
        def safe_tool(sql_query: str):
            try:
                return bare_tool(sql_query)
            except Exception as e:
                # do not raise cause we hit `max_retries`
                return str(e)
        tool = Tool(safe_tool, takes_ctx=False, name=tool_name, description=tool_description)

        ollama_model = OpenAIChatModel(
            model_name=name,
            provider=OllamaProvider(base_url='http://localhost:11434/v1'),  
        )

        agent = Agent(
            ollama_model,
            name="Fact-Checker",
            instructions=FACT_CHECKER_PROMPT_3SHOT,
            output_type=ClaimVerdict,
            output_retries=20,
            tools=[tool]
        )


        inp = f"Claim: {claim['claim']}\nExtra Information: {claim['extra_info']}"

        try:
            result = await agent.run(inp, usage_limits=UsageLimits(tool_calls_limit=20))

            all_messages = json.loads(result.all_messages_json().decode())
            log_dict = dict(result.output) | {'all_messages': all_messages}
        except Exception as e:
            error = str(e)
            log_dict = {'verdict': '', 'justification': '', 'error': error}

        results_dict = {'claim_id': claim_id} | log_dict
        results_path.open('a').write(json.dumps(results_dict) + '\n')
        

In [None]:
#| export
import asyncio

In [None]:
await run_all_pydantic()

In [None]:
#| export
asyncio.run(run_all_pydantic())

In [None]:
#| export
toolbox_client.close()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()