In [1]:
!pip install torch transformers bitsandbytes accelerate sqlparse

Collecting bitsandbytes
  Downloading bitsandbytes-0.43.0-py3-none-manylinux_2_24_x86_64.whl.metadata (1.8 kB)
Downloading bitsandbytes-0.43.0-py3-none-manylinux_2_24_x86_64.whl (102.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m102.2/102.2 MB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.43.0


In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [3]:
torch.cuda.is_available()

True

In [4]:
available_memory = torch.cuda.get_device_properties(0).total_memory

In [5]:
print(available_memory)

17066885120


In [6]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

In [7]:

# Cloning the GitHub repository
!git clone -q https://github.com/glee4810/ehrsql-2024.git
%cd ehrsql-2024

/kaggle/working/ehrsql-2024


In [8]:
import os
import json
import pandas as pd

# Directory paths for database, results and scoring program
DB_ID = 'mimic_iv'
BASE_DATA_DIR = 'data/mimic_iv'
RESULT_DIR = 'sample_result_submission/'
SCORE_PROGRAM_DIR = 'scoring_program/'

# File paths for the dataset and labels
TABLES_PATH = os.path.join('data', DB_ID, 'tables.json')               # JSON containing database schema
TRAIN_DATA_PATH = os.path.join(BASE_DATA_DIR, 'train', 'data.json')    # JSON file with natural language questions for training data
TRAIN_LABEL_PATH = os.path.join(BASE_DATA_DIR, 'train', 'label.json')  # JSON file with corresponding SQL queries for training data
VALID_DATA_PATH = os.path.join(BASE_DATA_DIR, 'test', 'data.json')    # JSON file for validation data
DB_PATH = os.path.join('data', DB_ID, f'{DB_ID}.sqlite')               # Database path

In [9]:
!wget https://physionet.org/static/published-projects/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2.zip
!unzip mimic-iv-clinical-database-demo-2.2
!gunzip -r mimic-iv-clinical-database-demo-2.2

--2024-03-26 07:22:18--  https://physionet.org/static/published-projects/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2.zip
Resolving physionet.org (physionet.org)... 18.18.42.54
Connecting to physionet.org (physionet.org)|18.18.42.54|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 16189661 (15M) [application/zip]
Saving to: 'mimic-iv-clinical-database-demo-2.2.zip'


2024-03-26 07:23:39 (198 KB/s) - 'mimic-iv-clinical-database-demo-2.2.zip' saved [16189661/16189661]

Archive:  mimic-iv-clinical-database-demo-2.2.zip
  inflating: mimic-iv-clinical-database-demo-2.2/LICENSE.txt  
  inflating: mimic-iv-clinical-database-demo-2.2/README.txt  
  inflating: mimic-iv-clinical-database-demo-2.2/SHA256SUMS.txt  
  inflating: mimic-iv-clinical-database-demo-2.2/demo_subject_id.csv  
 extracting: mimic-iv-clinical-database-demo-2.2/hosp/admissions.csv.gz  
  inflating: mimic-iv-clinical-database-demo-2.2/hosp/d_hcpcs.csv.gz  
  inflating: mimic-iv-clinical-database-d

In [10]:
%cd preprocess
!bash preprocess.sh
%cd ..

/kaggle/working/ehrsql-2024/preprocess
timeshift is True
start_year: 2100
time_span: 0
current_time: 2100-12-31 23:59:00
Processing patients, admissions, icustays, transfers
Cannot take a larger sample than population when 'replace=False
Use all available patients instead.
num_cur_patient: 4
num_non_cur_patient: 90
num_patient: 94
patients, admissions, icustays, transfers processed (took 0.2297 secs)
Processing dictionary tables (d_icd_diagnoses, d_icd_procedures, d_labitems, d_items)
d_icd_diagnoses, d_icd_procedures, d_labitems, d_items processed (took 2.5758 secs)
Processing diagnoses_icd table
diagnoses_icd processed (took 0.1616 secs)
Processing procedures_icd table
procedures_icd processed (took 0.0604 secs)
Processing labevents table
labevents processed (took 2.7756 secs)
Processing prescriptions table
prescriptions processed (took 0.8669 secs)
Processing COST table
cost processed (took 0.6457 secs)
Processing chartevents table
chartevents processed (took 4.2733 secs)
Processing

In [11]:
from utils.data_io import read_json as read_data

train_data = read_data(TRAIN_DATA_PATH)
train_label = read_data(TRAIN_LABEL_PATH)

valid_data = read_data(VALID_DATA_PATH)

In [12]:
print("Train data:", (len(train_data['data']), len(train_label)))
print("Valid data:", len(valid_data['data']))

Train data: (5124, 5124)
Valid data: 1167


In [13]:
# Explore keys and data structure
print(train_data.keys())
print(train_data['version'])
print(train_data['data'][0]['id'])
print(train_data['data'][0]['question'])

# Explore the label structure
# print(train_label.keys())
# print(train_label[list(train_label.keys())[0]])

dict_keys(['version', 'data'])
mimic_iv_demo_1.0_train
3b9849548e56c59f768d5447
Tell me the minimum respiratory rate in patient 10021118 in the first ICU visit.


Model Building Starts

In [14]:
!pip install accelerate



In [15]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# if available_memory > 17e9:
#     # if you have atleast 16GB of GPU memory, run load the model in float16
#     model = AutoModelForCausalLM.from_pretrained(
#         model_name,
#         trust_remote_code=True,
#         torch_dtype=torch.float16,
#         device_map="auto",
#         use_cache=True,
#     )
# else:
    # else, load in 8 bits – this is a bit slower
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    # torch_dtype=torch.float16,
    load_in_4bit=True,
    device_map="auto",
    use_cache=True,
)

