# BrainTrust Text2SQL Tutorial
<a target="_blank" href="https://colab.research.google.com/github/braintrustdata/braintrust-examples/blob/main/text-to-sql/py/BrainTrust-Text2SQL-Tutorial.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

Welcome to [BrainTrust](https://www.braintrustdata.com/)! This tutorial will teach you the basics of working with BrainTrust, including creating a project, running experiments, and analyzing their results. By the time you finish this tutorial, you should be ready to run your own experiments. This notebook accompanies the [Creating your first project](https://www.braintrustdata.com/docs/getting-started/first-project) guide in the docs.

Before starting, please make sure that you have a BrainTrust account. If you do not, please [sign up](https://www.braintrustdata.com) or [get in touch](mailto:info@braintrustdata.com).

## Setting up the environment

Before we get started, please enter your OpenAI Key here:

In [None]:
# NOTE: Replace YOUR_OPENAI_KEY with your OpenAI API Key and YOUR_BRAINTRUST_API_KEY with your BrainTrust API key. Do not put it in quotes.
%env OPENAI_API_KEY=YOUR_OPENAI_API_KEY
%env BRAINTRUST_API_KEY=YOUR_BRAINTRUST_API_KEY

The next few commands will install some libraries and include some helper code for the text2sql application. Feel free to copy/paste/tweak/reuse this code in your own tools.

In [None]:
!pip install braintrust duckdb datasets openai pyarrow python-Levenshtein

In [None]:
# Import libraries + define helper functions

import duckdb
from datasets import load_dataset
import json
from Levenshtein import distance
import openai
import os
import pyarrow as pa
import time

openai.api_key = os.getenv("OPENAI_API_KEY")
NUM_TEST_EXAMPLES = 10

# Define some helper functions
def get_table(table):
    rows = [
        {h: row[i] for (i, h) in enumerate(table["header"])} for row in table["rows"]
    ]

    return pa.Table.from_pylist(rows)

AGG_OPS = [None, "MAX", "MIN", "COUNT", "SUM", "AVG"]
COND_OPS = [" ILIKE ", ">", "<"]  # , "OP"]


def esc_fn(s):
    return f'''"{s.replace('"', '""')}"'''


def esc_value(s):
    if isinstance(s, str):
        return s.replace("'", "''")
    else:
        return s

def codegen_query(query):
    header = query["table"]["header"]

    projection = f"{esc_fn(header[query['sql']['sel']])}"

    agg_op = AGG_OPS[query["sql"]["agg"]]
    if agg_op is not None:
        projection = f"{agg_op}({projection})"

    conds = query["sql"]["conds"]

    filters = " and ".join(
        [
            f"""{esc_fn(header[field])}{COND_OPS[cond]}'{esc_value(value)}'"""
            for (field, cond, value) in zip(
                conds["column_index"], conds["operator_index"], conds["condition"]
            )
        ]
    )

    if filters:
        filters = f" WHERE {filters}"

    return f'SELECT {projection} FROM "table"{filters}'

OPENAI_CACHE = None
def openai_req(Completion=openai.Completion, **kwargs):
    global OPENAI_CACHE
    if OPENAI_CACHE is None:
        os.makedirs("data", exist_ok=True)
        OPENAI_CACHE = duckdb.connect(database="data/oai_cache.duckdb")
        OPENAI_CACHE.query(
            "CREATE TABLE IF NOT EXISTS cache (params text, response text)"
        )

    param_key = json.dumps(kwargs)
    resp = OPENAI_CACHE.execute(
        """SELECT response FROM "cache" WHERE params=?""", [param_key]
    ).fetchone()
    if resp and resp[0]:
        return json.loads(resp[0])

    for i in range(5):
      try:
        resp = Completion.create(**kwargs).to_dict()
        break
      except openai.error.RateLimitError:
        print("Rate limited... Sleeping for 30 seconds")
        time.sleep(30)


    OPENAI_CACHE.execute(
        """INSERT INTO "cache" VALUES (?, ?)""", [param_key, json.dumps(resp)]
    )

    return resp

def green(s):
  return "\x1b[32m" + s + "\x1b[0m"

def run_query(sql, table_record):
    table = get_table(table_record)  # noqa
    rel_from_arrow = duckdb.arrow(table)

    result = rel_from_arrow.query("table", sql).fetchone()
    if result and len(result) > 0:
        return result[0]
    return None

def score(r1, r2):
    if r1 is None and r2 is None:
        return 1
    if r1 is None or r2 is None:
        return 0

    r1, r2 = str(r1), str(r2)

    total_len = max(len(r1), len(r2))
    return 1 - distance(r1, r2) / total_len

## Exploring the data

In this section, we'll take a look at the dataset and ground truth text/sql pairs to better understand the problem and data.

In [None]:
# Initialize data from WikiSQL
data = list(load_dataset("wikisql")["test"])
idx = 0

In [None]:
data[idx]["question"]

In [None]:
table = get_table(data[idx]['table'])
duckdb.arrow(table).query("table", 'SELECT * FROM "table"')

In [None]:
gt_sql = codegen_query(data[idx])
print(gt_sql)

In [None]:
duckdb.arrow(table).query("table", gt_sql)

## Running your first experiment

In this section, we'll create our first experiment and analyze the results in BrainTrust.

In [None]:
# First attempt: provide the question and columns
def text2sql(query):
    table = query["table"]
    meta = "\n".join(f'"{h}"' for h in table["header"])

    prompt = f"""
Print a SQL query (over a table named "table" - the table is quoted with double quotes) that answers the question below.

You have the following columns:
{meta}

The format should be
Question: the question to ask
SQL: the SQL to generate

Question: {query['question']}
SQL: """
    resp = openai_req(model="text-davinci-003", prompt=prompt, max_tokens=1024)

    return (
        prompt,
        resp,
        resp["choices"][0]["text"].rstrip(";")
        if len(resp["choices"]) > 0
        else None,
    )

prompt, resp, _ = text2sql(data[idx])
print(prompt + green(resp['choices'][0]['text']))

output_sql = resp['choices'][0]['text'].rstrip(";")
try:
  duckdb.arrow(table).query("table", output_sql)
except Exception as e:
  print(e)

Now, that we've tested the prompt on an example, let's run it on several test examples, and log the results in BrainTrust.

First, we'll initialize a new experiment in a **project** named `text2sql-tutorial`. A project allows you to group together **experiments** that contain similar inputs and outputs, and compare results across them. In this tutorial, we'll ultimately create two experiments.

When you run `braintrust.init`, BrainTrust will automatically create the project if it does not exist. If there is an experiment in that project with the name you selected, BrainTrust will automatically create a new one (e.g. `with-columns-001`). Experiments are meant to be short-lived, one-time analyses.

_NOTE: If you did not specify a valid BRAINTRUST_API_KEY, you may be prompted to enter a token here for authentication._

In [None]:
import braintrust
bt = braintrust.init(project="text2sql-tutorial", experiment="with-columns")

This function runs the experiment in a loop (over `NUM_TEST_EXAMPLES` test cases), and logs a bunch of data to BrainTrust with the `bt.log` command. Each of these arguments can contain arbitrary JSON data which you can later slice & dice to understand how changes in your prompts & models affect results for different subsets of data. Here's a quick explanation of each argument:

* `inputs`: the arguments that uniquely define a test case. Later on, BrainTrust will use the `inputs` to know whether two test casess are the same between experiments, so they should not contain experiment-specific state. A simple rule of thumb is that if you run the same experiment twice, the `inputs` should be identical.
* `output`: the output of your application, including post-processing, that allows you to determine whether the result is correct or not. For example, in the text2sql app, the `output` is the _result_ of the SQL query generated by the model, not the query itself, because there may be multiple valid SQL queries to answer a single question.
* `expected`: the ground truth value that you'd compare to `output` to determine if your `output` value is correct or not. BrainTrust currently does not compare `output` to `expected` for you, since there are so many different ways to do that correctly. Instead, these values are just used to help you navigate your experiments while digging into analyses. However, we may later use these values to re-score outputs or fine-tune your models.
* `scores`: one or more numeric values (between 0 and 1) that tell you how accurate the outputs are compared to what you expect. In this example, the `answer` score is the definitive measure of correctness, but the `query` score helps you measure how far off the queries are. For example, if there's a small syntax error, you might have an `answer` score of 0 but a high `query` score. You can use these scores to help you sort, filter, and compare experiments.
* `metadata`: a JSON dictionary of additional data about the test example, model outputs, or just about anything else that's relevant, that you can use to help find and analyze examples later. In this example, the `id` is particularly helpful because it allows us to quickly test examples in the notebook (later on below) by setting `idx`. `output_sql` allows us to look at failure cases and understand how far off the SQL was from what we expected. `prompt` makes it easy to compare the actual prompts between examples and experiments. And so on.

In [None]:
def run_experiment(text2sql_fn):
  for i in range(NUM_TEST_EXAMPLES):
      print(f"{i+1}/{NUM_TEST_EXAMPLES}\r")
      query = data[i]
      gt_query = codegen_query(query)
      gt_answer = run_query(gt_query, query["table"])

      prompt, _, sql = text2sql_fn(query)
      try:
          answer = run_query(sql, query["table"])
      except Exception as e:
          answer = f"FAILED: {e}"

      bt.log(
          inputs={"question": query["question"]},
          output=answer,
          expected=gt_answer,
          scores={
              "answer": score(gt_answer, answer),
              "query": score(gt_query, sql),
          },
          metadata={
              "prompt": prompt,
              "gt_sql": gt_query,
              "output_sql": sql,
              "id": i,
          },
      )
  print(bt.summarize())

run_experiment(text2sql)

Take a look at the failures. Feel free to explore individual examples, filter down to low `answer` scores, etc. You should notice that `id=4` is one of the failures. Let's debug it and see if we can improve the prompt.

## Debugging a failure

In [None]:
idx=4

Let's start by looking at the ground truth:

In [None]:
print(data[idx]["question"])

table = get_table(data[idx]['table'])
print(duckdb.arrow(table).query("table", 'SELECT * FROM "table"'))

gt_sql = codegen_query(data[idx])
print(gt_sql)

print(duckdb.arrow(table).query("table", gt_sql))

And then what the model spits out:

In [None]:
prompt, resp, _ = text2sql(data[idx])
print(prompt + green(resp['choices'][0]['text']))

output_sql = resp['choices'][0]['text'].rstrip(";")
try:
  duckdb.arrow(table).query("table", output_sql)
except Exception as e:
  print(e)

Hmm, if only the model knew that `'Assen'` is a `Circuit`, not a `Round`. Let's provide some sample data for each column:

In [None]:
# Second attempt: provide the question, columns, and sample data
def text2sql_data(query):
    table = query["table"]
    rows = [
        {h: row[i] for (i, h) in enumerate(table["header"])}
        for row in table["rows"]
    ]
    meta = "\n".join(f'"{h}": {[row[h] for row in rows[:10]]}' for h in table["header"])

    prompt = f"""
Print a SQL query (over a table named "table" - the table is quoted with double quotes) that answers the question below.

You have the following columns (each with some sample data). The column
names may have spaces in them which you should escape with double quotes:

{meta}

The format should be
Question: the question to ask
SQL: the SQL to generate

Question: {query['question']}
SQL: """
    resp = openai_req(model="text-davinci-003", prompt=prompt, max_tokens=1024)

    return (
        prompt,
        resp,
        resp["choices"][0]["text"].rstrip(";")
        if len(resp["choices"]) > 0
        else None,
    )

prompt, resp, _ = text2sql_data(data[idx])
print(prompt + green(resp['choices'][0]['text']))

output_sql = resp['choices'][0]['text'].rstrip(";")
duckdb.arrow(table).query("table", output_sql)

Ok great! Now let's re-run the loop with this new version of the code.

In [None]:
bt = braintrust.init(project="text2sql-tutorial", experiment="with-data")
run_experiment(text2sql_data)

## Wrapping up

Congrats 🎉. You've run your first couple of experiments. Now, return back to the tutorial docs to proceed to the next step where we'll analyze the experiments.