### Setting up

In [1]:
# ! pip install anthropic

In [2]:
import json
with open('../../config.json') as f:
    config = json.load(f)

In [3]:
CH_HOST = 'http://localhost:8123' # default address 
import requests
import pandas as pd
import tqdm

def get_clickhouse_data(query, host = CH_HOST, connection_timeout = 1500):
  r = requests.post(host, params = {'query': query}, 
    timeout = connection_timeout)
  if r.status_code == 200:
      return r.text
  else: 
      return 'Database returned the following error:n' + r.text

In [4]:
import os
os.environ['ANTHROPIC_API_KEY'] = config['ANTHROPIC_API_KEY']

In [5]:
import anthropic
client = anthropic.Anthropic()

### Download evals

In [6]:
with open('./data/flight_data_qa_pairs.json', 'r') as f:
    qa_pairs = json.load(f)

qa_pairs_df = pd.DataFrame(qa_pairs)
qa_pairs_df = qa_pairs_df.rename(columns = {'sql': 'answer'})
qa_pairs_df['id'] = qa_pairs_df.index + 1
qa_pairs_df = qa_pairs_df[['id', 'question', 'answer']]
qa_pairs_df.head()

Unnamed: 0,id,question,answer
0,1,How many flights were in 2024?,select count() from flight_data where year = 2...
1,2,From which airport were the most flights depar...,"select origin, count() as flight_count from fl..."
2,3,What route on average took the longest time in...,"select concat(origin, ' - ', dest) as route, a..."
3,4,"In 2024, what percentage of time all airplanes...",select (sum(air_time) / sum(actual_elapsed_tim...
4,5,Which airline had the highest average departur...,"select op_unique_carrier, avg(dep_delay) as av..."


In [7]:
qa_pairs_df['answer_output'] = qa_pairs_df['answer'].apply(get_clickhouse_data)

### Direct answer generation

In [8]:
base_sql_system_prompt = '''
You are a senior SQL developer and your task is to help generate a SQL query based on user requirements. 
You are working with ClickHouse database. Specify the format (Tab Separated With Names) in the SQL query output to ensure that column names are included in the output.
Do not use count(*) in your queries since it's a bad practice with columnar databases, prefer using count().
Ensure that the query is syntactically correct and optimized for performance, taking into account ClickHouse specific features (i.e. that ClickHouse is a columnar database and supports functions like ARRAY JOIN, SAMPLE, etc.).
Return only the SQL query without any additional explanations or comments.

You will be working with flight_data table which has the following schema:

Column Name | Data Type | Null % | Example Value | Description
--- | --- | --- | --- | ---
year | Int64 | 0.0 | 2024 | Year of flight
month | Int64 | 0.0 | 1 | Month of flight (1–12)
day_of_month | Int64 | 0.0 | 1 | Day of the month
day_of_week | Int64 | 0.0 | 1 | Day of week (1=Monday … 7=Sunday)
fl_date | datetime64[ns] | 0.0 | 2024-01-01 00:00:00 | Flight date (YYYY-MM-DD)
op_unique_carrier | object | 0.0 | 9E | Unique carrier code
op_carrier_fl_num | float64 | 0.0 | 4814.0 | Flight number for reporting airline
origin | object | 0.0 | JFK | Origin airport code
origin_city_name | object | 0.0 | "New York, NY" | Origin city name
origin_state_nm | object | 0.0 | New York | Origin state name
dest | object | 0.0 | DTW | Destination airport code
dest_city_name | object | 0.0 | "Detroit, MI" | Destination city name
dest_state_nm | object | 0.0 | Michigan | Destination state name
crs_dep_time | Int64 | 0.0 | 1252 | Scheduled departure time (local, hhmm)
dep_time | float64 | 1.31 | 1247.0 | Actual departure time (local, hhmm)
dep_delay | float64 | 1.31 | -5.0 | Departure delay in minutes (negative if early)
taxi_out | float64 | 1.35 | 31.0 | Taxi out time in minutes
wheels_off | float64 | 1.35 | 1318.0 | Wheels-off time (local, hhmm)
wheels_on | float64 | 1.38 | 1442.0 | Wheels-on time (local, hhmm)
taxi_in | float64 | 1.38 | 7.0 | Taxi in time in minutes
crs_arr_time | Int64 | 0.0 | 1508 | Scheduled arrival time (local, hhmm)
arr_time | float64 | 1.38 | 1449.0 | Actual arrival time (local, hhmm)
arr_delay | float64 | 1.61 | -19.0 | Arrival delay in minutes (negative if early)
cancelled | int64 | 0.0 | 0 | Cancelled flight indicator (0=No, 1=Yes)
cancellation_code | object | 98.64 | B | Reason for cancellation (if cancelled)
diverted | int64 | 0.0 | 0 | Diverted flight indicator (0=No, 1=Yes)
crs_elapsed_time | float64 | 0.0 | 136.0 | Scheduled elapsed time in minutes
actual_elapsed_time | float64 | 1.61 | 122.0 | Actual elapsed time in minutes
air_time | float64 | 1.61 | 84.0 | Flight time in minutes
distance | float64 | 0.0 | 509.0 | Distance between origin and destination (miles)
carrier_delay | int64 | 0.0 | 0 | Carrier-related delay in minutes
weather_delay | int64 | 0.0 | 0 | Weather-related delay in minutes
nas_delay | int64 | 0.0 | 0 | National Air System delay in minutes
security_delay | int64 | 0.0 | 0 | Security delay in minutes
late_aircraft_delay | int64 | 0.0 | 0 | Late aircraft delay in minutes
'''

def generate_direct_sql(rec):
    message = client.messages.create(
        model="claude-3-5-haiku-latest",
        max_tokens = 8192,
        system=base_sql_system_prompt,
        messages=[
            {'role': 'user', 'content': rec['question']}
        ]
    )
    sql  = message.content[0].text
    if sql.endswith('```'):
        sql = sql[:-3]
    if sql.startswith('```sql'):
        sql = sql[6:]
    return sql

In [9]:
tmp = []
for rec in tqdm.tqdm(qa_pairs_df.to_dict('records')):
    llm_sql = generate_direct_sql(rec)
    tmp.append(
        {
            'id': rec['id'],
            'llm_direct_sql': llm_sql
        }
    )

llm_direct_df = pd.DataFrame(tmp)

100%|██████████| 20/20 [00:44<00:00,  2.25s/it]


In [10]:
direct_result_df = qa_pairs_df.merge(llm_direct_df, on = 'id')

In [11]:
direct_result_df['llm_direct_output'] = direct_result_df['llm_direct_sql'].apply(get_clickhouse_data)

In [12]:
def get_clickhouse_euristics(sql, output): 
    problems = []
    if 'format tabseparatedwithnames' not in sql.lower():
        problems.append('No format specified in SQL')
    if 'Database returned the following error' in output:
        problems.append('SQL execution error')
    if len(problems) == 0:
        return 'No problems detected'
    else:
        return ' + '.join(problems)

In [13]:
direct_result_df['llm_direct_sql_quality_heuristics'] = direct_result_df.apply(
    lambda row: get_clickhouse_euristics(row['llm_direct_sql'], row['llm_direct_output']), axis=1)

In [14]:
direct_result_df['llm_direct_sql_quality_heuristics'].value_counts()

llm_direct_sql_quality_heuristics
No problems detected                                16
No format specified in SQL + SQL execution error     2
No format specified in SQL                           1
SQL execution error                                  1
Name: count, dtype: int64

In [15]:
pd.set_option('display.max_colwidth', 5000)

In [180]:
# direct_result_df.head().T

### LLM judge

In [17]:
llm_judge_system_prompt = '''
You are a senior analyst and your task is to compare two SQL query results and determine if they are equivalent. 
Focus only on the data returned by the queries, ignoring any formatting differences. 
Take into account the initial user query and information needed to answer it. For example, if user asked for the average distance, and both queries return the same average value but in one of them there's also a count of records, you should consider them equivalent, since both provide the same requested information.

Answer with a JSON of the following structure:
{
  'reasoning': '<your reasoning here, 1-3 sentences on why you think they are equivalent or not>', 
  'equivalence': <true|false>
}
Ensure that ONLY JSON is in the output. 

You will be working with flight_data table which has the following schema:
Column Name | Data Type | Null % | Example Value | Description
--- | --- | --- | --- | ---
year | Int64 | 0.0 | 2024 | Year of flight
month | Int64 | 0.0 | 1 | Month of flight (1–12)
day_of_month | Int64 | 0.0 | 1 | Day of the month
day_of_week | Int64 | 0.0 | 1 | Day of week (1=Monday … 7=Sunday)
fl_date | datetime64[ns] | 0.0 | 2024-01-01 00:00:00 | Flight date (YYYY-MM-DD)
op_unique_carrier | object | 0.0 | 9E | Unique carrier code
op_carrier_fl_num | float64 | 0.0 | 4814.0 | Flight number for reporting airline
origin | object | 0.0 | JFK | Origin airport code
origin_city_name | object | 0.0 | "New York, NY" | Origin city name
origin_state_nm | object | 0.0 | New York | Origin state name
dest | object | 0.0 | DTW | Destination airport code
dest_city_name | object | 0.0 | "Detroit, MI" | Destination city name
dest_state_nm | object | 0.0 | Michigan | Destination state name
crs_dep_time | Int64 | 0.0 | 1252 | Scheduled departure time (local, hhmm)
dep_time | float64 | 1.31 | 1247.0 | Actual departure time (local, hhmm)
dep_delay | float64 | 1.31 | -5.0 | Departure delay in minutes (negative if early)
taxi_out | float64 | 1.35 | 31.0 | Taxi out time in minutes
wheels_off | float64 | 1.35 | 1318.0 | Wheels-off time (local, hhmm)
wheels_on | float64 | 1.38 | 1442.0 | Wheels-on time (local, hhmm)
taxi_in | float64 | 1.38 | 7.0 | Taxi in time in minutes
crs_arr_time | Int64 | 0.0 | 1508 | Scheduled arrival time (local, hhmm)
arr_time | float64 | 1.38 | 1449.0 | Actual arrival time (local, hhmm)
arr_delay | float64 | 1.61 | -19.0 | Arrival delay in minutes (negative if early)
cancelled | int64 | 0.0 | 0 | Cancelled flight indicator (0=No, 1=Yes)
cancellation_code | object | 98.64 | B | Reason for cancellation (if cancelled)
diverted | int64 | 0.0 | 0 | Diverted flight indicator (0=No, 1=Yes)
crs_elapsed_time | float64 | 0.0 | 136.0 | Scheduled elapsed time in minutes
actual_elapsed_time | float64 | 1.61 | 122.0 | Actual elapsed time in minutes
air_time | float64 | 1.61 | 84.0 | Flight time in minutes
distance | float64 | 0.0 | 509.0 | Distance between origin and destination (miles)
carrier_delay | int64 | 0.0 | 0 | Carrier-related delay in minutes
weather_delay | int64 | 0.0 | 0 | Weather-related delay in minutes
nas_delay | int64 | 0.0 | 0 | National Air System delay in minutes
security_delay | int64 | 0.0 | 0 | Security delay in minutes
late_aircraft_delay | int64 | 0.0 | 0 | Late aircraft delay in minutes
'''

llm_judge_user_prompt_template = '''
Here is the initial user query:
{user_query}

Here is the SQL query generated by the first analyst: 
SQL: 
{sql1} 

Database output: 
{result1}

Here is the SQL query generated by the second analyst:
SQL:
{sql2}

Database output:
{result2}
'''

def llm_judge(rec, field_to_check):
    user_prompt = llm_judge_user_prompt_template.format(
        user_query = rec['question'],
        sql1 = rec['answer'],
        result1 = rec['answer_output'],
        sql2 = rec[field_to_check + '_sql'],
        result2 = rec[field_to_check + '_output']
    )
    message = client.messages.create(
        model="claude-sonnet-4-5",
        max_tokens = 8192,
        temperature=0.1,
        system=llm_judge_system_prompt,
        messages=[
            {'role': 'user', 'content': user_prompt}
        ]
    )
    data  = message.content[0].text
    
    # Strip markdown code blocks
    data = data.strip()
    if data.startswith('```json'):
        data = data[7:]
    elif data.startswith('```'):
        data = data[3:]
    if data.endswith('```'):
        data = data[:-3]
    
    data = data.strip()
    
    # Debug: print the data if it's problematic
    if not data:
        raise ValueError(f"Empty response from API")
    
    return json.loads(data)

In [18]:
tmp = []

for rec in tqdm.tqdm(direct_result_df.to_dict('records')):
    try:
        judgment = llm_judge(rec, 'llm_direct')
    except Exception as e:
        print(f"Error processing record {rec['id']}: {e}")
        continue
    tmp.append(
        {
            'id': rec['id'],
            'llm_judge_reasoning': judgment['reasoning'],
            'llm_judge_equivalence': judgment['equivalence']
        }
    )

100%|██████████| 20/20 [01:28<00:00,  4.40s/it]


In [19]:
judge_df = pd.DataFrame(tmp)

In [20]:
judge_df.llm_judge_equivalence.value_counts()

llm_judge_equivalence
True     15
False     5
Name: count, dtype: int64

In [21]:
direct_result_df = direct_result_df.merge(judge_df, on = 'id')

In [22]:
# ! mkdir reflection_results

In [23]:
def get_final_result(heuristics, equivalence): 
    if equivalence: 
        return heuristics 
    if 'SQL execution error' in heuristics:
        return heuristics
    if heuristics == 'No problems detected':
        return 'Wrong answer provided'
    return heuristics + ' + Wrong answer provided'

direct_result_df['llm_final_result'] = direct_result_df.apply(
    lambda row: get_final_result(row['llm_direct_sql_quality_heuristics'], row['llm_judge_equivalence']), axis=1)

In [24]:
direct_result_df.groupby(['llm_final_result'], as_index=False).size()

Unnamed: 0,llm_final_result,size
0,No format specified in SQL,1
1,No format specified in SQL + SQL execution error,2
2,No problems detected,14
3,SQL execution error,1
4,Wrong answer provided,2


In [29]:
direct_result_df.to_csv('./reflection_results/direct_results.csv', index=False)

In [30]:
direct_result_df.to_clipboard(index=False)

### Simple reflection 

In [62]:
simple_reflection_results_df = direct_result_df.copy()\
    .drop(['llm_judge_reasoning', 'llm_judge_equivalence'], axis = 1).rename(
        columns = {'llm_final_result': 'llm_direct_final_result'}
    )

In [64]:
simple_reflection_user_prompt_template = '''
Your task is to assess the SQL query generated by another analyst and propose improvements if necessary.
Check whether the query is syntactically correct and optimized for performance. 
Pay attention to nuances in data (especially time stamps types, whether to use total elapsed time or time in the air, etc).
Ensure that the query answers the initial user question accurately. 
As the result return the following JSON: 
{{
  'reasoning': '<your reasoning here, 2-4 sentences on why you made changes or not>', 
  'refined_sql': '<the improved SQL query here>'
}}
Ensure that ONLY JSON is in the output and nothing else. Ensure that the output JSON is valid. 

Here is the initial user query:
{user_query}

Here is the SQL query generated by another analyst: 
{sql} 
'''

def simple_reflection(rec) -> str:
    user_prompt = simple_reflection_user_prompt_template.format(
        user_query=rec['question'],
        sql=rec['llm_direct_sql']
    )
    message = client.messages.create(
        model="claude-3-5-haiku-latest",
        max_tokens = 8192,
        system=base_sql_system_prompt,
        messages=[
            {'role': 'user', 'content': user_prompt}
        ]
    )

    data  = message.content[0].text
    # Strip markdown code blocks
    data = data.strip()
    if data.startswith('```json'):
        data = data[7:]
    elif data.startswith('```'):
        data = data[3:]
    if data.endswith('```'):
        data = data[:-3]
    
    data = data.strip()
    
    # Debug: print the data if it's problematic
    if not data:
        raise ValueError(f"Empty response from API")
    
    return json.loads(data.replace('\n', ' '))

In [65]:
tmp = []

for rec in tqdm.tqdm(simple_reflection_results_df.to_dict('records')):
    try:
        res = simple_reflection(rec)
        tmp.append(
            {
                'id': rec['id'],
                'llm_refined_sql': res['refined_sql'],
                'llm_refined_reasoning': res['reasoning']
            }
        )
    except Exception as e:
        print(f"Error processing record {rec['id']}: {e}")
        raise(e)

100%|██████████| 20/20 [01:14<00:00,  3.72s/it]


In [66]:
simple_reflection_results_df = simple_reflection_results_df.merge(pd.DataFrame(tmp), on = 'id')

In [67]:
simple_reflection_results_df.shape[0]

20

In [68]:
simple_reflection_results_df['llm_refined_output'] = simple_reflection_results_df.llm_refined_sql\
    .map(get_clickhouse_data)

In [69]:
simple_reflection_results_df['llm_refined_sql_quality_heuristics'] = simple_reflection_results_df.apply(
    lambda row: get_clickhouse_euristics(row['llm_refined_sql'], row['llm_refined_output']), axis=1)

In [70]:
tmp = []

for rec in tqdm.tqdm(simple_reflection_results_df.to_dict('records')):
    try:
        judgment = llm_judge(rec, 'llm_refined')
    except Exception as e:
        print(f"Error processing record {rec['id']}: {e}")
        continue
    tmp.append(
        {
            'id': rec['id'],
            'llm_judge_reasoning': judgment['reasoning'],
            'llm_judge_equivalence': judgment['equivalence']
        }
    )

judge_df = pd.DataFrame(tmp)

100%|██████████| 20/20 [01:28<00:00,  4.41s/it]


In [71]:
simple_reflection_results_df = simple_reflection_results_df.merge(judge_df, on = 'id')
simple_reflection_results_df.shape

(20, 14)

In [72]:
simple_reflection_results_df['llm_final_result'] = simple_reflection_results_df.apply(
    lambda row: get_final_result(row['llm_refined_sql_quality_heuristics'], row['llm_judge_equivalence']), axis=1)

In [73]:
simple_reflection_results_df.groupby(['llm_final_result'], as_index=False).size()

Unnamed: 0,llm_final_result,size
0,No format specified in SQL + SQL execution error,2
1,No problems detected,14
2,SQL execution error,2
3,Wrong answer provided,2


In [77]:
simple_reflection_results_df.groupby(['llm_direct_final_result', 'llm_final_result'], as_index=False).size()\
    .sort_values('size', ascending=False)

Unnamed: 0,llm_direct_final_result,llm_final_result,size
2,No problems detected,No problems detected,12
1,No format specified in SQL + SQL execution error,No format specified in SQL + SQL execution error,2
0,No format specified in SQL,No problems detected,1
3,No problems detected,SQL execution error,1
4,No problems detected,Wrong answer provided,1
5,SQL execution error,SQL execution error,1
6,Wrong answer provided,No problems detected,1
7,Wrong answer provided,Wrong answer provided,1


In [74]:
simple_reflection_results_df.to_csv('./reflection_results/simple_reflection_results.csv', index=False)

In [75]:
simple_reflection_results_df.to_clipboard(index = False)

### Reflection with feedback

In [132]:
feedback_reflection_results_df = direct_result_df.copy()\
    .drop(['llm_judge_reasoning', 'llm_judge_equivalence'], axis = 1)\
    .rename(
        columns = {
            'llm_final_result': 'llm_direct_final_result'
        }
    )

In [181]:
# feedback_reflection_results_df.head().T

In [134]:
feedback_reflection_user_prompt_template = '''
Your task is to assess the SQL query generated by another analyst and propose improvements if necessary.
Check whether the query is syntactically correct and optimized for performance. 
Pay attention to nuances in data (especially time stamps types, whether to use total elapsed time or time in the air, etc).
Ensure that the query answers the initial user question accurately. 

As the result return the following JSON: 
{{
  'reasoning': '<your reasoning here, 2-4 sentences on why you made changes or not>', 
  'refined_sql': '<the improved SQL query here>'
}}
Ensure that ONLY JSON is in the output and nothing else. Ensure that the output JSON is valid. 


Here is the initial user query:
{user_query}

Here is the SQL query generated by another analyst: 
{sql} 

Here is the database output of this query: 
{output}

We run an automatic check on the SQL query to check whether it has fomatting issues. Here's the output: 
{formatting}
'''

def feedback_reflection(rec) -> str:
    if 'No format specified in SQL' in rec['llm_direct_sql_quality_heuristics']:
        formatting = 'SQL missing formatting. Specify "format TabSeparatedWithNames" to ensure that column names are also returned'
    else: 
        formatting = 'Formatting is correct'
    user_prompt = feedback_reflection_user_prompt_template.format(
        user_query=rec['question'],
        sql=rec['llm_direct_sql'],
        output=rec['llm_direct_output'],
        formatting=formatting
    )
    message = client.messages.create(
        model="claude-3-5-haiku-latest",
        max_tokens = 8192,
        system=base_sql_system_prompt,
        messages=[
            {'role': 'user', 'content': user_prompt}
        ]
    )
    data  = message.content[0].text
    # Strip markdown code blocks
    data = data.strip()
    if data.startswith('```json'):
        data = data[7:]
    elif data.startswith('```'):
        data = data[3:]
    if data.endswith('```'):
        data = data[:-3]
    
    data = data.strip()
    
    # Debug: print the data if it's problematic
    if not data:
        raise ValueError(f"Empty response from API")
    
    return json.loads(data.replace('\n', ' '))

In [135]:
tmp = []

for rec in tqdm.tqdm(feedback_reflection_results_df.to_dict('records')):
    try:
        res = feedback_reflection(rec)
    except Exception as e:
        print(f"Error processing record {rec['id']}: {e}")
        continue
    tmp.append(
        {
            'id': rec['id'],
            'llm_refined_sql': res['refined_sql'],
            'llm_refined_reasoning': res['reasoning']
        }
    )

100%|██████████| 20/20 [01:17<00:00,  3.88s/it]


In [136]:
feedback_reflection_results_df = feedback_reflection_results_df.merge(pd.DataFrame(tmp), on = 'id')

In [137]:
feedback_reflection_results_df.shape[0]

20

In [138]:
feedback_reflection_results_df['llm_refined_output'] = feedback_reflection_results_df.llm_refined_sql\
    .map(get_clickhouse_data)

In [139]:
feedback_reflection_results_df['llm_refined_sql_quality_heuristics'] = feedback_reflection_results_df.apply(
    lambda row: get_clickhouse_euristics(row['llm_refined_sql'], row['llm_refined_output']), axis=1)

In [140]:
tmp = []

for rec in tqdm.tqdm(feedback_reflection_results_df.to_dict('records')):
    try:
        judgment = llm_judge(rec, 'llm_refined')
    except Exception as e:
        print(f"Error processing record {rec['id']}: {e}")
        continue
    tmp.append(
        {
            'id': rec['id'],
            'llm_judge_reasoning': judgment['reasoning'],
            'llm_judge_equivalence': judgment['equivalence']
        }
    )

judge_df = pd.DataFrame(tmp)

100%|██████████| 20/20 [01:33<00:00,  4.65s/it]


In [141]:
feedback_reflection_results_df = feedback_reflection_results_df.merge(judge_df, on = 'id')
feedback_reflection_results_df.shape

(20, 14)

In [142]:
feedback_reflection_results_df['llm_final_result'] = feedback_reflection_results_df.apply(
    lambda row: get_final_result(row['llm_refined_sql_quality_heuristics'], row['llm_judge_equivalence']), axis=1)

In [154]:
# manual labelling corrections
# feedback_reflection_results_df['llm_final_result'] = list(map(
#     lambda id, result: result if id != 15 else 'No problems detected',
#     feedback_reflection_results_df.id,
#     feedback_reflection_results_df.llm_final_result
# ))

In [149]:
feedback_reflection_results_df.groupby(['llm_final_result'], as_index=False).size()

Unnamed: 0,llm_final_result,size
0,No problems detected,17
1,Wrong answer provided,3


In [150]:
# feedback_reflection_results_df.groupby(['llm_direct_final_result', 'llm_final_result'], as_index=False).size()\
#     .sort_values('size', ascending = False)

In [151]:
feedback_reflection_results_df.to_clipboard(index = False)

### Compare results

In [155]:
cmp_df = pd.DataFrame()

In [156]:
cmp_df['direct generation'] = direct_result_df.groupby(['llm_final_result'], as_index=True).id.count()
cmp_df['simple reflection'] = simple_reflection_results_df.groupby(['llm_final_result'], as_index=True).id.count()
cmp_df['feedback reflection'] = feedback_reflection_results_df.groupby(['llm_final_result'], as_index=True).id.count()

In [157]:
cmp_df.fillna(0, inplace=True)

In [158]:
cmp_df.sort_values('direct generation', ascending=False, inplace=True)

In [159]:
cmp_df

Unnamed: 0_level_0,direct generation,simple reflection,feedback reflection
llm_final_result,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
No problems detected,14,14.0,17.0
No format specified in SQL + SQL execution error,2,2.0,0.0
Wrong answer provided,2,2.0,3.0
No format specified in SQL,1,0.0,0.0
SQL execution error,1,2.0,0.0


In [None]:
import plotly

def blend_colours(colour1, colour2, weight):
    """
    Blend two colors in hexadecimal format.
    
    Args:
        colour1: First color in hex format (e.g., '#4C78A8')
        colour2: Second color in hex format (e.g., '#FF5733')
        weight: Weight of first color (0 to 1), where 0 = all colour2, 1 = all colour1
    
    Returns:
        Blended color in hex format (e.g., '#A34D6D')
    """
    # Remove '#' if present
    c1 = colour1.lstrip('#')
    c2 = colour2.lstrip('#')
    
    # Convert hex to RGB
    r1, g1, b1 = int(c1[0:2], 16), int(c1[2:4], 16), int(c1[4:6], 16)
    r2, g2, b2 = int(c2[0:2], 16), int(c2[2:4], 16), int(c2[4:6], 16)
    
    # Blend RGB values
    r = int(r1 * weight + r2 * (1 - weight))
    g = int(g1 * weight + g2 * (1 - weight))
    b = int(b1 * weight + b2 * (1 - weight))
    
    # Convert back to hex
    return f'#{r:02X}{g:02X}{b:02X}'

In [None]:
import plotly.express as px 
import plotly.io as pio
pio.templates.default = 'simple_white'
px.bar(cmp_df.T, orientation = 'h', 
       title = '<b>Text-to-SQL:</b> answer accuracy',
    labels = {
        'value': 'number of questions',
        'index': '', 
        'llm_final_result': 'accuracy'
    }, text_auto = 'd',
    category_orders = {
        'llm_final_result': ['No problems detected', 'No format specified in SQL', 
                     'Wrong answer provided', 'SQL execution error', 'No format specified in SQL + SQL execution error'],
        'index': ['direct generation', 'simple reflection', 'feedback reflection']
    }, 
    color_discrete_map = {
        'No problems detected': plotly.colors.qualitative.T10[4], 
        'No format specified in SQL': plotly.colors.qualitative.T10[3], 
        'Wrong answer provided': plotly.colors.qualitative.T10[1], 
        'SQL execution error': plotly.colors.qualitative.T10[2], 
        'No format specified in SQL + SQL execution error': plotly.colors.qualitative.G10[8]
    }    
)

In [185]:
import plotly.express as px 
import plotly.io as pio
pio.templates.default = 'simple_white'
px.bar(cmp_df[['direct generation', 'simple reflection']].T, orientation = 'h', 
       title = '<b>Text-to-SQL:</b> answer accuracy',
    labels = {
        'value': 'number of questions',
        'index': '', 
        'llm_final_result': 'accuracy'
    }, text_auto = 'd',
    category_orders = {
        'llm_final_result': ['No problems detected', 'No format specified in SQL', 
                     'Wrong answer provided', 'SQL execution error', 'No format specified in SQL + SQL execution error'],
        'index': ['direct generation', 'simple reflection', 'feedback reflection']
    }, 
    color_discrete_map = {
        'No problems detected': plotly.colors.qualitative.T10[4], 
        'No format specified in SQL': plotly.colors.qualitative.T10[3], 
        'Wrong answer provided': plotly.colors.qualitative.T10[1], 
        'SQL execution error': plotly.colors.qualitative.T10[2], 
        'No format specified in SQL + SQL execution error': plotly.colors.qualitative.G10[8]
    }    
)

In [188]:
import plotly.express as px 
import plotly.io as pio
pio.templates.default = 'simple_white'
px.bar(cmp_df[['direct generation']].T, orientation = 'h', 
       title = '<b>Text-to-SQL:</b> answer accuracy',
    labels = {
        'value': 'number of questions',
        'index': '', 
        'llm_final_result': 'accuracy'
    }, text_auto = 'd',
    category_orders = {
        'llm_final_result': ['No problems detected', 'No format specified in SQL', 
                     'Wrong answer provided', 'SQL execution error', 'No format specified in SQL + SQL execution error'],
        'index': ['direct generation', 'simple reflection', 'feedback reflection']
    }, 
    color_discrete_map = {
        'No problems detected': plotly.colors.qualitative.T10[4], 
        'No format specified in SQL': plotly.colors.qualitative.T10[3], 
        'Wrong answer provided': plotly.colors.qualitative.T10[1], 
        'SQL execution error': plotly.colors.qualitative.T10[2], 
        'No format specified in SQL + SQL execution error': plotly.colors.qualitative.G10[8]
    }, height = 350  
)