# AI Search Evals

This guide demonstrates how we developed Braintrust's AI-powered search bar, harnessing the power of Braintrust's evaluation workflow along the way. If you've used Braintrust before, you may be familiar with the project page, which serves as a home base for collections of eval experiments:

![Braintrust Project Page](./assets/project-page-sql.png)

To find a particular experiment, you can type filter and sort queries into the search bar, using standard SQL syntax. But SQL can be finicky -- it's very easy to run into syntax errors like single quotes instead of double, incorrect JSON extraction syntax, or typos. Users would prefer to just type in an intuitive search like `experiments run on git commit 2a43fd1` or `score under 0.5` and see a corresponding SQL query appear automatically. Let's achieve this using AI, with assistance from Braintrust's eval framework.

We'll start by installing some packages and setting up our OpenAI client.

In [1]:
%pip install -U Levenshtein autoevals braintrust chevron duckdb openai pydantic

Note: you may need to restart the kernel to use updated packages.


In [2]:
# Set your API key here
%env BRAINTRUST_API_KEY=

env: BRAINTRUST_API_KEY=


In [3]:
%env BRAINTRUST_APP_URL=https://www.braintrustdata.com
%env BRAINTRUST_ORG_NAME=braintrustdata.com
# %env BRAINTRUST_API_KEY=sk-d2XPSspSFBpBAtRheuVrZt3QmBV8oxzPDR5WlKJRo0Yst4VQ
%env BRAINTRUST_API_KEY=sk-CJjoRo2XsQSFg0DsrE87Hne4OWs2jNytx57FNK8fohZac1zb

env: BRAINTRUST_APP_URL=https://www.braintrustdata.com
env: BRAINTRUST_ORG_NAME=braintrustdata.com
env: BRAINTRUST_API_KEY=sk-CJjoRo2XsQSFg0DsrE87Hne4OWs2jNytx57FNK8fohZac1zb


In [4]:
import os
import braintrust
import openai

PROJECT_NAME = "AI Search Cookbook"

# Your Braintrust API key can be used to access any AI provider via the
# Braintrust proxy, once you've added the provider API keys to your account.
openai_opts = dict(
    base_url="https://braintrustproxy.com/v1",
    api_key=os.environ["BRAINTRUST_API_KEY"],
)
client = braintrust.wrap_openai(openai.AsyncOpenAI(default_headers={"x-bt-use-cache": "always"}, **openai_opts))
dataset = braintrust.init_dataset(PROJECT_NAME, "AI Search Cookbook", use_output=False)

## Load the data and render the templates

When we ask GPT to translate a search query, we have to account for multiple output options: (1) a SQL filter, (2) a SQL sort, (3) both of the above, or (4) an unsuccessful translation (e.g. for a nonsensical user input). We'll use [function calling](https://platform.openai.com/docs/guides/function-calling) to robustly handle each distinct scenario, with the following output format:

- `match`: Whether or not the model was able to translate the search into a valid SQL filter/sort.
- `filter`: A `WHERE` clause.
- `sort`: An `ORDER BY` clause.
- `explanation`: Explanation for the choices above -- this is useful for debugging and evaluation.

In [5]:
import dataclasses
from typing import Optional
from pydantic import BaseModel, Field, create_model


@dataclasses.dataclass
class FunctionCallOutput:
    match: bool = False
    filter: Optional[str] = None
    sort: Optional[str] = None
    explanation: Optional[str] = None
    error: Optional[str] = None


class Match(BaseModel):
    explanation: str = Field(..., description="Explanation of why I called the MATCH function")


class SQL(BaseModel):
    filter: Optional[str] = Field(..., description="SQL filter clause")
    sort: Optional[str] = Field(..., description="SQL sort clause")
    explanation: str = Field(
        ..., description="Explanation of why I called the SQL function and how I chose the filter and/or sort clauses"
    )
    

def function_choices():
    return [
        {
            "name": "MATCH",
            "description": "Interpret the query as a simple substring match",
            "parameters": Match.model_json_schema(),
        },
        {
            "name": "SQL",
            "description": "Interpret the query as a SQL filter and/or sort clause",
            "parameters": SQL.model_json_schema(),
        },
    ]

