# Text-to-SQL-on-FHIR Evaluation
This is a Jupyter notebook for evaluating Text-to-SQL-on-FHIR.
For various experiments for developing the [querygen](querygen)
package, including indexing of the MIMIC-IV FHIR dataset,
see [code_embedding notebook](code_embedding.ipynb).
All evaluations in this notebook are done against the MIMIC-IV dataset.

The MIMIC-IV FHIR dataset was converted to Parquet files, both the
"wide schema" and "flat views", using the
[fhir-data-pipes conversion pipeline](https://github.com/google/fhir-data-pipes/blob/f01623433d38d1a72bb0eb4231512a5f57ad3b04/docker/compose-controller-spark-sql-single.yaml#L53).
The queries/experiments are all against the
[flat views](https://github.com/google/fhir-data-pipes/tree/master/docker/config/views).
The query engine in this notebook, is a local Spark engine run as a
[docker container](https://github.com/google/fhir-data-pipes/blob/f01623433d38d1a72bb0eb4231512a5f57ad3b04/docker/compose-controller-spark-sql-single.yaml#L80).

For a summary of results in this notebook, see
[this 2025 FHIR DevDay presentation](https://bit.ly/text-to-sql-on-fhir-DD25).

In [20]:
import pandas as pd
from sqlalchemy import dialects, engine

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 200)
pd.set_option('display.max_colwidth', None)

dialects.registry.register("hive", "pyhive.sqlalchemy_hive", "HiveDialect")

spark_query_engine = engine.create_engine("hive://localhost:10001/default")
#query_engine = engine.create_engine("hive://spark-thriftserver:10000/default")

pd.read_sql_query(
    sql="""
    SELECT COUNT(*) FROM patient_mimic_flat;
    """,
    con=spark_query_engine,
)

Unnamed: 0,count(1)
0,299712


In [21]:
pd.read_sql_query(
    sql="""
    SELECT COUNT(*) FROM observation_mimic_flat;
    """,
    con=spark_query_engine,
)

Unnamed: 0,count(1)
0,461098908


In [22]:
pd.read_sql_query(
    sql="""
    SELECT COUNT(*) FROM encounter_mimic_flat;
    """,
    con=spark_query_engine,
)

Unnamed: 0,count(1)
0,1687781


# Using the `querygen` package
This package basically puts together all the prompt pieces in a set of
reusable Python modules.

In [25]:
# You need a GCP project to query the Gemini model through Vertext AI.
%env GOOGLE_CLOUD_PROJECT=bashir-genai
%env GOOGLE_CLOUD_LOCATION=us-central1
%env GOOGLE_GENAI_USE_VERTEXAI=True
import os
print(os.environ['GOOGLE_CLOUD_PROJECT'])

env: GOOGLE_CLOUD_PROJECT=bashir-genai
env: GOOGLE_CLOUD_LOCATION=us-central1
env: GOOGLE_GENAI_USE_VERTEXAI=True
bashir-genai


In [16]:
# This is to avoid reloading the Stanza pipeline every time:
import stanza
nlp_ner = stanza.Pipeline(
                lang='en', processors='tokenize,ner',
                package={"ner": ["ncbi_disease", "i2b2", "radiology",
                                 "ontonotes-ww-multi_charlm"]})

2025-07-25 16:19:40 INFO: Checking for updates to resources.json in case models have been updated.  Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES


Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.10.0.json:   0%|  …

2025-07-25 16:19:40 INFO: Downloaded file to /usr/local/google/home/bashir/stanza_resources/resources.json
2025-07-25 16:19:41 INFO: Loading these models for language: en (English):
| Processor | Package                                               |
---------------------------------------------------------------------
| tokenize  | combined                                              |
| mwt       | combined                                              |
| ner       | ncbi_disease;i2b2;radiology;ontonotes-ww-multi_charlm |

2025-07-25 16:19:41 INFO: Using device: cpu
2025-07-25 16:19:41 INFO: Loading: tokenize
2025-07-25 16:19:41 INFO: Loading: mwt
2025-07-25 16:19:41 INFO: Loading: ner
2025-07-25 16:19:48 INFO: Done loading processors!


In [17]:
from querygen import sqlgen, column_sampler, db_description, util
import querygen
querygen.PRINT_CLOSE_CONCEPTS = True
querygen.PRINT_FINAL_PROMPT = False
# querygen.PRINT_FINAL_PROMPT = True
querygen.PRINT_MODEL_RESPONSE = False

# The reloads are to be able to quickly import latest external changes.
from importlib import reload
sqlgen = reload(sqlgen)
column_sampler = reload(column_sampler)
db_description = reload(db_description)
util = reload(util)

query_instance = sqlgen.SqlGen(
    target_db_url='hive://localhost:10001/default',
    pg_vector_db_url='postgresql://postgres:admin@localhost:5438/codevec',
    nlp_ner=nlp_ner
)

In [18]:
print(query_instance.generate_sql('What is the total number of patients?', table_suffix='_mimic_flat'))

CLOSE CONCEPTS: 

```sql
--- BEGIN
SELECT COUNT(*)  -- Counting the number of rows will give the total number of patients
FROM patient_mimic_flat;  -- The patient information is stored in the patient_mimic_flat table
--- END
```


# More column samples (10)
Previous experiments indicated that having 10 or more samples from each
column is beneficial in better SQL generation.

In [3]:
## Retry with more samples
from querygen import sqlgen, column_sampler, db_description, util, terminology_indexer
import querygen
querygen.PRINT_CLOSE_CONCEPTS = True
querygen.PRINT_FINAL_PROMPT = False
#querygen.PRINT_FINAL_PROMPT = True
querygen.PRINT_MODEL_RESPONSE = False

# The reloads are to be able to quickly import latest external changes.
from importlib import reload
sqlgen = reload(sqlgen)
column_sampler = reload(column_sampler)
db_description = reload(db_description)
util = reload(util)
terminology_indexer = reload(terminology_indexer)

# Note default CLOSE_CONCEPT_DISTANCE_THRESHOLD is 0.7 and MAX_CLOSE_CONCEPTS is 20.
query_instance_10_07_20 = sqlgen.SqlGen(
    target_db_url='hive://localhost:10001/default',
    pg_vector_db_url='postgresql://postgres:admin@localhost:5438/codevec',
    nlp_ner=nlp_ner,
    num_column_samples=10
)

term_indexer = terminology_indexer.TerminologyIndexer(
    target_db_url='hive://localhost:10001/default',
    pg_vector_db_url='postgresql://postgres:admin@localhost:5438/codevec')

# Oxygen levels and supplement
For some background see [this paper](https://jamanetwork.com/journals/jamainternalmedicine/fullarticle/2794196).

In [27]:
# Retrying the previous query with the vector DB for all codes (not partial)
query_instance_10_07_20.iterative_gen_sql(
    query='''
    For patients that have arterial hemoglobin oxygen saturation level less than 90%,
    count the number that received supplemental oxygen versus those who didn't.
    Group the results by oxygen saturation level deciles, e.g., between 80% to 90%, 70% to 80%, and so on.
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "arterial hemoglobin oxygen saturation level" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '2708-6', 'Oxygen saturation in Arterial blood'), ('observation_mimic_flat', 'code_code', '220227', 'Arterial O2 Saturation'), ('observation_mimic_flat', 'code_code', '50817', 'Oxygen Saturation'), ('observation_mimic_flat', 'code_code', '220224', 'Arterial O2 pressure'), ('observation_mimic_flat', 'code_code', '220277', 'O2 saturation pulseoxymetry'), ('observation_mimic_flat', 'code_code', '226063', 'Venous O2 Pressure'), ('observation_mimic_flat', 'code_code', '226541', 'ScvO2 Central Venous O2% Sat'), ('observation_mimic_flat', 'code_code', '51645', 'Hemoglobin, Calculated'), ('observation_mimic_flat', 'code_code', '228232', 'PAR-Oxygen saturation'), ('procedure_mimic_flat', 'code_code', '8965', 'M

In [28]:
# Simplifying the query for debug (removing grouping):
query_instance_10_07_20.iterative_gen_sql(
    query='''
    For patients that have arterial hemoglobin oxygen saturation level less than 90%,
    count the number that received supplemental oxygen versus those who didn't.
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "arterial hemoglobin oxygen saturation level" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '2708-6', 'Oxygen saturation in Arterial blood'), ('observation_mimic_flat', 'code_code', '220227', 'Arterial O2 Saturation'), ('observation_mimic_flat', 'code_code', '50817', 'Oxygen Saturation'), ('observation_mimic_flat', 'code_code', '220224', 'Arterial O2 pressure'), ('observation_mimic_flat', 'code_code', '220277', 'O2 saturation pulseoxymetry'), ('observation_mimic_flat', 'code_code', '226063', 'Venous O2 Pressure'), ('observation_mimic_flat', 'code_code', '226541', 'ScvO2 Central Venous O2% Sat'), ('observation_mimic_flat', 'code_code', '51645', 'Hemoglobin, Calculated'), ('observation_mimic_flat', 'code_code', '228232', 'PAR-Oxygen saturation'), ('procedure_mimic_flat', 'code_code', '8965', 'M

In [29]:
# Making the NLQ more specific by adding encounter constraint:
query_instance_10_07_20.iterative_gen_sql(
    query='''
    For patients that have arterial hemoglobin oxygen saturation level less than 90% in an encounter,
    count the number that received supplemental oxygen in the same encounter versus those who didn't.
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "arterial hemoglobin oxygen saturation level" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '2708-6', 'Oxygen saturation in Arterial blood'), ('observation_mimic_flat', 'code_code', '220227', 'Arterial O2 Saturation'), ('observation_mimic_flat', 'code_code', '50817', 'Oxygen Saturation'), ('observation_mimic_flat', 'code_code', '220224', 'Arterial O2 pressure'), ('observation_mimic_flat', 'code_code', '220277', 'O2 saturation pulseoxymetry'), ('observation_mimic_flat', 'code_code', '226063', 'Venous O2 Pressure'), ('observation_mimic_flat', 'code_code', '226541', 'ScvO2 Central Venous O2% Sat'), ('observation_mimic_flat', 'code_code', '51645', 'Hemoglobin, Calculated'), ('observation_mimic_flat', 'code_code', '228232', 'PAR-Oxygen saturation'), ('procedure_mimic_flat', 'code_code', '8965', 'M

In [30]:
# Repeating the same cell with more rounds:
# Making the NLQ more specific by adding encounter constraint:
query_instance_10_07_20.iterative_gen_sql(
    query='''
    For patients that have arterial hemoglobin oxygen saturation level less than 90% in an encounter,
    count the number that received supplemental oxygen in the same encounter versus those who didn't.
    ''',
    table_suffix='_mimic_flat',
    num_rounds=10
)

CLOSE CONCEPTS: 

            For "arterial hemoglobin oxygen saturation level" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '2708-6', 'Oxygen saturation in Arterial blood'), ('observation_mimic_flat', 'code_code', '220227', 'Arterial O2 Saturation'), ('observation_mimic_flat', 'code_code', '50817', 'Oxygen Saturation'), ('observation_mimic_flat', 'code_code', '220224', 'Arterial O2 pressure'), ('observation_mimic_flat', 'code_code', '220277', 'O2 saturation pulseoxymetry'), ('observation_mimic_flat', 'code_code', '226063', 'Venous O2 Pressure'), ('observation_mimic_flat', 'code_code', '226541', 'ScvO2 Central Venous O2% Sat'), ('observation_mimic_flat', 'code_code', '51645', 'Hemoglobin, Calculated'), ('observation_mimic_flat', 'code_code', '228232', 'PAR-Oxygen saturation'), ('procedure_mimic_flat', 'code_code', '8965', 'M

## Difference between SpO2 and blood oxygen saturation
Here we try to differentiate between oxygen measured by "near-infrared pulse oximeter" (SpO2)
and oxygen saturation measured by Arterial Blood Gas (ABG) test.

In [16]:
query_instance_10_07_20.iterative_gen_sql(
    query='''
    For patients that had their blood oxygen level measured both
    through pulse oximeter (SpO2) and Arterial Blood Gas (ABG) test
    in the same encounter, count the number of those that the
    difference between the two values was at most 2% vs those
    that it was greater than 2%.
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "their blood oxygen level" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
            For "pulse oximeter" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('procedure_mimic_flat', 'code_code', '7538', 'Fetal pulse oximetry'), ('observation_mimic_flat', 'code_code', '220277', 'O2 saturation pulseoxymetry'), ('observation_mimic_flat', 'code_code', '50821', 'pO2'), ('observation_mimic_flat', 'code_code', '223770', 'O2 Saturation Pulseoxymetry Alarm - Low'), ('observation_mimic_flat', 'code_code', '223769', 'O2 Saturation Pulseoxymetry Alarm - High'), ('procedure_mimic_flat', 'code_code', '8963', 'Pulmonary artery pressure monitoring'), ('observation_mimic_flat', 'code_code', '50832', 'pO2, Body Fluid'

In [33]:
# A different wording:
query_instance_10_07_20.iterative_gen_sql(
    query='''
    For encounters in which the blood oxygen level of
    the patient is measured through both pulse oximeter (SpO2)
    and Arterial Blood Gas (ABG) test, find the average
    for each method and count the number of those that
    the difference between those two average values
    is at most 2% vs those that it is greater than 2%.
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "the blood oxygen level" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '2708-6', 'Oxygen saturation in Arterial blood'), ('observation_mimic_flat', 'code_code', '220224', 'Arterial O2 pressure'), ('observation_mimic_flat', 'code_code', '50817', 'Oxygen Saturation'), ('observation_mimic_flat', 'code_code', '220227', 'Arterial O2 Saturation'), ('observation_mimic_flat', 'code_code', '50821', 'pO2'), ('observation_mimic_flat', 'code_code', '226063', 'Venous O2 Pressure'), ('observation_mimic_flat', 'code_code', '220277', 'O2 saturation pulseoxymetry'), ('observation_mimic_flat', 'code_code', '50816', 'Oxygen'), ('condition_mimic_flat', 'code_code', 'R0902', 'Hypoxemia'), ('condition_mimic_flat', 'code_code', '79902', 'Hypoxemia'), ('observation_mimic_flat', 'code_code', '50823', 'Required O2'), 

### Debugging concept codes

In [47]:
pd.set_option('display.max_colwidth', None)
obs_oxygen = pd.read_sql_query(
    sql="""
    SELECT OCC.`system`, OCC.code, OCC.display, COUNT(*) AS num_obs
    FROM observation_mimic AS O LATERAL VIEW explode(code.coding) AS OCC
    WHERE OCC.code = '220277' OR OCC.code = '220227' OR OCC.code = '50817'
    GROUP BY OCC.`system`, OCC.code, OCC.display
    """,
    con=spark_query_engine,
)
obs_oxygen

Unnamed: 0,system,code,display,num_obs
0,http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-chartevents-d-items,220277,O2 saturation pulseoxymetry,6324341
1,http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-chartevents-d-items,220227,Arterial O2 Saturation,87394
2,http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-labitems,50817,Oxygen Saturation,176225


In [63]:
from querygen import sqlgen, column_sampler, db_description, util, terminology_indexer
import querygen
querygen.PRINT_CLOSE_CONCEPTS = True
querygen.PRINT_FINAL_PROMPT = False
#querygen.PRINT_FINAL_PROMPT = True
querygen.PRINT_MODEL_RESPONSE = False

# The reloads are to be able to quickly import latest external changes.
from importlib import reload
column_sampler = reload(column_sampler)
db_description = reload(db_description)
util = reload(util)
terminology_indexer = reload(terminology_indexer)
sqlgen = reload(sqlgen)

query_instance_10_07_20 = sqlgen.SqlGen(
    target_db_url='hive://localhost:10001/default',
    pg_vector_db_url='postgresql://postgres:admin@localhost:5438/codevec',
    nlp_ner=nlp_ner,
    num_column_samples=10,
    close_threshold=0.7,
    max_close=20
)

term_indexer = terminology_indexer.TerminologyIndexer(
    target_db_url='hive://localhost:10001/default',
    pg_vector_db_url='postgresql://postgres:admin@localhost:5438/codevec')

In [78]:
query_instance_10_07_20.iterative_gen_sql(
    query='''
    Count the number of patients that within an hour, their blood oxygen level is
    measured through both pulse oximeter (SpO2) and Arterial Blood Gas (ABG) test
    and the difference of those two measurements is more than than 2%.
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "their blood oxygen level" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
            For "pulse oximeter" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('procedure_mimic_flat', 'code_code', '7538', 'Fetal pulse oximetry'), ('observation_mimic_flat', 'code_code', '220277', 'O2 saturation pulseoxymetry'), ('observation_mimic_flat', 'code_code', '50821', 'pO2'), ('observation_mimic_flat', 'code_code', '223770', 'O2 Saturation Pulseoxymetry Alarm - Low'), ('observation_mimic_flat', 'code_code', '223769', 'O2 Saturation Pulseoxymetry Alarm - High'), ('procedure_mimic_flat', 'code_code', '8963', 'Pulmonary artery pressure monitoring'), ('observation_mimic_flat', 'code_code', '50832', 'pO2, Body Fluid'

In [31]:
query_instance_10_07_20.iterative_gen_sql(
    query='''
    Count the number of patients for whom the blood oxygen level
    is measured through both pulse oximeter (SpO2) and
    Arterial Blood Gas (ABG) test within an hour, and the
    difference of those two measurements is more than than 2%.
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "the blood oxygen level" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '2708-6', 'Oxygen saturation in Arterial blood'), ('observation_mimic_flat', 'code_code', '220224', 'Arterial O2 pressure'), ('observation_mimic_flat', 'code_code', '50817', 'Oxygen Saturation'), ('observation_mimic_flat', 'code_code', '220227', 'Arterial O2 Saturation'), ('observation_mimic_flat', 'code_code', '50821', 'pO2'), ('observation_mimic_flat', 'code_code', '226063', 'Venous O2 Pressure'), ('observation_mimic_flat', 'code_code', '220277', 'O2 saturation pulseoxymetry'), ('observation_mimic_flat', 'code_code', '50816', 'Oxygen'), ('condition_mimic_flat', 'code_code', 'R0902', 'Hypoxemia'), ('condition_mimic_flat', 'code_code', '79902', 'Hypoxemia'), ('observation_mimic_flat', 'code_code', '50823', 'Required O2'), 

In [77]:
query_instance_10_07_20.iterative_gen_sql(
    query='''
    Find encounters in which the blood oxygen level of
    the patient is measured through both pulse oximeter (SpO2)
    and Arterial Blood Gas (ABG) test; then find the average
    value for each method and count the number of encounters
    that the difference between those two average values
    is at most 2% vs those that it is greater than 2%.
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "the blood oxygen level" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '2708-6', 'Oxygen saturation in Arterial blood'), ('observation_mimic_flat', 'code_code', '220224', 'Arterial O2 pressure'), ('observation_mimic_flat', 'code_code', '50817', 'Oxygen Saturation'), ('observation_mimic_flat', 'code_code', '220227', 'Arterial O2 Saturation'), ('observation_mimic_flat', 'code_code', '50821', 'pO2'), ('observation_mimic_flat', 'code_code', '226063', 'Venous O2 Pressure'), ('observation_mimic_flat', 'code_code', '220277', 'O2 saturation pulseoxymetry'), ('observation_mimic_flat', 'code_code', '50816', 'Oxygen'), ('condition_mimic_flat', 'code_code', 'R0902', 'Hypoxemia'), ('condition_mimic_flat', 'code_code', '79902', 'Hypoxemia'), ('observation_mimic_flat', 'code_code', '50823', 'Required O2'), 

### Debugging codes
Some of the cell outputs in this section are removed not to expose any individual IDs.

In [48]:
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 200)
pd.set_option('display.max_colwidth', None)

In [None]:
pd.read_sql_query(
    sql="""
    SELECT * FROM observation_mimic AS O
    LATERAL VIEW OUTER explode(code.coding) AS OCC
    WHERE OCC.code = '220227' -- '50817'
    LIMIT 2;
    """,
    con=spark_query_engine,
)

In [None]:
pd.read_sql_query(
    sql="""
    SELECT * FROM observation_mimic AS O
    LATERAL VIEW OUTER explode(code.coding) AS OCC
    WHERE OCC.code = '50817'
    LIMIT 2;
    """,
    con=spark_query_engine,
)

In [None]:
pd.read_sql_query(
    sql="""
    SELECT * FROM observation_mimic AS O
    LATERAL VIEW OUTER explode(code.coding) AS OCC
    WHERE OCC.code = '2708-6'
    LIMIT 2;
    """,
    con=spark_query_engine,
)

In [41]:
oxygen_obs = pd.read_sql_query(
    sql="""
    SELECT id, encounter_id, obs_date, val_quantity, val_quantity_unit, code_code
    FROM Observation_mimic_flat AS O
    WHERE O.code_code IN ('50817', '220227', '2708-6')
    AND O.patient_id = '78dd8dd9-31cb-5dd9-b703-11f37937ad14'
    ORDER BY O.obs_date;
    """,
    con=spark_query_engine,
)
oxygen_obs.head(100)

Unnamed: 0,id,encounter_id,obs_date,val_quantity,val_quantity_unit,code_code
0,92c384cf-5fc2-5854-b01e-8308c07fbd04,,2174-09-22T09:37:00-04:00,96.0,%,50817
1,5dd913f5-4a1a-5380-a31c-bd69703e4e72,fcc61b5a-510e-5102-bff9-96738a23cf33,2174-10-02T15:57:00-04:00,90.0,%,50817
2,63107731-5ea7-5d20-98b0-872c464e75fa,fcc61b5a-510e-5102-bff9-96738a23cf33,2174-10-02T21:54:00-04:00,97.0,%,50817
3,32cbf905-6cab-5893-bc00-a4696c238be9,fcc61b5a-510e-5102-bff9-96738a23cf33,2174-10-03T04:43:00-04:00,74.0,%,50817
4,ac07af90-08e7-595b-8b3e-eb8affbb38ff,fcc61b5a-510e-5102-bff9-96738a23cf33,2174-10-03T10:24:00-04:00,95.0,%,50817
5,80f8be9c-d6d7-5a66-9b94-06b5ca87f8c1,fcc61b5a-510e-5102-bff9-96738a23cf33,2174-10-03T10:29:00-04:00,57.0,%,50817
6,3167b91f-de08-5bbc-94f5-b8e752debe3c,24a28b35-8155-5542-a477-3192ec7aa4ea,2178-10-12T00:57:00-04:00,95.0,%,2708-6
7,023b3d04-38cd-5b43-84b5-4337a165d5d6,24a28b35-8155-5542-a477-3192ec7aa4ea,2178-10-12T04:38:00-04:00,98.0,%,2708-6
8,4dd6ee12-9885-5086-b9c6-2fa822122cd2,7656e003-9a07-5664-a390-419c6fb25398,2179-10-12T10:39:00-04:00,98.0,%,2708-6
9,4b2d7801-69a2-591e-86ae-1f885629c547,7656e003-9a07-5664-a390-419c6fb25398,2179-10-12T11:14:00-04:00,99.0,%,2708-6


In [100]:
oxygen_obs = pd.read_sql_query(
    sql="""
    SELECT id, encounter_id, obs_date, val_quantity, val_quantity_unit, code_code
    FROM Observation_mimic_flat AS O
    WHERE O.code_code IN ('50817', '220227')
    AND O.patient_id = 'ef70d460-d857-5ce4-ad5d-e09d79a1a006'
    ORDER BY O.obs_date;
    """,
    con=spark_query_engine,
)
oxygen_obs.head(40)

Unnamed: 0,id,encounter_id,obs_date,val_quantity,val_quantity_unit,code_code
0,691cabfe-f489-54c9-b30f-f479c535ecb6,,2145-08-26T18:49:00-04:00,94.0,%,50817


In [10]:
pd.read_sql_query(
    sql="""
    SELECT code_code AS code, AVG(YEAR(obs_date)) AS avg_year, COUNT(*) AS num_obs
    FROM observation_mimic_flat AS O
    WHERE code_code IN ('50817', '220227')
    GROUP BY code_code
    ORDER BY code;
    """,
    con=spark_query_engine,
)

Unnamed: 0,code,avg_year,num_obs
0,220227,2153.366913,87394
1,50817,2153.599058,176225


In [12]:
# Original query with 220227
pd.read_sql_query(
    sql="""
    WITH spo2_obs AS (
    SELECT 
        encounter_id,
        AVG(val_quantity) AS avg_spo2
    FROM observation_mimic_flat
    WHERE code_code = '220277' -- O2 saturation pulseoxymetry is used as a SpO2 indicator
    AND val_quantity IS NOT NULL
    GROUP BY encounter_id
), abg_obs AS (
    SELECT 
        encounter_id,
        AVG(val_quantity) AS avg_abg
    FROM observation_mimic_flat
    WHERE code_code = '220227' -- Arterial O2 Saturation is used as an ABG indicator
    AND val_quantity IS NOT NULL
    GROUP BY encounter_id
), combined_obs AS (
    SELECT 
        s.encounter_id,
        s.avg_spo2,
        a.avg_abg
    FROM spo2_obs s
    JOIN abg_obs a ON s.encounter_id = a.encounter_id
), diff_obs AS (
    SELECT 
        encounter_id,
        ABS(avg_spo2 - avg_abg) AS abs_diff,
        (ABS(avg_spo2 - avg_abg) / ((avg_spo2 + avg_abg)/2))*100 as perc_diff -- percentage difference calculation
    FROM combined_obs
)
SELECT 
    SUM(CASE WHEN perc_diff <= 2 THEN 1 ELSE 0 END) AS within_2_percent,  -- Count encounters within 2pct difference
    SUM(CASE WHEN perc_diff > 2 THEN 1 ELSE 0 END) AS greater_than_2_percent  -- Count encounters with greater than 2pct difference
FROM diff_obs;
    """,
    con=spark_query_engine,
)

Unnamed: 0,within_2_percent,greater_than_2_percent
0,12064,5319


In [13]:
# Same query with 50817 instead of 220227
pd.read_sql_query(
    sql="""
    WITH spo2_obs AS (
    SELECT 
        encounter_id,
        AVG(val_quantity) AS avg_spo2
    FROM observation_mimic_flat
    WHERE code_code = '220277' -- O2 saturation pulseoxymetry is used as a SpO2 indicator
    AND val_quantity IS NOT NULL
    GROUP BY encounter_id
), abg_obs AS (
    SELECT 
        encounter_id,
        AVG(val_quantity) AS avg_abg
    FROM observation_mimic_flat
    WHERE code_code = '50817' -- Arterial O2 Saturation is used as an ABG indicator
    AND val_quantity IS NOT NULL
    GROUP BY encounter_id
), combined_obs AS (
    SELECT 
        s.encounter_id,
        s.avg_spo2,
        a.avg_abg
    FROM spo2_obs s
    JOIN abg_obs a ON s.encounter_id = a.encounter_id
), diff_obs AS (
    SELECT 
        encounter_id,
        ABS(avg_spo2 - avg_abg) AS abs_diff,
        (ABS(avg_spo2 - avg_abg) / ((avg_spo2 + avg_abg)/2))*100 as perc_diff -- percentage difference calculation
    FROM combined_obs
)
SELECT 
    SUM(CASE WHEN perc_diff <= 2 THEN 1 ELSE 0 END) AS within_2_percent,  -- Count encounters within 2pct difference
    SUM(CASE WHEN perc_diff > 2 THEN 1 ELSE 0 END) AS greater_than_2_percent  -- Count encounters with greater than 2pct difference
FROM diff_obs;
    """,
    con=spark_query_engine,
)

Unnamed: 0,within_2_percent,greater_than_2_percent
0,,


In [None]:
oxygen_obs = pd.read_sql_query(
    sql="""
    SELECT id, encounter_id, obs_date, val_quantity, val_quantity_unit, code_code
    FROM Observation_mimic_flat AS O
    WHERE O.code_code IN ('50817', '220227', '220277')
    AND O.patient_id = 'XXX'
    ORDER BY O.obs_date;
    """,
    con=spark_query_engine,
)
oxygen_obs.head(30)

In [19]:
pd.read_sql_query(
    sql="""
SELECT COUNT(DISTINCT spo2.patient_id)
FROM observation_mimic_flat spo2
JOIN observation_mimic_flat abg ON spo2.patient_id = abg.patient_id -- Join SpO2 and ABG observations on patient ID
WHERE spo2.code_code = '220277'  -- Filter for SpO2 observations
AND abg.code_code = '220227'  -- Filter for ABG observations (Arterial O2 Saturation)
AND ABS(spo2.val_quantity - abg.val_quantity) > 2  -- Calculate the absolute difference of SpO2 and ABG values and filter for differences > 2
AND CAST(substr(spo2.obs_date, 1, 19) AS timestamp) >= CAST(substr(abg.obs_date, 1, 19) AS timestamp)  -- Ensure spo2 observation time is after ABG observation time
AND CAST(substr(spo2.obs_date, 1, 19) AS timestamp) <=  CAST(substr(abg.obs_date, 1, 19) AS timestamp) + INTERVAL '1' HOUR;  -- Ensure that the SpO2 and ABG measurements are within 1 hour of each other
    """,
    con=spark_query_engine,
)

Unnamed: 0,count(DISTINCT patient_id)
0,8762


In [52]:
# Note time constraints is incorrect.
pd.read_sql_query(
    sql="""
SELECT COUNT(DISTINCT spo2.patient_id)
FROM observation_mimic_flat spo2
JOIN observation_mimic_flat abg ON spo2.patient_id = abg.patient_id -- Join SpO2 and ABG observations on patient ID
WHERE spo2.code_code = '220277'  -- Filter for SpO2 observations
AND abg.code_code = '50817'  -- Filter for ABG observations (Arterial O2 Saturation)
AND ABS(spo2.val_quantity - abg.val_quantity) > 2  -- Calculate the absolute difference of SpO2 and ABG values and filter for differences > 2
AND CAST(substr(spo2.obs_date, 1, 19) AS timestamp) >= CAST(substr(abg.obs_date, 1, 19) AS timestamp)  -- Ensure spo2 observation time is after ABG observation time
AND CAST(substr(spo2.obs_date, 1, 19) AS timestamp) <=  CAST(substr(abg.obs_date, 1, 19) AS timestamp) + INTERVAL '1' HOUR;  -- Ensure that the SpO2 and ABG measurements are within 1 hour of each other
    """,
    con=spark_query_engine,
)

Unnamed: 0,count(DISTINCT patient_id)
0,14136


In [21]:
# Side note: one of the suggested queries has an encounter contraint in it (why?):
pd.read_sql_query(
    sql="""
SELECT COUNT(DISTINCT spo2.patient_id)
FROM observation_mimic_flat spo2
JOIN observation_mimic_flat abg ON spo2.patient_id = abg.patient_id -- Join SpO2 and ABG observations on patient ID
AND spo2.encounter_id = abg.encounter_id -- Ensuring both reading taken during the same encounter
WHERE spo2.code_code = '220277'  -- Filter for SpO2 observations
AND abg.code_code = '220227'  -- Filter for ABG observations (Arterial O2 Saturation)
AND ABS(spo2.val_quantity - abg.val_quantity) > 2  -- Calculate the absolute difference of SpO2 and ABG values and filter for differences > 2
AND CAST(substr(spo2.obs_date, 1, 19) AS timestamp) >= CAST(substr(abg.obs_date, 1, 19) AS timestamp)  -- Ensure spo2 observation time is after ABG observation time
AND CAST(substr(spo2.obs_date, 1, 19) AS timestamp) <=  CAST(substr(abg.obs_date, 1, 19) AS timestamp) + INTERVAL '1' HOUR;  -- Ensure that the SpO2 and ABG measurements are within 1 hour of each other
    """,
    con=spark_query_engine,
)

Unnamed: 0,count(DISTINCT patient_id)
0,8756


In [53]:
# Correct query with 2708-6
pd.read_sql_query(
    sql="""
SELECT COUNT(DISTINCT spo2.patient_id)
FROM observation_mimic_flat spo2
JOIN observation_mimic_flat abg ON spo2.patient_id = abg.patient_id  -- Join SpO2 and ABG observations on patient_id
WHERE spo2.code_code = '220277'  -- Filter for SpO2 observations using pulse oximetry
AND abg.code_code = '2708-6'  -- Filter for ABG observations
AND ABS(spo2.val_quantity - abg.val_quantity) > 2  -- Filter where the absolute difference is greater than 2
AND CAST(spo2.obs_date AS TIMESTAMP) BETWEEN CAST(abg.obs_date AS TIMESTAMP) - INTERVAL '1' HOUR AND CAST(abg.obs_date AS TIMESTAMP) + INTERVAL '1' HOUR;  -- Filter for observations within 1 hour
    """,
    con=spark_query_engine,
)

Unnamed: 0,count(DISTINCT patient_id)
0,3406


In [54]:
# Correct query with 220227
pd.read_sql_query(
    sql="""
SELECT COUNT(DISTINCT spo2.patient_id)
FROM observation_mimic_flat spo2
JOIN observation_mimic_flat abg ON spo2.patient_id = abg.patient_id  -- Join SpO2 and ABG observations on patient_id
WHERE spo2.code_code = '220277'  -- Filter for SpO2 observations using pulse oximetry
AND abg.code_code = '220227'  -- Filter for ABG observations
AND ABS(spo2.val_quantity - abg.val_quantity) > 2  -- Filter where the absolute difference is greater than 2
AND CAST(spo2.obs_date AS TIMESTAMP) BETWEEN CAST(abg.obs_date AS TIMESTAMP) - INTERVAL '1' HOUR AND CAST(abg.obs_date AS TIMESTAMP) + INTERVAL '1' HOUR;  -- Filter for observations within 1 hour
    """,
    con=spark_query_engine,
)

Unnamed: 0,count(DISTINCT patient_id)
0,10225


In [55]:
# Correct query with 50817
pd.read_sql_query(
    sql="""
SELECT COUNT(DISTINCT spo2.patient_id)
FROM observation_mimic_flat spo2
JOIN observation_mimic_flat abg ON spo2.patient_id = abg.patient_id  -- Join SpO2 and ABG observations on patient_id
WHERE spo2.code_code = '220277'  -- Filter for SpO2 observations using pulse oximetry
AND abg.code_code = '50817'  -- Filter for ABG observations
AND ABS(spo2.val_quantity - abg.val_quantity) > 2  -- Filter where the absolute difference is greater than 2
AND CAST(spo2.obs_date AS TIMESTAMP) BETWEEN CAST(abg.obs_date AS TIMESTAMP) - INTERVAL '1' HOUR AND CAST(abg.obs_date AS TIMESTAMP) + INTERVAL '1' HOUR;  -- Filter for observations within 1 hour
    """,
    con=spark_query_engine,
)

Unnamed: 0,count(DISTINCT patient_id)
0,15114


In [32]:
query_instance_10_07_20.iterative_gen_sql(
    query='''
    For encounters in which the blood oxygen level of the patient is measured through both
    pulse oximeter (SpO2) and Arterial Blood Gas (ABG) test, find the average of for each
    method and count the number of those that the difference between those two average values
    is at most 2% vs those that it is greater than 2%.
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "the blood oxygen level" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '2708-6', 'Oxygen saturation in Arterial blood'), ('observation_mimic_flat', 'code_code', '220224', 'Arterial O2 pressure'), ('observation_mimic_flat', 'code_code', '50817', 'Oxygen Saturation'), ('observation_mimic_flat', 'code_code', '220227', 'Arterial O2 Saturation'), ('observation_mimic_flat', 'code_code', '50821', 'pO2'), ('observation_mimic_flat', 'code_code', '226063', 'Venous O2 Pressure'), ('observation_mimic_flat', 'code_code', '220277', 'O2 saturation pulseoxymetry'), ('observation_mimic_flat', 'code_code', '50816', 'Oxygen'), ('condition_mimic_flat', 'code_code', 'R0902', 'Hypoxemia'), ('condition_mimic_flat', 'code_code', '79902', 'Hypoxemia'), ('observation_mimic_flat', 'code_code', '50823', 'Required O2'), 

In [33]:
pd.read_sql_query(
    sql="""
SELECT COUNT(*)
FROM observation_mimic_flat o
WHERE code_code = '2708-6'
    """,
    con=spark_query_engine,
)

Unnamed: 0,count(1)
0,1989697


In [34]:
pd.read_sql_query(
    sql="""
SELECT COUNT(*)
FROM observation_mimic_flat o
WHERE code_code = '220227'
    """,
    con=spark_query_engine,
)

Unnamed: 0,count(1)
0,87394


In [36]:
pd.read_sql_query(
    sql="""
SELECT COUNT(*)
FROM observation_mimic o
    """,
    con=spark_query_engine,
)

Unnamed: 0,count(1)
0,461098908


# Eval set
Source repo: [github.com/glee4810/ehrsql-2024](https://github.com/glee4810/ehrsql-2024)


In [7]:
from querygen import sqlgen, column_sampler, db_description, util, terminology_indexer
import querygen
querygen.PRINT_CLOSE_CONCEPTS = True
querygen.PRINT_FINAL_PROMPT = False
#querygen.PRINT_FINAL_PROMPT = True
querygen.PRINT_MODEL_RESPONSE = False

# The reloads are to be able to quickly import latest external changes.
from importlib import reload
column_sampler = reload(column_sampler)
db_description = reload(db_description)
util = reload(util)
terminology_indexer = reload(terminology_indexer)
sqlgen = reload(sqlgen)

query_instance_10_07_20 = sqlgen.SqlGen(
    target_db_url='hive://localhost:10001/default',
    pg_vector_db_url='postgresql://postgres:admin@localhost:5438/codevec',
    nlp_ner=nlp_ner,
    num_column_samples=10,
    close_threshold=0.7,
    max_close=20
)

term_indexer = terminology_indexer.TerminologyIndexer(
    target_db_url='hive://localhost:10001/default',
    pg_vector_db_url='postgresql://postgres:admin@localhost:5438/codevec')

## Population level queries
The queries are from the [EHRSQL 2024 dataset](https://github.com/glee4810/ehrsql-2024/tree/master/data/mimic_iv/test)
where a queries are for the whole population.

### Most frequent lab tests
**Analysis**: The general structure of the query and the answer are partially correct,
however there are some issues/observations:
- One issue is that in the FHIR version of MIMIC,
various tables like `labevents`, `microbiologyevents`, `chartevents`,
`datetimeevents`, `vitalsign`, etc. are all represented as `Observation`.
This makes it harder to write the correct SQL compared to original tables.
The only column useful to differentiate this is `code_sys` in flat observation
tables. However, this column values were not properly sampled when on 10K rows
were checked. Here is what we got in the generated prompt:
```
Here are 10 sample values from column code_sys of table observation_mimic_full_flat_custom
sorted by their frequencies over 10000 rows: 

0  http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-chartevents-d-items  10000
```
By enabling full column sampling (see subsection below), we get these values:
```
Here are 10 sample values from column code_sys of table observation_mimic_flat
sorted by their frequencies over ALL rows: 

0      http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-chartevents-d-items  313645032
1               http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-labitems  118171367
2                                                          http://loinc.org   14352966
3                  http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-items   11347966
4        http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-microbiology-test    2184371
5  http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-microbiology-antibiotic    1107278
6    http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-microbiology-organism     289928
```
*Side note*: The `labevents` are also included in the `identifier.system` of
original Observation resources but that column is not included in our flat views.
the `observation_mimic_flat`.

- Another issue is that multiple codes have the same `display` value, e.g., there are
three codes with `display` being `Safety Measures` (all from `chartevents`); so the query
should have included `code` and `system` too.

See: https://physionet.org/content/mimic-iv-fhir/2.1/

In [14]:
# Question ID: ce4170ef70617d0f660eb8c8
# Expected answer: [['chloride'], ['creatinine'], ['hematocrit'], ['sodium']]
query_instance_10_07_20.iterative_gen_sql(
    query='''
    Can you show me the top four most frequent lab tests given?
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "four" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '50891', 'C4')]

            
            For "frequent lab tests" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
Iteration 0 SQL: 

SELECT code_display
FROM observation_mimic_flat
GROUP BY code_display
ORDER BY count(*) DESC
LIMIT 4; -- show the top four most frequent lab tests given

Iteration 0 SQL execution output: 
START_OUTPUT
('Safety Measures',)
('Heart Rate',)
('Respiratory Rate',)
('O2 saturation pulseoxymetry',)
END_OUTPUT
Iteration 1 SQL: 

SELECT code_display
FROM observation_mimic_flat
GROUP BY code_display
ORDER BY count(*) DESC
LIMIT 4; -- show the top four most frequent lab tests given

I

In [21]:
pd.read_sql_query(
    sql="""
    SELECT code_display, COUNT(*) AS num_obs
    FROM observation_mimic_flat
    GROUP BY code_display
    ORDER BY count(*) DESC
    LIMIT 10; -- show the top four most frequent lab tests given
    """,
    con=spark_query_engine,
)

Unnamed: 0,code_display,num_obs
0,Safety Measures,9159444
1,Heart Rate,6460839
2,Respiratory Rate,6393762
3,O2 saturation pulseoxymetry,6324341
4,Heart Rhythm,5907987
5,Ectopy Type 1,5272404
6,Non Invasive Blood Pressure systolic,4067219
7,Non Invasive Blood Pressure diastolic,4066337
8,Non Invasive Blood Pressure mean,4064537
9,Hemoglobin,3729307


In [22]:
pd.read_sql_query(
    sql="""
    SELECT code_code, code_sys, code_display, COUNT(*) AS num_obs
    FROM observation_mimic_flat
    GROUP BY code_code, code_sys, code_display
    ORDER BY COUNT(*) DESC
    LIMIT 20;
    """,
    con=spark_query_engine,
)

Unnamed: 0,code_code,code_sys,code_display,num_obs
0,227969,http://mimic.mit.edu/fhir/mimic/CodeSystem/mim...,Safety Measures,8770670
1,220045,http://mimic.mit.edu/fhir/mimic/CodeSystem/mim...,Heart Rate,6460839
2,220210,http://mimic.mit.edu/fhir/mimic/CodeSystem/mim...,Respiratory Rate,6393762
3,220277,http://mimic.mit.edu/fhir/mimic/CodeSystem/mim...,O2 saturation pulseoxymetry,6324341
4,220048,http://mimic.mit.edu/fhir/mimic/CodeSystem/mim...,Heart Rhythm,5907987
5,224650,http://mimic.mit.edu/fhir/mimic/CodeSystem/mim...,Ectopy Type 1,5272404
6,220179,http://mimic.mit.edu/fhir/mimic/CodeSystem/mim...,Non Invasive Blood Pressure systolic,4067219
7,220180,http://mimic.mit.edu/fhir/mimic/CodeSystem/mim...,Non Invasive Blood Pressure diastolic,4066337
8,220181,http://mimic.mit.edu/fhir/mimic/CodeSystem/mim...,Non Invasive Blood Pressure mean,4064537
9,227958,http://mimic.mit.edu/fhir/mimic/CodeSystem/mim...,Less Restrictive Measures,3333656


#### With full column sampling
With full column sampling we get the correct `code_sys` constraint and the results:
```
START_OUTPUT
('Glucose',)
('Hematocrit',)
('Hemoglobin',)
('Creatinine',)
END_OUTPUT
```
I think the answer provided in the EHRSQL dataset is incorrect. Here is the SQL the
model generates:
```sql
SELECT code_display
FROM observation_mimic_flat
WHERE code_sys = 'http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-labitems' -- To filter for lab tests
GROUP BY code_display
ORDER BY COUNT(*) DESC
LIMIT 4; -- To get the top four most frequent lab tests
```

*Side note*: In the first run below, ~2 minutes is for full column sampling. This is
cached and the impact can be seen in the next repeat cell which takes only ~27 seconds.

In [30]:
# With full column sampling
#
# Question ID: ce4170ef70617d0f660eb8c8
# Expected answer: [['chloride'], ['creatinine'], ['hematocrit'], ['sodium']]
query_instance_10_07_20.iterative_gen_sql(
    query='''
    Can you show me the top four most frequent lab tests given?
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "four" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '50891', 'C4')]

            
            For "frequent lab tests" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
Iteration 0 SQL: 

SELECT code_display
FROM observation_mimic_flat
WHERE code_sys = 'http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-labitems' -- To filter for lab tests
GROUP BY code_display
ORDER BY COUNT(*) DESC
LIMIT 4; -- To get the top four most frequent lab tests

Iteration 0 SQL execution output: 
START_OUTPUT
('Glucose',)
('Hematocrit',)
('Hemoglobin',)
('Creatinine',)
END_OUTPUT
Iteration 1 SQL: 

SELECT code_display
FROM observation_mimic_flat
WHERE code_sys = 'http://mimic.mit.e

In [31]:
# With full column sampling; repeating to show the performance effect of full
# sampling as it is cached in this second run.
#
# Question ID: ce4170ef70617d0f660eb8c8
# Expected answer: [['chloride'], ['creatinine'], ['hematocrit'], ['sodium']]
query_instance_10_07_20.iterative_gen_sql(
    query='''
    Can you show me the top four most frequent lab tests given?
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "four" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '50891', 'C4')]

            
            For "frequent lab tests" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
Iteration 0 SQL: 

SELECT code_display
FROM observation_mimic_flat
WHERE code_sys = 'http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-labitems' -- To filter for lab tests
GROUP BY code_display
ORDER BY count(*) DESC
LIMIT 4; -- To get the top four most frequent lab tests

Iteration 0 SQL execution output: 
START_OUTPUT
('Glucose',)
('Hematocrit',)
('Hemoglobin',)
('Creatinine',)
END_OUTPUT
Iteration 1 SQL: 

SELECT code_display
FROM observation_mimic_flat
WHERE code_sys = 'http://mimic.mit.e

### Most frequent microbiology tests after cardiac arrest
**Analysis**: Similar to the above example; adding full column sampling
fixes the queries/result. Another related point is that the performance
of the queries are much better probably because of the following constraint:
```
code_sys = 'http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-microbiology-test'
```

It is worth noting that the "expected answer" is probably wrong too as it only
has the one element *'blood culture, routine'* (instead of four). Our queries return:
```
START_OUTPUT
('Blood Culture, Routine',)
('GRAM STAIN',)
('URINE CULTURE',)
('RESPIRATORY CULTURE',)
END_OUTPUT
```

Finally note that for the codes that are picked for cardiac arrest, `I462` is the
target one as its description exactly matches that of the question, i.e.,
_"Cardiac arrest due to underlying cardiac condition"_ (which makes the question
a little bit unrealistic BTW). But the LLM sometime picks other related codes too,
e.g., 
- `I469`: _"Cardiac arrest, cause unspecified"_
- `I468`: _"Cardiac arrest due to other underlying condition"_

In [26]:
# Question ID: e56ea93964e30564ef0d8599
# Expected answer: [['blood culture, routine']]
query_instance_10_07_20.iterative_gen_sql(
    query='''
    What is the four most frequently given microbiology tests since 2100 for
    patients who were previously diagnosed with cardiac arrest due to
    underlying cardiac condition in the same hospital encounter?
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "four" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '50891', 'C4')]

            
            For "microbiology tests" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
            For "2100" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
            For "cardiac arrest" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('condition_mimic_flat', 'code_code', '4275', 'Cardiac arrest'), ('procedure_mimic_flat'

In [32]:
# With full column sampling.
#
# Question ID: e56ea93964e30564ef0d8599
# Expected answer: [['blood culture, routine']]
query_instance_10_07_20.iterative_gen_sql(
    query='''
    What is the four most frequently given microbiology tests since 2100 for
    patients who were previously diagnosed with cardiac arrest due to
    underlying cardiac condition in the same hospital encounter?
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "four" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '50891', 'C4')]

            
            For "microbiology tests" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
            For "2100" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
            For "cardiac arrest" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('condition_mimic_flat', 'code_code', '4275', 'Cardiac arrest'), ('procedure_mimic_flat'

### Most frequent procedures
The actual generated query seems to be correct but the information about the time at
which the procedure was performed is in the `performed.dateTime` field of `Procedure`
FHIR resources which is not exposed in the flat views. Instead the model uses the
`period_start` which is a good choice but it is not set for most procedures.

In [4]:
# Question ID: 783af9c867b3757480fa6805
# Expected answer: [['arterial catheterization'], ['central venous catheter placement with guidance'],
#  ['continuous invasive mechanical ventilation for less than 96 consecutive hours'],
#  ['enteral infusion of concentrated nutritional substances'], ['extracorporeal circulation auxiliary to open heart surgery'], ...
# The expected answer has more elements while it is clearly asking for only five!
query_instance_10_07_20.iterative_gen_sql(
    query='''
    Can you tell me the top five most common procedures this year?
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "five" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '51767', 'EE5')]

            
            For "this year" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
Iteration 0 SQL: 

SELECT code_display, COUNT(*) AS procedure_count
FROM procedure_mimic_flat
WHERE SUBSTR(period_start, 1, 4) = SUBSTR(CURRENT_DATE(), 1, 4) -- Filter by the current year, assuming period_start contains the year
GROUP BY code_display
ORDER BY procedure_count DESC
LIMIT 5; -- Limit the result to the top 5 procedures

Iteration 0 SQL execution output: 
START_OUTPUT

END_OUTPUT
Iteration 1 SQL: 

SELECT code_display, COUNT(*) AS procedure_count
FROM procedure_mimic_flat
WHERE SUBSTR(peri

In [5]:
query_instance_10_07_20.iterative_gen_sql(
    query='''
    Can you tell me the top five most common procedures in year 2100?
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "five" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '51767', 'EE5')]

            
            For "year 2100" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
Iteration 0 SQL: 

SELECT code_display, COUNT(*) AS procedure_count
FROM procedure_mimic_flat
WHERE SUBSTR(period_start, 1, 4) = '2100' -- Filter for procedures in the year 2100, using the start date.
GROUP BY code_display
ORDER BY procedure_count DESC
LIMIT 5; -- Limit to the top 5 most common procedures.

Iteration 0 SQL execution output: 
START_OUTPUT

END_OUTPUT
Iteration 1 SQL: 

SELECT code_display, COUNT(*) AS procedure_count
FROM procedure_mimic_flat
WHERE SUBSTR(period_start, 1, 4) = '2100' -

In [9]:
# The problem is that most procedures do not have the period fields:
pd.read_sql_query(
    sql="""
    SELECT SUBSTR(period_start, 1, 4) AS year, COUNT(*) AS num_proc
    FROM procedure_mimic_flat
    GROUP BY year
    ORDER BY num_proc DESC
    LIMIT 5; -- Limit to the top 5 most common procedures.
    """,
    con=spark_query_engine,
)

Unnamed: 0,year,num_proc
0,,2658883
1,2149.0,10080
2,2130.0,9689
3,2183.0,9602
4,2163.0,9502


In [78]:
# Instead it is the `performed.dateTime` field that should be used but
# it is *not* exposed in the flat view.
pd.read_sql_query(
    sql="""
    SELECT SUBSTR(performed.dateTime, 1, 4) AS year, COUNT(*) AS num_proc
    FROM procedure_mimic
    GROUP BY year
    ORDER BY num_proc DESC
    LIMIT 10;
    """,
    con=spark_query_engine,
)

Unnamed: 0,year,num_proc
0,,696092
1,2176.0,34951
2,2145.0,34662
3,2181.0,34652
4,2141.0,34529
5,2144.0,34528
6,2146.0,34394
7,2163.0,34307
8,2177.0,34261
9,2180.0,34231


#### After adding `performed.dateTime` to the `procedure_mimic_flat`
*Analysis:* Note the `COALESCE(period_start, performed_time)` function in the
generated SQLs.

In [77]:
# After adding performed.dateTime to `procedure_mimic_flat`
pd.read_sql_query(
    sql="""
    SELECT SUBSTR(performed_time, 1, 4) AS year, COUNT(*) AS num_proc
    FROM procedure_mimic_flat
    GROUP BY year
    ORDER BY num_proc DESC
    LIMIT 10; -- Limit to the top 5 most common procedures.
    """,
    con=spark_query_engine,
)

Unnamed: 0,year,num_proc
0,,696092
1,2176.0,34951
2,2145.0,34662
3,2181.0,34652
4,2141.0,34529
5,2144.0,34528
6,2146.0,34394
7,2163.0,34307
8,2177.0,34261
9,2180.0,34231


In [79]:
query_instance_10_07_30 = sqlgen.SqlGen(
    target_db_url='hive://localhost:10001/default',
    pg_vector_db_url='postgresql://postgres:admin@localhost:5438/codevec',
    nlp_ner=nlp_ner,
    num_column_samples=10,
    close_threshold=0.7,
    max_close=30
)

In [81]:
query_instance_10_07_30.iterative_gen_sql(
    query='''
    Can you tell me the top five most common procedures in year 2150?
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "five" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '51767', 'EE5')]

            
            For "year 2150" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
Iteration 0 SQL: 

SELECT code_display, COUNT(*) AS procedure_count
FROM procedure_mimic_flat
WHERE SUBSTR(period_start, 1, 4) = '2150' -- Filter procedures that started in the year 2150
GROUP BY code_display
ORDER BY procedure_count DESC
LIMIT 5; -- Limit the result to the top 5 procedures

Iteration 0 SQL execution output: 
START_OUTPUT
('20 Gauge', 1008)
('Chest X-Ray', 851)
('18 Gauge', 698)
('EKG', 440)
('Invasive Ventilation', 371)
END_OUTPUT
Iteration 1 SQL: 

SELECT code_display, COUNT(*) AS p

In [85]:
# Checking queries
pd.read_sql_query(
    sql="""
SELECT code_display, COUNT(*) AS procedure_count
FROM procedure_mimic_flat
WHERE SUBSTR(COALESCE(period_start, performed_time), 1, 4) = '2150'  -- Consider both period_start and performed_time, coalesce to handle nulls
GROUP BY code_display
ORDER BY procedure_count DESC
LIMIT 5;  -- Limit the results to the top 5 procedures
    """,
    con=spark_query_engine,
)

Unnamed: 0,code_display,procedure_count
0,Taking patient vital signs assessment (procedure),18745
1,Triage: emergency center (procedure),5039
2,20 Gauge,1008
3,Chest X-Ray,851
4,18 Gauge,698


In [96]:
# without the date constraint
pd.read_sql_query(
    sql="""
SELECT code_display, COUNT(*) AS procedure_count
FROM procedure_mimic_flat
-- WHERE SUBSTR(COALESCE(period_start, performed_time), 1, 4) = '2150'  -- Consider both period_start and performed_time, coalesce to handle nulls
GROUP BY code_display
ORDER BY procedure_count DESC
LIMIT 10;  -- Limit the results to the top 5 procedures
    """,
    con=spark_query_engine,
)

Unnamed: 0,code_display,procedure_count
0,Taking patient vital signs assessment (procedure),1564610
1,Triage: emergency center (procedure),425087
2,20 Gauge,86558
3,Chest X-Ray,69120
4,18 Gauge,57127
5,EKG,36727
6,Arterial Line,33173
7,Invasive Ventilation,30710
8,Blood Cultured,22444
9,Extubation,22336


In [101]:
# with system
pd.read_sql_query(
    sql="""
SELECT code_sys, code_code, code_display, COUNT(*) AS procedure_count
FROM procedure_mimic_flat
GROUP BY code_sys, code_code, code_display
ORDER BY procedure_count DESC
LIMIT 20;  -- Limit the results to the top 5 procedures
    """,
    con=spark_query_engine,
)

Unnamed: 0,code_sys,code_code,code_display,procedure_count
0,http://snomed.info/sct,410188000,Taking patient vital signs assessment (procedure),1564610
1,http://snomed.info/sct,386478007,Triage: emergency center (procedure),425087
2,http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-items,224275,20 Gauge,86558
3,http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-items,225459,Chest X-Ray,69120
4,http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-items,224277,18 Gauge,57127
5,http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-items,225402,EKG,36727
6,http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-items,225752,Arterial Line,33173
7,http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-items,225792,Invasive Ventilation,30710
8,http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-items,225401,Blood Cultured,22444
9,http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-d-items,227194,Extubation,22336


In [98]:
# Testing target procedures
pd.read_sql_query(
    sql="""
SELECT code_sys, code_display, COUNT(*) AS procedure_count
FROM procedure_mimic_flat
WHERE LOWER(code_display) = 'arterial catheterization'
GROUP BY code_sys, code_display
ORDER BY procedure_count DESC
LIMIT 10;  -- Limit the results to the top 5 procedures
    """,
    con=spark_query_engine,
)

Unnamed: 0,code_sys,code_display,procedure_count
0,http://mimic.mit.edu/fhir/mimic/CodeSystem/mimic-procedure-icd9,Arterial catheterization,3473


### Number of Prednisone prescriptions
**Analysis:** In this case too, the general structure of the final query is correct with
the following caveats:
- Unlike other codes, for medication codes the `display` field is empty and instead
  the `code` field has a textual representation of the medication name. This means that
  our embedding based similarity search is useless for this code.
- The model correctly recognizes that it should look at the `code` text directly but the
  additional challenge is that `Prednisone` is spelled as `PredniSONE` (not sure why) and
  the model does not recognize it (as it never sees it example values).
- The year `2100` is probably a bad choice because of date-shifts in MIMIC. As the investigations
  below show, picking a "good year" like 2150 works.

In [104]:
# Question ID: d1d8f2df3141cf92ddc53228
# Expected answer: [['4']]
query_instance_10_07_20.iterative_gen_sql(
    query='''
    How many prednisone prescription cases were there in 2100?
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "prednisone prescription" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
            For "2100" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
Iteration 0 SQL: 

SELECT COUNT(DISTINCT mr.id)
FROM medication_request_mimic_flat mr
WHERE mr.medication_code LIKE '%Prednisone%' -- To find Prednisone prescriptions
AND SUBSTR(mr.id, 1, 4) = '2100'; -- There is no date field, so we're checking the first four characters of id assuming they might represent the year

Iteration 0 SQL execution output: 
START_OUTPUT
(0,)
END_OUTPUT
Iteration 1 SQL: 

SELECT COUNT(DISTINCT mr.id)
FROM medication_request_mimic_flat mr
WHERE mr.medication_code LIKE '%Prednisone%' -- To find Prednisone prescription

In [16]:
pd.read_sql_query(
    sql="""
    SELECT med.medication_code, COUNT(*) AS num_med
FROM medication_request_mimic_flat AS med
GROUP BY med.medication_code
ORDER BY num_med DESC
LIMIT 20;
    """,
    con=spark_query_engine,
)

Unnamed: 0,medication_code,num_med
0,,13533220
1,Sodium Chloride 0.9% Flush,194908
2,IV therapy,150815
3,Insulin,122879
4,Acetaminophen,70808
5,Heparin,56742
6,HYDROmorphone (Dilaudid),56245
7,Docusate Sodium,54624
8,Influenza Vaccine Quadrivalent,41424
9,Senna,32762


In [41]:
# One issue is case sensitivity (note LOWER):
from sqlalchemy import text

with spark_query_engine.connect() as connection:
    result = connection.execute(text(
        """
        SELECT COUNT(DISTINCT med.id)
FROM medication_request_mimic_flat AS med
WHERE LOWER(med.medication_code) LIKE '%prednisone%' -- Filter for medication requests containing 'Prednisone'
        """))
    for row in result:
        print("num with LOWER: ", row)
        
    result = connection.execute(text(
        """
        SELECT COUNT(DISTINCT med.id)
FROM medication_request_mimic_flat AS med
WHERE med.medication_code LIKE '%Prednisone%' -- Filter for medication requests containing 'Prednisone'
        """))
    for row in result:
        print("num without LOWER: ", row)

    result = connection.execute(text(
        """
        SELECT med.medication_code AS med_code, COUNT(*) AS num_med
FROM medication_request_mimic_flat AS med
WHERE LOWER(med.medication_code) LIKE '%prednisone%' -- Filter for medication requests containing 'Prednisone'
GROUP BY med_code
ORDER BY num_med DESC
LIMIT 10;
        """))
    for row in result:
        print("name and count: ", row)

num with LOWER:  (5147,)
num without LOWER:  (1,)
name and count:  ('PredniSONE', 5145)
name and count:  ('Prednisone', 1)
name and count:  ('predniSONE', 1)


In [38]:
# Another issue is the date:
with spark_query_engine.connect() as connection:
    result = connection.execute(text(
        """
        SELECT COUNT(DISTINCT med.id)
FROM medication_request_mimic_flat AS med
JOIN encounter_mimic_flat AS enc ON med.encounter_id = enc.id -- Join the tables to link medication requests to encounters
WHERE LOWER(med.medication_code) LIKE '%prednisone%' -- Filter for medication requests containing 'Prednisone'
AND SUBSTR(enc.period_start, 1, 4) = '2100'; -- Filter for encounters where the period_start year is 2100
        """))
    for row in result:
        print("num_med in 2100: ", row)

    result = connection.execute(text(
        """
        SELECT SUBSTR(enc.period_start, 1, 4) AS med_year, COUNT(DISTINCT med.id) AS num_med
FROM medication_request_mimic_flat AS med
JOIN encounter_mimic_flat AS enc ON med.encounter_id = enc.id -- Join the tables to link medication requests to encounters
WHERE LOWER(med.medication_code) LIKE '%prednisone%' -- Filter for medication requests containing 'Prednisone'
GROUP BY med_year
ORDER BY num_med DESC
LIMIT 10;
        """))
    for row in result:
        print("row: ", row)        

num_med in 2100:  (0,)
row:  ('2141', 91)
row:  ('2147', 86)
row:  ('2165', 83)
row:  ('2173', 82)
row:  ('2131', 81)
row:  ('2151', 80)
row:  ('2150', 78)
row:  ('2170', 75)
row:  ('2133', 74)
row:  ('2164', 74)


In [39]:
# Same query with LOWER fix and a different year:

with spark_query_engine.connect() as connection:
    result = connection.execute(text(
        """
        SELECT COUNT(DISTINCT med.id)
FROM medication_request_mimic_flat AS med
JOIN encounter_mimic_flat AS enc ON med.encounter_id = enc.id -- Join the tables to link medication requests to encounters
WHERE LOWER(med.medication_code) LIKE '%prednisone%' -- Filter for medication requests containing 'Prednisone'
AND SUBSTR(enc.period_start, 1, 4) = '2150'; -- Filter for encounters where the period_start year is 2100
        """))
    for row in result:
        print("Number of Prednisone meds in 2150: ", row)

Number of Prednisone meds in 2150:  (78,)


In [105]:
# Question ID: d1d8f2df3141cf92ddc53228
# Expected answer: [['4']]
query_instance_10_07_20.iterative_gen_sql(
    query='''
    How many prednisone prescription cases were there in 2150?
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "prednisone prescription" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
            For "2150" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
Iteration 0 SQL: 

SELECT COUNT(*)
FROM medication_request_mimic_flat
WHERE medication_code LIKE '%Prednisone%' -- To find prednisone prescriptions
AND encounter_id IN (SELECT id FROM encounter_mimic_flat WHERE SUBSTRING(period_start, 1, 4) = '2150'); -- To filter encounters that happened in the year 2150

Iteration 0 SQL execution output: 
START_OUTPUT
(0,)
END_OUTPUT
Iteration 1 SQL: 

SELECT COUNT(DISTINCT T1.id)
FROM medication_request_mimic_flat AS T1
INNER JOIN encounter_mimic_flat AS T2 ON T1.encounter_id = T2.id
WHERE T1.medication_c

### Mammary-coronary artery bypass after arthropathy

Query in the evaluation set (based on original MIMIC schema):
```sql
SELECT COUNT( DISTINCT T1.subject_id )
FROM ( 
    SELECT admissions.subject_id, diagnoses_icd.charttime
    FROM diagnoses_icd JOIN admissions ON diagnoses_icd.hadm_id = admissions.hadm_id
    WHERE diagnoses_icd.icd_code = (
        SELECT d_icd_diagnoses.icd_code FROM d_icd_diagnoses
        WHERE d_icd_diagnoses.long_title = 'arthropathy, unspecified, lower leg' )
          AND strftime('%Y',diagnoses_icd.charttime) = '2100' ) AS T1
        JOIN (
            SELECT admissions.subject_id, procedures_icd.charttime
            FROM procedures_icd JOIN admissions ON procedures_icd.hadm_id = admissions.hadm_id
            WHERE procedures_icd.icd_code = (
                 SELECT d_icd_procedures.icd_code FROM d_icd_procedures
                 WHERE d_icd_procedures.long_title = 'single internal mammary-coronary artery bypass'
            ) AND strftime('%Y',procedures_icd.charttime) = '2100' ) AS T2
        ON T1.subject_id = T2.subject_id
        WHERE T1.charttime < T2.charttime
          AND datetime(T1.charttime,'start of month') = datetime(T2.charttime,'start of month')
```

In [43]:
# Question ID: 15a83f8dbbf6d2243721a87b
# Expected answer: [['1']]
query_instance_10_07_20.iterative_gen_sql(
    query='''
    How many patients underwent single internal mammary-coronary artery bypass during the same
    month after the diagnosis with arthropathy, unspecified, lower leg, in 2150?
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "single internal mammary-coronary artery bypass" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('procedure_mimic_flat', 'code_code', '3615', 'Single internal mammary-coronary artery bypass'), ('procedure_mimic_flat', 'code_code', '3616', 'Double internal mammary-coronary artery bypass'), ('procedure_mimic_flat', 'code_code', '02100A8', 'Bypass Coronary Artery, One Artery from Right Internal Mammary with Autologous Arterial Tissue, Open Approach'), ('procedure_mimic_flat', 'code_code', '02100Z8', 'Bypass Coronary Artery, One Artery from Right Internal Mammary, Open Approach'), ('procedure_mimic_flat', 'code_code', '02100A9', 'Bypass Coronary Artery, One Artery from Left Internal Mammary with Autologous Arterial Tissue, Open Approach'), ('procedure_mimic_flat', 'code_code', '02100Z9', 'Bypass Coronary Artery, One Artery from Left Int

In [8]:
# Question ID: 15a83f8dbbf6d2243721a87b
# Expected answer: [['1']]
query_instance_10_07_20.iterative_gen_sql(
    query='''
    How many patients underwent single internal mammary-coronary artery bypass during the same
    month after the diagnosis with arthropathy, unspecified, lower leg, in 2150?
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "single internal mammary-coronary artery bypass" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('procedure_mimic_flat', 'code_code', '3615', 'Single internal mammary-coronary artery bypass'), ('procedure_mimic_flat', 'code_code', '3616', 'Double internal mammary-coronary artery bypass'), ('procedure_mimic_flat', 'code_code', '02100A8', 'Bypass Coronary Artery, One Artery from Right Internal Mammary with Autologous Arterial Tissue, Open Approach'), ('procedure_mimic_flat', 'code_code', '02100Z8', 'Bypass Coronary Artery, One Artery from Right Internal Mammary, Open Approach'), ('procedure_mimic_flat', 'code_code', '02100A9', 'Bypass Coronary Artery, One Artery from Left Internal Mammary with Autologous Arterial Tissue, Open Approach'), ('procedure_mimic_flat', 'code_code', '02100Z9', 'Bypass Coronary Artery, One Artery from Left Int

In [62]:
# The problem is that most procedures do not have the period fields:
pd.read_sql_query(
    sql="""
    SELECT SUBSTR(c.onset_datetime, 1, 4) AS cond_year, COUNT(DISTINCT p.patient_id) AS num_patients
FROM condition_mimic_flat c
JOIN procedure_mimic_flat p ON c.patient_id = p.patient_id
WHERE 
  -- SUBSTR(c.onset_datetime, 1, 4) = '2150'  -- diagnoses should be in 2150
c.code_code = '71686'  -- diagnosis is Arthropathy, unspecified, lower leg
-- AND SUBSTR(c.onset_datetime, 6, 2) = SUBSTR(p.period_start, 6, 2)  -- ensure the procedures and diagnosis are within the same month
AND p.code_code IN ('3615', '02100A8', '02100Z8', '02100A9', '02100Z9', '0210098', '0210099', '3611')  -- procedure is single internal mammary-coronary artery bypass
GROUP BY cond_year
ORDER BY num_patients DESC
LIMIT 10;
;
    """,
    con=spark_query_engine,
)

Unnamed: 0,cond_year,num_patients


In [63]:
pd.read_sql_query(
    sql="""
    SELECT SUBSTR(c.onset_datetime, 1, 4) AS cond_year, COUNT(*) AS num_cond
FROM condition_mimic_flat c
WHERE c.code_code = '71686'  -- diagnosis is Arthropathy, unspecified, lower leg
GROUP BY cond_year
ORDER BY num_cond DESC
LIMIT 10;
;
    """,
    con=spark_query_engine,
)

Unnamed: 0,cond_year,num_cond
0,,2


#### Tweaking closeness thresholds

In [64]:
query_instance_10_06_30 = sqlgen.SqlGen(
    target_db_url='hive://localhost:10001/default',
    pg_vector_db_url='postgresql://postgres:admin@localhost:5438/codevec',
    nlp_ner=nlp_ner,
    num_column_samples=10,
    close_threshold=0.6,
    max_close=30
)

In [65]:
query_instance_10_06_30.iterative_gen_sql(
    query='''
    How many patients underwent single internal mammary-coronary artery bypass during the same
    month after the diagnosis with arthropathy, unspecified, lower leg, in 2150?
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "single internal mammary-coronary artery bypass" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('procedure_mimic_flat', 'code_code', '3615', 'Single internal mammary-coronary artery bypass'), ('procedure_mimic_flat', 'code_code', '3616', 'Double internal mammary-coronary artery bypass'), ('procedure_mimic_flat', 'code_code', '02100A8', 'Bypass Coronary Artery, One Artery from Right Internal Mammary with Autologous Arterial Tissue, Open Approach'), ('procedure_mimic_flat', 'code_code', '02100Z8', 'Bypass Coronary Artery, One Artery from Right Internal Mammary, Open Approach'), ('procedure_mimic_flat', 'code_code', '02100A9', 'Bypass Coronary Artery, One Artery from Left Internal Mammary with Autologous Arterial Tissue, Open Approach'), ('procedure_mimic_flat', 'code_code', '02100Z9', 'Bypass Coronary Artery, One Artery from Left Int

In [66]:
query_instance_10_07_30 = sqlgen.SqlGen(
    target_db_url='hive://localhost:10001/default',
    pg_vector_db_url='postgresql://postgres:admin@localhost:5438/codevec',
    nlp_ner=nlp_ner,
    num_column_samples=10,
    close_threshold=0.7,
    max_close=30
)

In [67]:
query_instance_10_07_30.iterative_gen_sql(
    query='''
    How many patients underwent single internal mammary-coronary artery bypass during the same
    month after the diagnosis with arthropathy, unspecified, lower leg, in 2150?
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "single internal mammary-coronary artery bypass" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('procedure_mimic_flat', 'code_code', '3615', 'Single internal mammary-coronary artery bypass'), ('procedure_mimic_flat', 'code_code', '3616', 'Double internal mammary-coronary artery bypass'), ('procedure_mimic_flat', 'code_code', '02100A8', 'Bypass Coronary Artery, One Artery from Right Internal Mammary with Autologous Arterial Tissue, Open Approach'), ('procedure_mimic_flat', 'code_code', '02100Z8', 'Bypass Coronary Artery, One Artery from Right Internal Mammary, Open Approach'), ('procedure_mimic_flat', 'code_code', '02100A9', 'Bypass Coronary Artery, One Artery from Left Internal Mammary with Autologous Arterial Tissue, Open Approach'), ('procedure_mimic_flat', 'code_code', '02100Z9', 'Bypass Coronary Artery, One Artery from Left Int

In [68]:
pd.read_sql_query(
    sql="""
    SELECT SUBSTR(c.onset_datetime, 1, 4) AS cond_year, COUNT(*) AS num_cond
FROM condition_mimic_flat c
WHERE c.code_code = '71696'
GROUP BY cond_year
ORDER BY num_cond DESC
LIMIT 10;
;
    """,
    con=spark_query_engine,
)

Unnamed: 0,cond_year,num_cond
0,,314


In [72]:
# This is the same issue with procedure dates not being in `pariod` but `performed`.
pd.read_sql_query(
    sql="""
    SELECT SUBSTR(pr.period_start, 1, 7) AS month, COUNT(DISTINCT patient_id) AS num_patients
FROM procedure_mimic_flat pr
WHERE pr.code_code IN ('3615', '3616', '02100A8', '02100Z8', '02100A9', '02100Z9', '0210098', '0210099',
  '02110Z8', '02110A8', '3611', '02104Z9', '02110A9', '02110Z9', '02120Z8', '02120A8', '02130Z8', '02120A9',
  '02120Z9', '0211099', '0213099', '0210499', '0212099', '3612', '03U007Z', '03U107Z', '3613',
  '03L10ZZ', '03B10ZZ', '3617')  -- Single internal mammary-coronary artery bypass
-- AND SUBSTR(pr.period_start, 1, 7) = SUBSTR(c.onset_datetime, 1, 7)  -- Procedure and condition in the same month
GROUP BY month
ORDER BY num_patients DESC
LIMIT 20
;
    """,
    con=spark_query_engine,
)

Unnamed: 0,month,num_patients
0,,5487


In [None]:
pd.read_sql_query(
    sql="""
    SELECT SUBSTR(pr.period_start, 1, 7) AS month, COUNT(DISTINCT patient_id) AS num_patients
FROM procedure_mimic_flat pr
WHERE pr.code_code IN ('3615', '3616', '02100A8', '02100Z8', '02100A9', '02100Z9', '0210098', '0210099',
  '02110Z8', '02110A8', '3611', '02104Z9', '02110A9', '02110Z9', '02120Z8', '02120A8', '02130Z8', '02120A9',
  '02120Z9', '0211099', '0213099', '0210499', '0212099', '3612', '03U007Z', '03U107Z', '3613',
  '03L10ZZ', '03B10ZZ', '3617')  -- Single internal mammary-coronary artery bypass
-- AND SUBSTR(pr.period_start, 1, 7) = SUBSTR(c.onset_datetime, 1, 7)  -- Procedure and condition in the same month
GROUP BY month
ORDER BY num_patients DESC
LIMIT 20
;
    """,
    con=spark_query_engine,
)

#### After adding `performed.dateTime` to `procedure_mimic_flat`

In [82]:
query_instance_10_07_30.iterative_gen_sql(
    query='''
    How many patients underwent single internal mammary-coronary artery bypass during the same
    month after the diagnosis with arthropathy, unspecified, lower leg, in 2150?
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "single internal mammary-coronary artery bypass" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('procedure_mimic_flat', 'code_code', '3615', 'Single internal mammary-coronary artery bypass'), ('procedure_mimic_flat', 'code_code', '3616', 'Double internal mammary-coronary artery bypass'), ('procedure_mimic_flat', 'code_code', '02100A8', 'Bypass Coronary Artery, One Artery from Right Internal Mammary with Autologous Arterial Tissue, Open Approach'), ('procedure_mimic_flat', 'code_code', '02100Z8', 'Bypass Coronary Artery, One Artery from Right Internal Mammary, Open Approach'), ('procedure_mimic_flat', 'code_code', '02100A9', 'Bypass Coronary Artery, One Artery from Left Internal Mammary with Autologous Arterial Tissue, Open Approach'), ('procedure_mimic_flat', 'code_code', '02100Z9', 'Bypass Coronary Artery, One Artery from Left Int

In [13]:
pd.read_sql_query(
    sql="""
    SELECT SUBSTR(c.onset_datetime, 1, 4) AS year, COUNT(DISTINCT c.patient_id)
FROM condition_mimic_flat c
JOIN procedure_mimic_flat pr ON c.patient_id = pr.patient_id
WHERE c.code_code = '71696'
-- AND SUBSTR(c.onset_datetime, 1, 4) = '2150'
-- AND SUBSTR(pr.period_start, 1, 7) = SUBSTR(c.onset_datetime, 1, 7)
AND pr.code_code IN ('3615', '02100A8', '02100Z8', '02100A9', '02100Z9', '0210098', '0210099', '3611', '02104Z9', '0210499')
GROUP BY year
ORDER BY year
;
    """,
    con=spark_query_engine,
)

Unnamed: 0,year,count(DISTINCT patient_id)
0,,15


In [14]:
pd.read_sql_query(
    sql="""
    SELECT SUBSTR(c.onset_datetime, 1, 4) AS year, COUNT(*)
FROM condition_mimic_flat c
-- WHERE c.code_code = '71696'
-- AND SUBSTR(c.onset_datetime, 1, 4) = '2150'
-- AND SUBSTR(pr.period_start, 1, 7) = SUBSTR(c.onset_datetime, 1, 7)
-- AND pr.code_code IN ('3615', '02100A8', '02100Z8', '02100A9', '02100Z9', '0210098', '0210099', '3611', '02104Z9', '0210499')
GROUP BY year
ORDER BY year
;
    """,
    con=spark_query_engine,
)

Unnamed: 0,year,count(1)
0,,5655376


In [26]:
pd.read_sql_query(
    sql="""
    SELECT id, onset
FROM condition_mimic c
LIMIT 5
;
    """,
    con=spark_query_engine,
)

Unnamed: 0,id,onset
0,283c04d1-a092-58da-8118-2374dd7e37fc,
1,916c70a6-56a3-55c4-8a54-a5bd3f67dd69,
2,789d8b3e-d873-584c-8e76-fa43b46f4696,
3,ff0fbb16-eb00-5d58-9570-1e25d5463ffe,
4,c0878bb9-f070-5dc4-9003-884b03ca2534,


# Appendix: Full prompt
This section includes some examples that show the full prompt. They are moved
here to reduce clutter.

In [64]:
query_instance_10_07_20.iterative_gen_sql(
    query='''
    For encounters in which the blood oxygen level of the patient is measured through both
    pulse oximeter (SpO2) and Arterial Blood Gas (ABG) test, find the average of for each
    method and count the number of those that the difference between those two average values
    is at most 2% vs those that it is greater than 2%.
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "the blood oxygen level" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('observation_mimic_flat', 'code_code', '2708-6', 'Oxygen saturation in Arterial blood'), ('observation_mimic_flat', 'code_code', '220224', 'Arterial O2 pressure'), ('observation_mimic_flat', 'code_code', '50817', 'Oxygen Saturation'), ('observation_mimic_flat', 'code_code', '220227', 'Arterial O2 Saturation'), ('observation_mimic_flat', 'code_code', '50821', 'pO2'), ('observation_mimic_flat', 'code_code', '226063', 'Venous O2 Pressure'), ('observation_mimic_flat', 'code_code', '220277', 'O2 saturation pulseoxymetry'), ('observation_mimic_flat', 'code_code', '50816', 'Oxygen'), ('condition_mimic_flat', 'code_code', 'R0902', 'Hypoxemia'), ('condition_mimic_flat', 'code_code', '79902', 'Hypoxemia'), ('observation_mimic_flat', 'code_code', '50823', 'Required O2'), 

In [67]:
query_instance_10_07_20.iterative_gen_sql(
    query='''
    Count the number of patients that within an hour, their blood oxygen level is
    measured through both pulse oximeter (SpO2) and Arterial Blood Gas (ABG) test
    and the difference of those two measurements is more than 2%.
    ''',
    table_suffix='_mimic_flat',
    num_rounds=5
)

CLOSE CONCEPTS: 

            For "their blood oxygen level" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            []

            
            For "pulse oximeter" this is the list of relevant codes and their description, the format
            of list elements is ('table_name', 'code_column_name', 'code', 'display'):
            [('procedure_mimic_flat', 'code_code', '7538', 'Fetal pulse oximetry'), ('observation_mimic_flat', 'code_code', '220277', 'O2 saturation pulseoxymetry'), ('observation_mimic_flat', 'code_code', '50821', 'pO2'), ('observation_mimic_flat', 'code_code', '223770', 'O2 Saturation Pulseoxymetry Alarm - Low'), ('observation_mimic_flat', 'code_code', '223769', 'O2 Saturation Pulseoxymetry Alarm - High'), ('procedure_mimic_flat', 'code_code', '8963', 'Pulmonary artery pressure monitoring'), ('observation_mimic_flat', 'code_code', '50832', 'pO2, Body Fluid'