tokenizer_config.json:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/515 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/691 [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/3.59G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

In [16]:
prompt = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'Humne kiya'
- Remember that revenue is price multiplied by quantity
- Remember that cost is supply_price multiplied by quantity

### Database Schema
This query will run on a database whose schema is represented in this string:

DROP TABLE IF EXISTS patients;
CREATE TABLE patients 
(
    row_id INT NOT NULL PRIMARY KEY,
    subject_id INT NOT NULL UNIQUE,
    gender VARCHAR(5) NOT NULL,
    dob TIMESTAMP(0) NOT NULL,
    dod TIMESTAMP(0)
);

DROP TABLE IF EXISTS admissions;
CREATE TABLE admissions
(
    row_id INT NOT NULL PRIMARY KEY,
    subject_id INT NOT NULL,
    hadm_id INT NOT NULL UNIQUE,
    admittime TIMESTAMP(0) NOT NULL,
    dischtime TIMESTAMP(0),
    admission_type VARCHAR(50) NOT NULL,
    admission_location VARCHAR(50) NOT NULL,
    discharge_location VARCHAR(50),
    insurance VARCHAR(255) NOT NULL,
    language VARCHAR(10),
    marital_status VARCHAR(50),
    age INT NOT NULL,
    FOREIGN KEY(subject_id) REFERENCES patients(subject_id)
);

DROP TABLE IF EXISTS d_icd_diagnoses;
CREATE TABLE d_icd_diagnoses
(
    row_id INT NOT NULL PRIMARY KEY,
    icd_code VARCHAR(10) NOT NULL UNIQUE,
    long_title VARCHAR(255) NOT NULL
);

DROP TABLE IF EXISTS d_icd_procedures;
CREATE TABLE d_icd_procedures 
(
    row_id INT NOT NULL PRIMARY KEY,
    icd_code VARCHAR(10) NOT NULL UNIQUE,
    long_title VARCHAR(255) NOT NULL
);

DROP TABLE IF EXISTS d_labitems;
CREATE TABLE d_labitems 
(
    row_id INT NOT NULL PRIMARY KEY,
    itemid INT NOT NULL UNIQUE,
    label VARCHAR(200)
);

DROP TABLE IF EXISTS d_items;
CREATE TABLE d_items 
(
    row_id INT NOT NULL PRIMARY KEY,
    itemid INT NOT NULL UNIQUE,
    label VARCHAR(200) NOT NULL,
    abbreviation VARCHAR(200) NOT NULL,
    linksto VARCHAR(50) NOT NULL
);

DROP TABLE IF EXISTS diagnoses_icd;
CREATE TABLE diagnoses_icd
(
    row_id INT NOT NULL PRIMARY KEY,
    subject_id INT NOT NULL,
    hadm_id INT NOT NULL,
    icd_code VARCHAR(10) NOT NULL,
    charttime TIMESTAMP(0) NOT NULL,
    FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id),
    FOREIGN KEY(icd_code) REFERENCES d_icd_diagnoses(icd_code)
);