## Prepare prompts for evaluation in Braintrust

Let's evaluate two different prompts: a shorter prompt with a brief explanation of the problem statement and description of the experiment schema, and a longer prompt that additionally contains a feed of example cases to guide the model. There's nothing special about either of these prompts, and that's OK -- we can iterate and improve the prompts when we use Braintrust to drill down into the results.

In [6]:
SHORT_PROMPT_FILE = "./assets/short_prompt.tmpl"
LONG_PROMPT_FILE = "./assets/long_prompt.tmpl"

with open(SHORT_PROMPT_FILE) as f:
    short_prompt = f.read()
    
with open(LONG_PROMPT_FILE) as f:
    long_prompt = f.read()

One detail worth mentioning: each prompt contains a stub for dynamic insertion of the data schema. This is motivated by the need to handle semantic searches like `more than 40 examples` or `score < 0.5` that don't directly reference a column in the base table. We need to tell the model how the data is structured and what each fields actually _means_. We'll construct a descriptive schema using [pydantic](link) and paste it into each prompt to provide the model with this information.

In [7]:
import chevron
from typing import Any, Callable, Dict, List


class ExperimentGitState(BaseModel):
    commit: str = Field(
        ...,
        description="Git commit hash. Any prefix of this hash at least 7 characters long should be considered an exact match, so use a substring filter rather than string equality to check the commit, e.g. `(source->>'commit') ILIKE '{COMMIT}%'`",
    )
    branch: str = Field(..., description="Git branch name")
    tag: Optional[str] = Field(..., description="Git commit tag")
    commit_time: int = Field(..., description="Git commit timestamp")
    author_name: str = Field(..., description="Author of git commit")
    author_email: str = Field(..., description="Email address of git commit author")
    commit_message: str = Field(..., description="Git commit message")
    dirty: Optional[bool] = Field(
        ...,
        description="Whether the git state was dirty when the experiment was run. If false, the git state was clean",
    )


class Experiment(BaseModel):
    id: str = Field(..., description="Experiment ID, unique")
    name: str = Field(..., description="Name of the experiment")
    last_updated: int = Field(
        ...,
        description="Timestamp marking when the experiment was last updated. If the query deals with some notion of relative time, like age or recency, refer to this timestamp and, if appropriate, compare it to the current time `get_current_time()` by adding or subtracting an interval.",
    )
    creator: Dict[str, str] = Field(..., description="Information about the experiment creator")
    source: ExperimentGitState = Field(..., description="Git state that the experiment was run on")
    metadata: Dict[str, Any] = Field(
        ...,
        description="Custom metadata provided by the user. Ignore this field unless the query mentions metadata or refers to a metadata key specifically",
    )


def build_experiment_schema(score_fields: List[str]):
    ExperimentWithScoreFields = create_model(
        "Experiment",
        __base__=Experiment,
        **{field: (Optional[float], ...) for field in score_fields},
    )
    return json.dumps(ExperimentWithScoreFields.model_json_schema())


def render_prompt(prompt: str, score_fields: List[str]) -> str:
    schema = build_experiment_schema(score_fields)
    return chevron.render(prompt.strip(), {"schema": schema}, warn=True)

Our prompts are ready! Before we run our evals, we just need to load some sample data and define our scoring functions.

## Load sample data

Let's load our examples. Each example case contains `input` (the search query) and `expected` (function call output). 

In [8]:
import json


@dataclasses.dataclass
class Example:
    input: str
    expected: FunctionCallOutput
    

EXAMPLES_FILE = "./assets/examples.json"
with open(EXAMPLES_FILE) as f:
    examples_json = json.load(f)
    
templates = [Example(input=e["input"], expected=FunctionCallOutput(**e["expected"])) for e in examples_json["examples"]]

# Each example contains a few dynamic fields that depends on the experiments
# we're searching over. For simplicity, we'll hard-code these fields here.
SCORE_FIELDS = ['avg_sql_score', 'avg_factuality_score']


