<!-- Original Implementation by Gyubok Lee -->
<!-- Refined by Sunjun Kweon on 2024-01-15. -->
<!-- Refined by Woosog Chay on 2025-02-21. -->
<!-- Note: This Jupyter notebook is tailored to the unique requirements of the EHRSQL project. It includes specific modifications and additional adjustments to cater to the dataset and experiment objectives. -->

# OpenAI Model (GPT-4o-mini) Sample Code for MIMIC-IV: Single-Turn Text-to-SQL with Abstention on Electronic Health Records


<!-- ## Task Introduction
The goal of the task is to **develop a reliable text-to-SQL system on EHR**. Unlike standard text-to-SQL tasks, this system must handle all types of questions, including answerable and unanswerable ones with respect to the EHR database structure. For answerable questions, the system must accurately generate SQL queries. For unanswerable questions, the system must correctly identify them as such, thereby preventing incorrect SQL predictions for infeasible questions. The range of questions includes answerable queries about MIMIC-IV, covering topics such as patient demographics, vital signs, and specific disease survival rates ([EHRSQL](https://github.com/glee4810/EHRSQL)). Additionally, there are specially designed unanswerable questions intended to challenge the system. Successfully completing this task will result in the creation of a reliable question-answering system for EHRs, significantly improving the flexibility and efficiency of clinical knowledge exploration in hospitals. -->

## Steps of Baseline Code

- [x] Step 0: Prerequisites (OpenAI API)
- [x] Step 1: Clone the GitHub Repository and Install Dependencies
- [x] Step 2: Import Global Packages and Define File Paths
- [x] Step 3: Load Data and Prepare Datasets
- [x] Step 4: Building a predictive model using chatGPT
- [x] Step 5: Evaluation


## Step 0 : Prerequisites (OpenAI API key)



In [1]:
import getpass
from IPython.display import clear_output

clear_output()
# Please enter your API key
new_api_key = ''
while len(new_api_key) == 0:
    new_api_key = getpass.getpass("Please input your API key: ")
    clear_output()

When submitting your code for code verification, please be sure to submit your OpenAI API key along with your code, like in the sample_submission_chatgpt_api_key.json.

## Step 1: Clone the GitHub Repository and Install Dependencies

Before you begin, make sure you're in the correct directory. If you need to reset the repository directory, remove the existing directory by uncommenting and executing the following lines:

In [2]:
%cd /content
!rm -rf ai612_project_1

/content


Now, clone the repository and install the required Python packages:

In [3]:
# Cloning the GitHub repository
!git clone -q https://github.com/benchay1999/ai612_project_1.git
%cd ai612_project_1

# Installing dependencies
!pip install -q tiktoken openai func_timeout

/content/ai612_project_1
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for func_timeout (setup.py) ... [?25l[?25hdone


Use the `%load_ext` magic command to automatically reload modules before executing a new line:

In [4]:
%load_ext autoreload
%autoreload 2

## Step 2: Import Global Packages and Define File Paths

After setting up the repository and dependencies, the next step is to import packages that will be used globally throughout this notebook and to define the file paths to our datasets.

In [5]:
import os
import json
import pandas as pd
from tqdm import tqdm

# Directory paths for database, results and scoring program
DB_ID = 'mimic_iv'
BASE_DATA_DIR = 'data'
RESULT_DIR = 'results'
SCORING_DIR = 'scoring'

# File paths for the dataset and labels
TABLES_PATH = os.path.join('database', 'tables.json')               # JSON containing database schema
VALID_DATA_PATH = os.path.join(BASE_DATA_DIR, 'valid_data.json')    # JSON file for validation data
VALID_LABEL_PATH = os.path.join(BASE_DATA_DIR, 'valid_label.json')  # JSON file for validation labels (for evaluation)
DB_PATH = os.path.join('data', DB_ID, f'{DB_ID}.sqlite')            # Database path

# Load data
with open(os.path.join(VALID_DATA_PATH), 'r') as f:
    valid_data = json.load(f)

with open(os.path.join(VALID_LABEL_PATH), 'r') as f:
    valid_labels = json.load(f)

# Load SQL assumptions for MIMIC-IV
assumptions = open("database/mimic_iv_assumption.txt", "r").read()

In [6]:
print(assumptions)

- Use SQLite for SQL query generation.
- Use DENSE_RANK() when asked for ranking results, but retrieve only the relevant items, excluding their counts or ranks. When the question does not explicitly mention ranking, don't use DENSE_RANK().
- For the top N results, return only the relevant items, excluding their counts.
- Use DISTINCT in queries related to the cost of events, drug routes, or when counting or listing patients or hospital/ICU visits.
- When calculating the total cost, sum the patient’s diagnoses, procedures, lab events, and prescription costs within a single hospital admission only.
- Use DISTINCT to retrieve the cost of a single event (diagnosis, procedure, lab event, or prescription).
- For cost-related questions, use cost.event_type to specify the event type ('procedures_icd', 'labevents', 'prescriptions', 'diagnoses_icd') when retrieving costs for procedures, lab events, prescriptions, or diagnoses, respectively.
- Treat questions that start with "is it possible," "ca

In [7]:
# This function loads and processes a database schema from a JSON file.

def load_schema(DATASET_JSON):
    schema_df = pd.read_json(DATASET_JSON)
    schema_df = schema_df.drop(['column_names','table_names'], axis=1)
    schema = []
    f_keys = []
    p_keys = []
    for index, row in schema_df.iterrows():
        tables = row['table_names_original']
        col_names = row['column_names_original']
        col_types = row['column_types']
        foreign_keys = row['foreign_keys']
        primary_keys = row['primary_keys']
        for col, col_type in zip(col_names, col_types):
            index, col_name = col
            if index > -1:
                schema.append([row['db_id'], tables[index], col_name, col_type])
        for primary_key in primary_keys:
            index, column = col_names[primary_key]
            p_keys.append([row['db_id'], tables[index], column])
        for foreign_key in foreign_keys:
            first, second = foreign_key
            first_index, first_column = col_names[first]
            second_index, second_column = col_names[second]
            f_keys.append([row['db_id'], tables[first_index], tables[second_index], first_column, second_column])
    db_schema = pd.DataFrame(schema, columns=['Database name', 'Table Name', 'Field Name', 'Type'])
    primary_key = pd.DataFrame(p_keys, columns=['Database name', 'Table Name', 'Primary Key'])
    foreign_key = pd.DataFrame(f_keys,
                        columns=['Database name', 'First Table Name', 'Second Table Name', 'First Table Foreign Key',
                                 'Second Table Foreign Key'])
    return db_schema, primary_key, foreign_key

# Generates a string representation of foreign key relationships in a MySQL-like format for a specific database.
def find_foreign_keys_MYSQL_like(foreign, db_id):
    df = foreign[foreign['Database name'] == db_id]
    output = "["
    for index, row in df.iterrows():
        output += row['First Table Name'] + '.' + row['First Table Foreign Key'] + " = " + row['Second Table Name'] + '.' + row['Second Table Foreign Key'] + ', '
    output = output[:-2] + "]"
    if len(output)==1:
        output = '[]'
    return output

# Creates a string representation of the fields (columns) in each table of a specific database, formatted in a MySQL-like syntax.
def find_fields_MYSQL_like(db_schema, db_id):
    df = db_schema[db_schema['Database name'] == db_id]
    df = df.groupby('Table Name')
    output = ""
    for name, group in df:
        output += "Table " +name+ ', columns = ['
        for index, row in group.iterrows():
            output += row["Field Name"]+', '
        output = output[:-2]
        output += "]\n"
    return output

# Generates a comprehensive textual prompt describing the database schema, including tables, columns, and foreign key relationships.
# Then, add the SQL assumptions for MIMIC-IV
def create_schema_prompt(db_id, db_schema, primary_key, foreign_key, is_lower=True):
    prompt = find_fields_MYSQL_like(db_schema, db_id)
    prompt += "Foreign_keys = " + find_foreign_keys_MYSQL_like(foreign_key, db_id)
    if is_lower:
        prompt = prompt.lower()
    prompt += "\nSQL Assumptions that you must follow:\n" + assumptions
    return prompt

## Step 3: Load Data and Prepare Datasets

Now that we have our environment and paths set up, the next step is to load the data and prepare it for our model. This involves preprocessing the MIMIC-IV database, reading the data from JSON files, splitting it into training and validation sets, and then initializing our dataset object.

In [8]:
from utils import read_json as read_data

db_schema, primary_key, foreign_key = load_schema(TABLES_PATH)

table_prompt = create_schema_prompt(DB_ID, db_schema, primary_key, foreign_key)

### Data Statistics

In [9]:
print("Valid data:", (len(valid_data['data']), len(valid_labels)))

Valid data: (20, 20)


### Data Format

Before proceeding with the model, it is always a good idea to explore the dataset. This includes checking the keys in the dataset, and viewing the first few entries to understand the structure of the data.



In [10]:
# Explore keys and data structure
print(valid_data.keys())
print(valid_labels[list(valid_labels.keys())[0]])

dict_keys(['version', 'data'])
SELECT AVG(labevents.valuenum) FROM labevents WHERE labevents.hadm_id IN ( SELECT admissions.hadm_id FROM admissions WHERE admissions.subject_id = 10021487 ) AND labevents.itemid IN ( SELECT d_labitems.itemid FROM d_labitems WHERE d_labitems.label = 'bilirubin, direct' ) AND strftime('%Y-%m',labevents.charttime) >= '2100-05' GROUP BY strftime('%Y-%m',labevents.charttime)


In [11]:
# Prompt for chatGPT containing meta-data of the database such as columns and foreign keys, and SQL assumptions

print(table_prompt)

table admissions, columns = [row_id, subject_id, hadm_id, admittime, dischtime, admission_type, admission_location, discharge_location, insurance, language, marital_status, age]
table chartevents, columns = [row_id, subject_id, hadm_id, stay_id, itemid, charttime, valuenum, valueuom]
table cost, columns = [row_id, subject_id, hadm_id, event_type, event_id, chargetime, cost]
table d_icd_diagnoses, columns = [row_id, icd_code, long_title]
table d_icd_procedures, columns = [row_id, icd_code, long_title]
table d_items, columns = [row_id, itemid, label, abbreviation, linksto]
table d_labitems, columns = [row_id, itemid, label]
table diagnoses_icd, columns = [row_id, subject_id, hadm_id, icd_code, charttime]
table icustays, columns = [row_id, subject_id, hadm_id, stay_id, first_careunit, last_careunit, intime, outtime]
table inputevents, columns = [row_id, subject_id, hadm_id, stay_id, starttime, itemid, totalamount, totalamountuom]
table labevents, columns = [row_id, subject_id, hadm_id, it

## Step 4: Building a predictive model using chatGPT

In [12]:
# Save your api key into json file
import json

api_path = 'sample_submission_chatgpt_api_key.json'
json_data = {}
json_data['key'] = new_api_key
with open(api_path, 'w') as file:
    json.dump(json_data, file, indent=4)

In [13]:
import os
import re
import json
import tiktoken

import openai
from openai import OpenAI

client = OpenAI(api_key=json_data['key'])

def post_process(answer):
    answer = answer.replace('\n', ' ')
    answer = re.sub('[ ]+', ' ', answer)
    answer = answer.replace("```sql", "").replace("```", "").strip()
    return answer

class Model():
    def __init__(self):
        current_real_dir = os.getcwd()
        # current_real_dir = os.path.dirname(os.path.realpath(__file__))
        target_dir = os.path.join(current_real_dir, 'sample_submission_chatgpt_api_key.json')

        if os.path.isfile(target_dir):
            with open(target_dir, 'rb') as f:
                openai.api_key = json.load(f)['key']
        if not os.path.isfile(target_dir) or openai.api_key == "":
            raise Exception("Error: no API key file found.")

    def ask_chatgpt(self, prompt, model="gpt-4o-mini", temperature=0.6):
        response = client.chat.completions.create(
                    model=model,
                    temperature=temperature,
                    messages=prompt
                )
        return response.choices[0].message.content

    def generate(self, input_data):
        """
        Arguments:
            input_data: list of python dictionaries containing 'id' and 'input'
        Returns:
            labels: python dictionary containing sql prediction or 'null' values associated with ids
        """

        labels = {}

        for sample in tqdm(input_data):
            answer = self.ask_chatgpt(sample['input'])
            labels[sample["id"]] = post_process(answer)

        return labels

In [14]:
myModel = Model()
data = valid_data["data"]

In [15]:
# System prompt for chatGPT
system_msg = "Given the following SQL tables and SQL assumptions you must follow, your job is to write queries given a user’s request.\n IMPORTANT: If you think you cannot predict the SQL accurately, you must answer with 'null'."

In [16]:
input_data = []
for sample in data:
    sample_dict = {}
    sample_dict['id'] = sample['id']
    conversation = [{"role": "system", "content": system_msg + '\n\n' + table_prompt}]
    user_question_wrapper = lambda question: '\n\n' + f"""NLQ: \"{question}\"\nSQL: """
    conversation.append({"role": "user", "content": user_question_wrapper(sample['question'])})
    sample_dict['input'] = conversation
    input_data.append(sample_dict)

In [17]:
# First message
print(conversation[0]['content'])

Given the following SQL tables and SQL assumptions you must follow, your job is to write queries given a user’s request.
 IMPORTANT: If you think you cannot predict the SQL accurately, you must answer with 'null'.

table admissions, columns = [row_id, subject_id, hadm_id, admittime, dischtime, admission_type, admission_location, discharge_location, insurance, language, marital_status, age]
table chartevents, columns = [row_id, subject_id, hadm_id, stay_id, itemid, charttime, valuenum, valueuom]
table cost, columns = [row_id, subject_id, hadm_id, event_type, event_id, chargetime, cost]
table d_icd_diagnoses, columns = [row_id, icd_code, long_title]
table d_icd_procedures, columns = [row_id, icd_code, long_title]
table d_items, columns = [row_id, itemid, label, abbreviation, linksto]
table d_labitems, columns = [row_id, itemid, label]
table diagnoses_icd, columns = [row_id, subject_id, hadm_id, icd_code, charttime]
table icustays, columns = [row_id, subject_id, hadm_id, stay_id, first_ca

In [18]:
# Second message

print(conversation[1]['content'])



NLQ: "How was patient 10006053 first admitted to the hospital in terms of admission type?"
SQL: 


In [19]:
# Generate answer(SQL) from chatGPT
label_y = myModel.generate(input_data)

100%|██████████| 20/20 [00:26<00:00,  1.34s/it]


Below is how the predicted labels(SQLs) look like. **This should be your submission.**

In [20]:
label_y

{'b9c136c1e1d19649caabdeb4': "SELECT AVG(labevents.valuenum) AS average_bilirubin FROM labevents JOIN d_labitems ON labevents.itemid = d_labitems.itemid WHERE labevents.hadm_id IN ( SELECT hadm_id FROM admissions WHERE subject_id = 10021487 ) AND d_labitems.label = 'Bilirubin, Direct' AND labevents.charttime >= datetime('2100-05-01') GROUP BY strftime('%Y-%m', labevents.charttime);",
 'b389e224ed07b11a553f0329': 'SELECT DISTINCT d_icd_diagnoses.long_title FROM diagnoses_icd JOIN d_icd_diagnoses ON diagnoses_icd.icd_code = d_icd_diagnoses.icd_code WHERE diagnoses_icd.subject_id = 10001217 ORDER BY diagnoses_icd.charttime LIMIT 1;',
 '0845eda9197d9666e0b3a017': "WITH ICU_Visits AS ( SELECT icustays.hadm_id, chartevents.charttime, chartevents.valuenum, DENSE_RANK() OVER (PARTITION BY icustays.hadm_id ORDER BY chartevents.charttime) AS visit_rank FROM icustays JOIN chartevents ON icustays.stay_id = chartevents.stay_id WHERE chartevents.itemid = (SELECT itemid FROM d_items WHERE label = 'ar

In [21]:
from utils import write_json as write_label

# Save the filtered predictions to a JSON file
os.makedirs(RESULT_DIR, exist_ok=True)
SCORING_OUTPUT_DIR = os.path.join(RESULT_DIR, '20240000.json') # The file to submit
write_label(SCORING_OUTPUT_DIR, label_y)

# Verify the file creation
print("Listing files in RESULT_DIR:")
!ls {RESULT_DIR}

Listing files in RESULT_DIR:
20240000.json


## Step 5: Evaluation

You can evaluate your own valid set using the following code:

*Note*: The risk for questions that are not answerable is 0 here since there are no such data in the valid set.

In [22]:
from scoring.scorer import Scorer
with open("data/valid_data.json", "r") as f:
    data = json.load(f)

with open("data/valid_label.json", "r") as f:
    gold_labels = json.load(f)

with open("results/20240000.json", "r") as f:
    predictions = json.load(f)
scorer = Scorer(
    data=data,
    predictions=label_y,
    gold_labels=valid_labels,
    score_dir="results"
)
print()
print(scorer.get_scores())




100%|██████████| 20/20 [00:00<00:00, 83.70it/s]

No data for risk_notans. This happens when there is no `notans` questions in the evaluation dataset. This metric will be ignored when calculating the final score. This will not happen when evaluating on the test set.
Coverage for answerable questions (in %): 100.0 || 20/20
Risk for answerable questions (in %): 55.0 || 11/20
Risk for unanswerable questions (in %): None || 0/0
Final score: 72.5
{'cov_ans*100': 100.0, 'risk_ans*100': 55.0, 'risk_notans*100': None, 'final_score': 72.5}