DROP TABLE IF EXISTS procedures_icd;
CREATE TABLE procedures_icd
(
    row_id INT NOT NULL PRIMARY KEY,
    subject_id INT NOT NULL,
    hadm_id INT NOT NULL,
    icd_code VARCHAR(10) NOT NULL,
    charttime TIMESTAMP(0) NOT NULL,
    FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id),
    FOREIGN KEY(icd_code) REFERENCES d_icd_procedures(icd_code)
);

DROP TABLE IF EXISTS labevents;
CREATE TABLE labevents
(
    row_id INT NOT NULL PRIMARY KEY,
    subject_id INT NOT NULL,
    hadm_id INT NOT NULL,
    itemid INT NOT NULL,
    charttime TIMESTAMP(0),
    valuenum DOUBLE PRECISION,
    valueuom VARCHAR(20),
    FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id),
    FOREIGN KEY(itemid) REFERENCES d_labitems(itemid)
);

DROP TABLE IF EXISTS prescriptions;
CREATE TABLE prescriptions
(
    row_id INT NOT NULL PRIMARY KEY,
    subject_id INT NOT NULL,
    hadm_id INT NOT NULL,
    starttime TIMESTAMP(0) NOT NULL,
    stoptime TIMESTAMP(0),
    drug VARCHAR(255) NOT NULL,
    dose_val_rx VARCHAR(100) NOT NULL,
    dose_unit_rx VARCHAR(50) NOT NULL,
    route VARCHAR(50) NOT NULL,
    FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id)
);

DROP TABLE IF EXISTS cost;
CREATE TABLE cost
(
    row_id INT NOT NULL PRIMARY KEY,
    subject_id INT NOT NULL,
    hadm_id INT NOT NULL,
    event_type VARCHAR(20) NOT NULL,
    event_id INT NOT NULL,
    chargetime TIMESTAMP(0) NOT NULL,
    cost DOUBLE PRECISION NOT NULL,
    FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id),
    FOREIGN KEY(event_id) REFERENCES diagnoses_icd(row_id),
    FOREIGN KEY(event_id) REFERENCES procedures_icd(row_id),
    FOREIGN KEY(event_id) REFERENCES labevents(row_id),
    FOREIGN KEY(event_id) REFERENCES prescriptions(row_id)  
);

DROP TABLE IF EXISTS chartevents;
CREATE TABLE chartevents
(
    row_id INT NOT NULL PRIMARY KEY,
    subject_id INT NOT NULL,
    hadm_id INT NOT NULL,
    stay_id INT NOT NULL,
    itemid INT NOT NULL,
    charttime TIMESTAMP(0) NOT NULL,
    valuenum DOUBLE PRECISION,
    valueuom VARCHAR(50),
    FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id),
    FOREIGN KEY(stay_id) REFERENCES icustays(stay_id),
    FOREIGN KEY(itemid) REFERENCES d_items(itemid)
);

DROP TABLE IF EXISTS inputevents;
CREATE TABLE inputevents
(
    row_id INT NOT NULL PRIMARY KEY,
    subject_id INT NOT NULL,
    hadm_id INT NOT NULL,
    stay_id INT NOT NULL,
    starttime TIMESTAMP(0) NOT NULL,
    itemid INT NOT NULL,
    amount DOUBLE PRECISION,
    FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id),
    FOREIGN KEY(stay_id) REFERENCES icustays(stay_id),
    FOREIGN KEY(itemid) REFERENCES d_items(itemid)
);

DROP TABLE IF EXISTS outputevents;
CREATE TABLE outputevents
(
    row_id INT NOT NULL PRIMARY KEY,
    subject_id INT NOT NULL,
    hadm_id INT NOT NULL,
    stay_id INT NOT NULL,
    charttime TIMESTAMP(0) NOT NULL,
    itemid INT NOT NULL,
    value DOUBLE PRECISION,
    FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id),
    FOREIGN KEY(stay_id) REFERENCES icustays(stay_id),
    FOREIGN KEY(itemid) REFERENCES d_items(itemid)
);

