In [66]:
import pandas as pd
import altair as alt
from queries import QUERIES

In [88]:
df = pd.read_csv("LLM_performance_results_1.csv")
# remove two tasks with buggy initial evaluation code
df = df[~df['query_name'].isin(['Carrying Values Forward', 'Aggregating Counts at Event Times'])]
# add in working evaluations
df = pd.concat([df, pd.read_csv("LLM_performance_results_2.csv"), pd.read_csv("LLM_performance_results.csv")])
df['result'] = df['result'].replace({'valid': 'Correct', 'invalid': 'Incorrect', 'error': 'Invalid'})
df

Unnamed: 0,query_name,iteration,method,result,error,query
0,Attributes,1,TempoQL,Correct,,{Anchor Age} + (({Admit Time} - {Anchor Year})...
1,Attributes,1,SQL,Correct,,"SELECT\n icu.stay_id,\n pat.anchor_age + DAT..."
2,Events,1,TempoQL,Correct,,"{\n scope = chartevents;\n name in (\n ""R..."
3,Events,1,SQL,Incorrect,,WITH rr_items AS (\n SELECT\n itemid\n FR...
4,String Operations,1,TempoQL,Correct,,{Diagnosis; scope = Diagnosis} startswith /(40...
...,...,...,...,...,...,...
3,Carrying Values Forward,9,SQL,Invalid,Reason: 400 No matching signature for function...,WITH\n hourly_grid AS (\n SELECT\n st...
4,Aggregating Counts at Event Times,10,TempoQL,Correct,,(\n count(\n start({Cardioversion/Defibril...
5,Aggregating Counts at Event Times,10,SQL,Incorrect,,"-- Assumptions:\n-- 1. ""Heart Rhythm"" events a..."
6,Carrying Values Forward,10,TempoQL,Correct,,(\n first {O2 Delivery Device(s); scope = cha...


In [89]:
chart = alt.Chart(df)
text = chart.mark_text(
  align='right',
  baseline='middle',
  color='white',
  dx=-6  # Nudges text to right so it doesn't overlap bar
).encode(
  x=alt.X('count()', stack='normalize'),
  y=alt.Y('method:N', sort=['TempoQL', 'SQL']),
  detail='result:N',
  text=alt.Text('frac:Q', format='.0%'),
  order=alt.Order(
      'result',
      sort='descending'
    )
).transform_joinaggregate(
    total='count()',
    groupby=['method']  
).transform_joinaggregate(
    count='count()',
    groupby=['method', 'result']  
).transform_calculate(
    frac=alt.datum.count / alt.datum.total
)
(chart.mark_bar().encode(
    x=alt.X("count()", title="Proportion of Query Trials").stack('normalize'),
    y=alt.Y('method:N', sort=['TempoQL', 'SQL'], title=None),
    color=alt.Color('result:N', sort=['Correct', 'Incorrect', 'Invalid']),
    order=alt.Order(
      'result',
      sort='ascending'
    )
)).properties(title="LLM Query Authoring Quality")

In [91]:
(chart.mark_bar().encode(
    x=alt.X("count()", title="Proportion of Query Trials").stack('normalize'),
    yOffset=alt.YOffset('method:N', sort=['TempoQL', 'SQL'], scale=alt.Scale(padding=0.05)),
    y=alt.Y('query_name:N', sort=[q["name"] for q in QUERIES], title=None),
    color=alt.Color('result:N', sort=['Correct', 'Incorrect', 'Invalid']),
    order=alt.Order(
      'result',
      sort='ascending'
    )
) + chart.mark_text(align='left', color='white').encode(
  x=alt.value(6),
  yOffset=alt.YOffset('method:N', sort=['TempoQL', 'SQL'], scale=alt.Scale(padding=0.05)),
    y=alt.Y('query_name:N', sort=[q["name"] for q in QUERIES], title=None),
    text='method:N'
    ).transform_aggregate(total='count()', groupby=['method', 'query_name'])).properties(title="LLM Query Authoring Quality")

In [92]:
from queries import QUERIES, SQL_PREFIX
import re

lengths = []
for query in QUERIES:
    sql_query = query["sql"].replace(SQL_PREFIX, "")
    lengths.append({
        "tql_chars": len(re.sub(r"[\s\r\n]", "", query["tempoql"])), 
        "sql_chars": len(re.sub(r"[\s\r\n]", "", sql_query)),
        "tql_words": len(re.split(r"[\s\r\n]+", re.sub(r"[^A-Za-z0-9\s\r\n]", "", query["tempoql"]))),
        "sql_words": len(re.split(r"[\s\r\n]+", re.sub(r"[^A-Za-z0-9\s\r\n]", "", sql_query))),
    })
lengths = pd.DataFrame(lengths)

In [93]:
(lengths['sql_words'] / lengths['tql_words']).mean()

5.95628073140945

In [94]:
(lengths['sql_chars'] / lengths['tql_chars']).mean()

7.822606119697116