def render_example(example: Example, args: Dict[str, Any]) -> Example:
    render_optional = lambda template: chevron.render(template, args, warn=True) if template is not None else None
    return Example(
        input=render_optional(example.input),
        expected=FunctionCallOutput(
            match=example.expected.match,
            filter=render_optional(example.expected.filter),
            sort=render_optional(example.expected.sort),
            explanation=render_optional(example.expected.explanation),
        ),
    )


examples = [render_example(t, {"score_fields": SCORE_FIELDS}) for t in templates]

Insert our examples into a Braintrust dataset so we can introspect and reuse the data later.

In [9]:
for example in examples:
    # dataset.insert(**dataclasses.asdict(example))
    dataset.insert(input=example.input, expected=example.expected)
dataset.flush()

records = list(dataset)
print(f"Generated {len(records)} records. Here are the first 2...")
for record in records[:2]:
    print(record)

Generated 45 records. Here are the first 2...
{'id': '101323f7-da03-4f63-8bb4-e1ef1a80c522', 'span_id': '0375d6ab-998b-4c39-9f32-fd8c65e4e9c1', 'root_span_id': '0375d6ab-998b-4c39-9f32-fd8c65e4e9c1', '_xact_id': 1000192548352904752, 'created': '2024-02-19T03:48:32.415527Z', 'project_id': '61ce386b-1dac-4027-980f-2f3baf32c9f4', 'dataset_id': '8a114409-177e-4b3a-ae70-69b050f668ef', 'input': 'metadata.env.veggies includes kale or spinach', 'expected': {'sort': None, 'error': None, 'match': False, 'filter': "(metadata->'env'->>'veggies') ILIKE '%kale%' OR (metadata->'env'->>'veggies') ILIKE '%spinach%'", 'explanation': None}, 'metadata': None}
{'id': '12db86bf-f395-4276-bf63-67866f31c3f6', 'span_id': 'fa9e9bac-6b4b-40fd-8fe9-aac4b2fec632', 'root_span_id': 'fa9e9bac-6b4b-40fd-8fe9-aac4b2fec632', '_xact_id': 1000192548352904752, 'created': '2024-02-19T03:48:32.415110Z', 'project_id': '61ce386b-1dac-4027-980f-2f3baf32c9f4', 'dataset_id': '8a114409-177e-4b3a-ae70-69b050f668ef', 'input': 'null 

## Define scoring functions

How do we score our outputs against the ground truth queries? We can't rely on an exact text match, since there are multiple correct ways to translate a SQL query. Instead, we'll use two approximate scoring methods: (1) `SQLScorer`, which roundtrips each query through `json_serialize_sql` to normalize before attempting a direct comparison, and (2) `AutoScorer`, which delegates the scoring task to `gpt-4`.

In [None]:
import duckdb
from braintrust import current_span, traced
from autoevals import Score, Scorer, Sql
from Levenshtein import distance


EXPERIMENTS_TABLE = "./assets/experiments.parquet";
SUMMARY_TABLE = "./assets/experiments_summary.parquet";

duckdb.sql(f"DROP TABLE IF EXISTS experiments; CREATE TABLE experiments AS SELECT * FROM '{EXPERIMENTS_TABLE}'")
duckdb.sql(f"DROP TABLE IF EXISTS experiments_summary; CREATE TABLE experiments_summary AS SELECT * FROM '{SUMMARY_TABLE}'")


def _test_clause(*, filter=None, sort=None) -> bool:
    clause = f"""
        SELECT
          experiments.id AS id,
          experiments.name,
          experiments_summary.last_updated,
          experiments.user AS creator,
          experiments.repo_info AS source,
          experiments_summary.* EXCLUDE (experiment_id, last_updated),
        FROM experiments
        LEFT JOIN experiments_summary ON experiments.id = experiments_summary.experiment_id
        {'WHERE ' + filter if filter else ''}
        {'ORDER BY ' + sort if sort else ''}
    """
    current_span().log(metadata=dict(test_clause=clause))
    try:
        duckdb.sql(clause).fetchall()
        return True
    except Exception:
        return False


def _single_quote(s):
    return f"""'{s.replace("'", "''")}'"""


def _roundtrip_filter(s):
    return duckdb.sql(f"""
        SELECT json_deserialize_sql(json_serialize_sql({_single_quote(f"SELECT 1 WHERE {s}")}))
    """).fetchall()[0][0]


def _roundtrip_sort(s):
    return duckdb.sql(f"""
        SELECT json_deserialize_sql(json_serialize_sql({_single_quote(f"SELECT 1 ORDER BY {s}")}))
    """).fetchall()[0][0]


def score_clause(
    output: Optional[str], expected: Optional[str], roundtrip: Callable[[str], str], test_clause: Callable[[str], bool]
) -> float:
    exact_match = 1 if output == expected else 0
    current_span().log(scores=dict(exact_match=exact_match))
    if exact_match:
        return 1

    roundtrip_match = 1 if roundtrip(output) == roundtrip(expected) else 0
    current_span().log(scores=dict(roundtrip_match=roundtrip_match))
    if roundtrip_match:
        return 1

    # If the queries aren't equivalent after roundtripping, it's not immediately clear
    # whether they are semantically equivalent. Let's at least check that the generated
    # clause is valid SQL by running the `test_clause` function defined above, which
    # runs a test query against our sample data.
    valid_clause_score = 1 if test_clause(output) else 0
    current_span().log(scores=dict(valid_clause=valid_clause_score))
    if valid_clause_score == 0:
        return 0

    max_len = max(len(clause) for clause in [output, expected])
    if max_len == 0:
        current_span().log(metadata=dict(error="Bad example: empty clause"))
        return 0
    return 1 - (distance(output, expected) / max_len)

    
class SQLScorer(Scorer):
    """SQLScorer uses DuckDB's `json_serialize_sql` function to determine whether
    the model's chosen filter/sort clause(s) are equivalent to the expected
    outputs. If not, we assign partial credit to each clause depending on
    (1) whether the clause is valid SQL, as determined by running it against
    the actual data and seeing if it errors, and (2) a distance-wise comparison
    to the expected text.
    """
    
    def _run_eval_sync(
        self,
        output,
        expected=None,
        **kwargs,
    ):
        if expected is None:
            raise ValueError("SQLScorer requires an expected value")

        name = "SQLScorer"
        expected = FunctionCallOutput(**expected)
        
        function_choice_score = 1 if output.match == expected.match else 0
        current_span().log(scores=dict(function_choice=function_choice_score))
        if function_choice_score == 0:
            return Score(name=name, score=0)
        if expected.match:
            return Score(name=name, score=1)

        filter_score = None
        if output.filter and expected.filter:
            with current_span().start_span("SimpleFilter") as span:
                filter_score = score_clause(output.filter, expected.filter, _roundtrip_filter, lambda s: _test_clause(filter=s))
        elif output.filter or expected.filter:
            filter_score = 0
        current_span().log(scores=dict(filter=filter_score))

        sort_score = None
        if output.sort and expected.sort:
            with current_span().start_span("SimpleSort") as span:
                sort_score = score_clause(output.sort, expected.sort, _roundtrip_sort, lambda s: _test_clause(sort=s))
        elif output.sort or expected.sort:
            sort_score = 0
        current_span().log(scores=dict(sort=sort_score))

        scores = [s for s in [filter_score, sort_score] if s is not None]
        if len(scores) == 0:
            return Score(name=name, score=0, error="Bad example: no filter or sort for SQL function call")
        return Score(name=name, score=sum(scores) / len(scores))


@traced("auto_score_filter")
def auto_score_filter(openai_opts, **kwargs):
    return Sql(**openai_opts)(**kwargs)


@traced("auto_score_sort")
def auto_score_sort(openai_opts, **kwargs):
    return Sql(**openai_opts)(**kwargs)


class AutoScorer(Scorer):
    """AutoScorer uses the `Sql` scorer from the autoevals library to auto-score
    the model's chosen filter/sort clause(s) against the expected outputs
    using an LLM.
    """
    
    def __init__(self, **openai_opts):
        self.openai_opts = openai_opts

    def _run_eval_sync(
        self,
        output,
        expected=None,
        **kwargs,
    ):
        if expected is None:
            raise ValueError("AutoScorer requires an expected value")
        input = kwargs.get("input")
        if input is None or not isinstance(input, str):
            raise ValueError("AutoScorer requires an input value of type str")

        name = "AutoScorer"
        expected = FunctionCallOutput(**expected)

        function_choice_score = 1 if output.match == expected.match else 0
        current_span().log(scores=dict(function_choice=function_choice_score))
        if function_choice_score == 0:
            return Score(name=name, score=0)
        if expected.match:
            return Score(name=name, score=1)

        filter_score = None
        if output.filter and expected.filter:
            result = auto_score_filter(
                openai_opts=self.openai_opts,
                input=input,
                output=output.filter,
                expected=expected.filter,
            )
            filter_score = result.score or 0
        elif output.filter or expected.filter:
            filter_score = 0
        current_span().log(scores=dict(filter=filter_score))

        sort_score = None
        if output.sort and expected.sort:
            result = auto_score_sort(
                openai_opts=self.openai_opts,
                input=input,
                output=output.sort,
                expected=expected.sort,
            )
            sort_score = result.score or 0
        elif output.sort or expected.sort:
            sort_score = 0
        current_span().log(scores=dict(sort=sort_score))

        scores = [s for s in [filter_score, sort_score] if s is not None]
        if len(scores) == 0:
            return Score(name=name, score=0, error="Bad example: no filter or sort for SQL function call")
        return Score(name=name, score=sum(scores) / len(scores))

## Run the evals!

We'll use the Braintrust `Eval` framework to set up our experiments according to the prompts, dataset, and scoring functions defined above.

In [None]:
def build_completion_kwargs(
    *,
    query: str,
    model: str,
    prompt: str,
    score_fields: List[str],
    **kwargs,
):
    # Inject the JSON schema into the prompt to assist the model.
    schema = build_experiment_schema(score_fields=score_fields)
    system_message = chevron.render(prompt.strip(), {"schema": schema}, warn=True)
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": f"Query: {query}"},
    ]
    tools = [{"type": "function", "function": f} for f in function_choices()]
    return dict(
        model=model,
        temperature=0,
        messages=messages,
        tools=tools,
    )