DROP TABLE IF EXISTS microbiologyevents;
CREATE TABLE microbiologyevents
(
    row_id INT NOT NULL PRIMARY KEY,
    subject_id INT NOT NULL,
    hadm_id INT NOT NULL,
    charttime TIMESTAMP(0) NOT NULL,
    spec_type_desc VARCHAR(100),
    test_name VARCHAR(100),
    org_name VARCHAR(100),
    FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id)
);

DROP TABLE IF EXISTS icustays;
CREATE TABLE icustays
(
    row_id INT NOT NULL PRIMARY KEY,
    subject_id INT NOT NULL,
    hadm_id INT NOT NULL,
    stay_id INT NOT NULL UNIQUE,
    first_careunit VARCHAR(20) NOT NULL,
    last_careunit VARCHAR(20) NOT NULL,
    intime TIMESTAMP(0) NOT NULL,
    outtime TIMESTAMP(0),
    FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id)
);

DROP TABLE IF EXISTS transfers;
CREATE TABLE transfers
(
    row_id INT NOT NULL PRIMARY KEY,
    subject_id INT NOT NULL,
    hadm_id INT NOT NULL,
    transfer_id INT NOT NULL,
    eventtype VARCHAR(20) NOT NULL,
    careunit VARCHAR(20),
    intime TIMESTAMP(0) NOT NULL,
    outtime TIMESTAMP(0),
    FOREIGN KEY(hadm_id) REFERENCES admissions(hadm_id)
);


### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

In [17]:
import sqlparse

def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )

    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [18]:
questions=[]
ids = []
length = len(valid_data['data'])
print(length)
for i in range (length):
    questions.append(valid_data['data'][i]['question'])
    ids.append(valid_data['data'][i]['id'])
#questions


1167


In [19]:
sqls = []
i = 0

# Open a text file in write mode
with open("queries.txt", "w") as file:
    file.write('{')
    for question in questions:
        sql = generate_query(question)
        if 'Humne kiya' in sql:
            sql = 'null'
        sql = sql.replace('\n', '')
        sql = sql.replace(';','')
        sqls.append(sql)
        id=ids[i]
        print(id+":",end="")
        # Write the id and SQL query to the file within double quotes
        file.write(f'"{id}":"{sql}",\n')
        
        print(sql)
        i += 1
        print("Count is ", i)
    file.write('}')
    file.close()


2024-03-26 07:25:56.953115: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-26 07:25:56.953224: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-26 07:25:57.085603: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


282f008dd8dfb8f4a1dd6999:SELECT m.org_nameFROM microbiologyevents mJOIN  (SELECT s.subject_id,          MIN(m.charttime) AS min_charttime   FROM microbiologyevents m   JOIN admissions s ON m.hadm_id = s.hadm_id   WHERE s.subject_id = 10027602   GROUP BY s.subject_id) AS min_chart ON m.hadm_id = min_chart.subject_idWHERE m.test_name = 'Urine test'  AND m.spec_type_desc = 'Microorganism'
Count is  1
47fd000ef0b1033a8aabfac8:SELECT p.drug,       COUNT(*) AS frequencyFROM prescriptions pWHERE EXTRACT(YEAR              FROM p.starttime) >= 2010GROUP BY p.drugORDER BY frequency DESCLIMIT 5
Count is  2
7f07c59357750ac2b84e5221:SELECT m.spec_type_desc,       COUNT(*) AS frequencyFROM microbiologyevents mJOIN  (SELECT subject_id   FROM admissions   WHERE charttime >= (CURRENT_DATE - interval '2 months')     AND (icd_code = '757072'          OR icd_code = '757073')   GROUP BY subject_id) AS p ON m.subject_id = p.subject_idGROUP BY m.spec_type_descORDER BY frequency DESCLIMIT 5
Count is  3
c83990

In [20]:
# sqls = []
# i=0
# for question in questions:
#     sql = generate_query(question)
#     if 'Humne kiya' in sql:
#       sql = 'null'
#     sqls.append(sql.replace('\n', ''))
#     print(sql)
#     i+=1
#     print("Count is ",i)