def format_output(completion):
    try:
        tool_calls = completion.choices[0].message.tool_calls
        if tool_calls is None or len(tool_calls) != 1:
            raise Exception("No tool call: ", completion.choices[0].message)
        tool_call = tool_calls[0]
        if tool_call.type != "function":
            raise Exception("Tool call is not a function")
        function_name = tool_call.function.name
        arguments = json.loads(tool_call.function.arguments)
        return FunctionCallOutput(match=function_name == "MATCH", **arguments)
    except Exception as e:
        return FunctionCallOutput(error=str(e))


GRADER = "gpt-4"  # Used by AutoScorer to grade the model outputs


async def run_eval(experiment_name, prompt, model, score_fields=SCORE_FIELDS):
    async def task(input):
        completion_kwargs = build_completion_kwargs(
            query=input,
            model=model,
            prompt=prompt,
            score_fields=score_fields,
        )
        return format_output(await client.chat.completions.create(**completion_kwargs))

    await braintrust.Eval(
        name=PROJECT_NAME,
        experiment_name=experiment_name,
        data=dataset,
        task=task,
        scores=[SQLScorer(), AutoScorer(**openai_opts, model=GRADER)],
    )

Let's run it! We'll use `gpt-3.5-turbo` for both evals.

In [None]:
await run_eval("Short Prompt 2.0", short_prompt, "gpt-3.5-turbo")

In [None]:
await run_eval("Long Prompt 2.0", long_prompt, "gpt-3.5-turbo")

## View the results in Braintrust

The evals will generate a link to the experiment page. Click into an experiment to view the results!

If you've just been following along, you can [check out some sample results here](). Type some searches into the search bar to see AI search in action. :)

![Braintrust Project Page](./assets/project-page-sql.png)