<!-- Original Implementation by Gyubok Lee -->
<!-- Refined by Seongsu Bae on 2024-01-14 -->
<!-- Note: This Jupyter notebook is specifically designed for the EHRSQL project. It features custom modifications and enhancements to cater to the unique dataset and experiment objectives -->

# Local Model Sample Code (T5) for EHRSQL: Reliable Text-to-SQL Modeling on Electronic Health Records

<p align="left" float="left">
  <img src="https://github.com/glee4810/ehrsql-2024/raw/master/image/logo.png" height="100" />
</p>

Welcome to the T5-based Local-Model Baseline Code for the EHRSQL task, a component of the Clinical NLP 2024 Workshop. This Jupyter notebook serves as a comprehensive guide to developing a robust Text-to-SQL model for Electronic Health Records (EHRs).

## Steps in This Jupyter Notebook
- [x] Step 1: Clone the GitHub Repository and Install Dependencies
- [x] Step 2: Import Global Packages and Define File Paths
- [x] Step 3: Load Data and Prepare Datasets
- [x] Step 4: Construct a Text-to-SQL Baseline Model
- [x] Step 5: Initial Model Evaluation on All Queries
- [x] Step 6: Model Evaluation Considering Unanswerable Questions
- [x] Step 7: Test data inference
- [x] Step 8: Submission

## Getting Started

Begin your journey with the EHRSQL task by following these structured steps (from Step 1 to Step 8). Each section is designed to guide you smoothly through the process, from setup to submission. We're eager to see the innovative solutions you'll bring to the field of Text-to-SQL modeling for electronic health records.

## Step 1: Clone the GitHub Repository and Install Dependencies

Before you begin, make sure you're in the correct directory. If you need to reset the repository directory, remove the existing directory by uncommenting and executing the following lines:

In [3]:
%cd /content
!rm -rf ehrsql-2024

/content


Now, clone the repository and install the required Python packages:

In [4]:
# Cloning the GitHub repository
!git clone -q https://github.com/glee4810/ehrsql-2024.git
%cd ehrsql-2024

# Installing dependencies
!pip install -q transformers
!pip install -q sentencepiece

/content/ehrsql-2024


Use the `%load_ext` magic command to automatically reload modules before executing a new line:

In [5]:
%load_ext autoreload
%autoreload 2

## Step 2: Import Global Packages and Define File Paths

Now that the repository and dependencies are set up, let's import the necessary global packages and define file paths for our datasets.

In [6]:
import os
import json
import numpy as np
import pandas as pd
from collections import Counter

Directory Paths:
- `BASE_DATA_DIR`: The primary directory for all dataset files, ensuring centralized and organized data storage.
- `RESULT_DIR`: Path to store results for submission.


Dataset File Paths:
- `TABLES_PATH`: Contains the structure of database tables (JSON format).
- `TRAIN_DATA_PATH`: Natural language questions for training (JSON format).
- `TRAIN_LABEL_PATH`: Corresponding SQL queries for training labels (JSON format).
- `VALID_DATA_PATH`: Validation dataset with natural language questions (JSON format).

In [7]:
# Directory paths for database, results and scoring program
DB_ID = 'mimic_iv'
BASE_DATA_DIR = 'data/mimic_iv'
RESULT_DIR = 'sample_result_submission/'
SCORE_PROGRAM_DIR = 'scoring_program/'

# File paths for the dataset and labels
TABLES_PATH = os.path.join('data', DB_ID, 'tables.json')               # JSON containing database schema
TRAIN_DATA_PATH = os.path.join(BASE_DATA_DIR, 'train', 'data.json')    # JSON file with natural language questions for training data
TRAIN_LABEL_PATH = os.path.join(BASE_DATA_DIR, 'train', 'label.json')  # JSON file with corresponding SQL queries for training data
VALID_DATA_PATH = os.path.join(BASE_DATA_DIR, 'valid', 'data.json')    # JSON file for validation data
DB_PATH = os.path.join('data', DB_ID, f'{DB_ID}.sqlite')               # Database path

## Step 3: Load Data and Prepare Datasets

Now that we have our environment and paths set up, the next step is to load the data and prepare it for our model. This involves preprocessing the MIMIC-IV database, reading the data from JSON files, splitting it into training and validation sets, and then initializing our dataset object.

### Download and Preprocess MIMIC-IV Database Demo

In [8]:
!wget https://physionet.org/static/published-projects/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2.zip
!unzip mimic-iv-clinical-database-demo-2.2
!gunzip -r mimic-iv-clinical-database-demo-2.2

--2025-06-30 11:29:19--  https://physionet.org/static/published-projects/mimic-iv-demo/mimic-iv-clinical-database-demo-2.2.zip
Resolving physionet.org (physionet.org)... 18.18.42.54
Connecting to physionet.org (physionet.org)|18.18.42.54|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 16189661 (15M) [application/zip]
Saving to: ‘mimic-iv-clinical-database-demo-2.2.zip’


2025-06-30 11:29:52 (480 KB/s) - ‘mimic-iv-clinical-database-demo-2.2.zip’ saved [16189661/16189661]

Archive:  mimic-iv-clinical-database-demo-2.2.zip
  inflating: mimic-iv-clinical-database-demo-2.2/LICENSE.txt  
  inflating: mimic-iv-clinical-database-demo-2.2/README.txt  
  inflating: mimic-iv-clinical-database-demo-2.2/SHA256SUMS.txt  
  inflating: mimic-iv-clinical-database-demo-2.2/demo_subject_id.csv  
 extracting: mimic-iv-clinical-database-demo-2.2/hosp/admissions.csv.gz  
  inflating: mimic-iv-clinical-database-demo-2.2/hosp/d_hcpcs.csv.gz  
  inflating: mimic-iv-clinical-database-d

In [9]:
%cd preprocess
!bash preprocess.sh
%cd ..

/content/ehrsql-2024/preprocess
timeshift is True
start_year: 2100
time_span: 0
current_time: 2100-12-31 23:59:00
Processing patients, admissions, icustays, transfers
Cannot take a larger sample than population when 'replace=False
Use all available patients instead.
num_cur_patient: 4
num_non_cur_patient: 90
num_patient: 94
patients, admissions, icustays, transfers processed (took 0.1326 secs)
Processing dictionary tables (d_icd_diagnoses, d_icd_procedures, d_labitems, d_items)
d_icd_diagnoses, d_icd_procedures, d_labitems, d_items processed (took 1.5603 secs)
Processing diagnoses_icd table
diagnoses_icd processed (took 0.0974 secs)
Processing procedures_icd table
procedures_icd processed (took 0.0389 secs)
Processing labevents table
labevents processed (took 1.4311 secs)
Processing prescriptions table
prescriptions processed (took 0.4612 secs)
Processing COST table
cost processed (took 0.3343 secs)
Processing chartevents table
chartevents processed (took 2.5671 secs)
Processing inpute

### Load Data from JSON

In [10]:
from utils.data_io import read_json as read_data
from utils.data_io import write_json as write_data

# Load train and validation sets
train_data = read_data(TRAIN_DATA_PATH)
train_label = read_data(TRAIN_LABEL_PATH)
valid_data = read_data(VALID_DATA_PATH)

# Quick summary of the dataset
print(f"Train data: {len(train_data['data'])} entries, Train labels: {len(train_label)} entries")
print(f"Valid data: {len(valid_data['data'])} entries")

Train data: 5124 entries, Train labels: 5124 entries
Valid data: 1163 entries


### Explore the Dataset

Before proceeding with the model, it is always a good idea to explore the dataset. This includes checking the keys in the dataset, and viewing the first few entries to understand the structure of the data.

In [11]:
# Explore keys and data structure
print(train_data.keys())
print(train_data['version'])
print(train_data['data'][0])

# Explore the label structure
print(train_label.keys())
print(train_label[list(train_label.keys())[0]])

dict_keys(['version', 'data'])
train_v1.1.0
{'id': '86c0624f79d2e1a57bd72381', 'question': 'What are the consumption methods of ampicillin sodium?'}
dict_keys(['86c0624f79d2e1a57bd72381', '8410ba33ca60f450bef1f5d8', '657a58e806b8805fa2a4714d', 'db32ddeefc60ef2e085d547a', '192595f3a72877272b046e55', 'def6d375a2be3ced2d907c6d', 'bc72b33b40a0a574955d98e6', 'e28ceb0ac6e10e4e5f14b4dd', 'd162e9cc060eadbc00511ac2', 'a229a64f2bf74f9fc9dfa626', 'dac5c6aaac6b57cdc328694e', '56fd4c77a08d2c69466f3daf', 'b9bcd099d5f42ea75aca129e', 'b64c62eb05aaad4071af83a7', '368dc6fa47ea7885905c48d5', '1f76d064d52e32276438fbaf', '8073dc91a374c05758eb7cd2', '9d3038ee311cade2b5621bce', '9906646cfa86ce749ac4d08e', 'df7a774b0aacbddb0a1facfe', '9298a838a1b1fa2910b4a9eb', '975a9d2f0afd09663ee67165', '3eb6d37c0b82ceb5b2df547e', '612076c066409c1c13b808e5', '3da43478d4e899fc1b0f0fe3', '21e946effbb48dafa5e842d3', '7c25dcee711f95e8760994eb', 'eec30a374f5a614dddb5c128', '01a258cef08c0eb27a9ef878', 'efd5c9c44ecab92af5e496fa', 

### Data Preprocessing

For the **development phase** (until the release of the test data on Monday, February 26, 2024), we will use the validation data as our evaluation dataset. Therefore, we need to reorganize our data as follows:
- Original Validation Data as Test Set: The existing validation data will now serve as our test set. This set will be used for final evaluations.
- Split Original Training Data: We will split the original training data into new training and validation sets. This allows us to develop and tune our model effectively during the development phase.

For the **final testing phase**, starting from the test data release on February 26, 2024, participants should use the original test dataset for final evaluations and submissions.


In [12]:
from sklearn.model_selection import train_test_split

# Define stratification criteria for consistent distribution between answerable and unanswerable questions
stratify = ['unans' if train_label[id_]=='null' else 'ans' for id_ in list(train_label.keys())]

# Split the original training data into new training and validation sets, while maintaining the distribution
new_train_keys, new_valid_keys = train_test_split(
    list(train_label.keys()),
    train_size=0.8,
    random_state=42,
    stratify=stratify
)

# Initialize containers for the new training and validation sets
new_train_data = []
new_train_label = {}
new_valid_data = []
new_valid_label = {}

# Sort each sample into the new training or validation set as determined by the split
for sample in train_data['data']:
    if sample['id'] in new_train_keys:
        new_train_data.append(sample)
        new_train_label[sample['id']] = train_label[sample['id']]
    elif sample['id'] in new_valid_keys:
        new_valid_data.append(sample)
        new_valid_label[sample['id']] = train_label[sample['id']]
    else:
        # If a sample is neither in the train nor valid keys, raise an error
        raise ValueError(f"Error: Sample with ID {sample['id']} has an invalid split.")

# Structure the new datasets in a JSON-compatible format
new_train_data = {'version': f'{DB_ID}_sample', 'data': new_train_data}
new_valid_data = {'version': f'{DB_ID}_sample', 'data': new_valid_data}

# Display the size of the new training and validation sets for verification
print(f"New Train data: {len(new_train_data['data'])} entries, New Train labels: {len(new_train_label)} entries, Unanswerable: {sum(value == 'null' for value in new_train_label.values())}")
print(f"New Valid data: {len(new_valid_data['data'])} entries, New Valid labels: {len(new_valid_label)} entries, Unanswerable: {sum(value == 'null' for value in new_valid_label.values())}")

New Train data: 4099 entries, New Train labels: 4099 entries, Unanswerable: 360
New Valid data: 1025 entries, New Valid labels: 1025 entries, Unanswerable: 90


In [13]:
# Set directory for the new splitted data
NEW_TRAIN_DIR = os.path.join(BASE_DATA_DIR, '__train')
NEW_VALID_DIR = os.path.join(BASE_DATA_DIR, '__valid')
NEW_TEST_DIR = os.path.join(BASE_DATA_DIR, 'valid')

# Save the new datasets to JSON files for later use
write_data(os.path.join(NEW_TRAIN_DIR, "data.json"), new_train_data)
write_data(os.path.join(NEW_TRAIN_DIR, "label.json"), new_train_label)
write_data(os.path.join(NEW_VALID_DIR, "data.json"), new_valid_data)
write_data(os.path.join(NEW_VALID_DIR, "label.json"), new_valid_label)

### Preparing the `Dataset` Class

The next step is to create a `Dataset` class that will be used by PyTorch to feed data to our model. This class handles the tokenization of our questions and SQL labels, and prepares the data in a format that our T5 model can consume.

In [14]:
import json
import random
import torch
from torch.utils.data import Dataset


def encode_file(tokenizer, text, max_length, truncation=True, padding=True, return_tensors="pt"):
    """
    Tokenizes the text and returns tensors.
    """
    return tokenizer(
        text,
        max_length=max_length,
        truncation=truncation,
        padding=padding,
        return_tensors=return_tensors,
    )


class T5Dataset(Dataset):
    """
    A dataset class for the T5 model, handling the conversion of natural language questions to SQL queries.
    """
    def __init__(
        self,
        tokenizer,
        data_dir,
        is_test=False,
        max_source_length=256, # natural langauge question
        max_target_length=512, # SQL
        db_id='mimiciii', # NOTE: `mimic_iv` will be used for codabench
        tables_file=None,
        exclude_unans=False, # exclude unanswerable questions b/c they have no valid sql.
        random_seed=0,
        append_schema_info=False,
    ):

        super().__init__()
        self.tokenizer = tokenizer
        self.db_id = db_id
        self.is_test = is_test # this option does not include target label
        self.random = random.Random(random_seed) # initialized for schema shuffling
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length

        # Load data from JSON files
        with open(f'{data_dir}/data.json') as json_file:
            data = json.load(json_file)["data"]

        label = {}
        if not self.is_test:
            with open(f'{data_dir}/label.json') as json_file:
                label = json.load(json_file)

        self.db_json = None
        if tables_file:
            with open(tables_file) as f:
                self.db_json = json.load(f)

        # Process and encode the samples from the loaded data
        ids = []
        questions = []
        labels = []
        for sample in data:

            # id
            if exclude_unans:
                if sample["id"] in label and label[sample["id"]] == "null":
                    continue
            ids.append(sample['id'])

            # question
            question = self.preprocess_sample(sample, append_schema_info)
            questions.append(question)

            # label
            if not self.is_test:
                labels.append(label[sample["id"]])

        self.ids = ids
        question_encoded = encode_file(tokenizer, questions, max_length=self.max_source_length)
        self.source_ids, self.source_mask = question_encoded['input_ids'], question_encoded['attention_mask']
        if not self.is_test:
            label_encoded = encode_file(tokenizer, labels, max_length=self.max_target_length)
            self.target_ids = label_encoded['input_ids']

    def __len__(self):
        return len(self.source_ids)

    def __getitem__(self, index):
        if self.is_test:
            return {
                "id": self.ids[index],
                "source_ids": self.source_ids[index],
                "source_mask": self.source_mask[index]
            }
        else:
            return {
                "id": self.ids[index],
                "source_ids": self.source_ids[index],
                "source_mask": self.source_mask[index],
                "target_ids": self.target_ids[index]
            }

    def preprocess_sample(self, sample, append_schema_info=False):
        """
        Processes a single data sample, adding schema description to the question.
        """
        question = sample["question"]

        if append_schema_info:
            if self.db_json:
                tables_json = [db for db in self.db_json if db["db_id"] == self.db_id][0]
                schema_description = self.get_schema_description(tables_json)
                question += f" {schema_description}"
            return question
        else:
            return question

    def get_schema_description(self, tables_json, shuffle_schema=False):
        """
        Generates a textual description of the database schema.
        """
        table_names = tables_json["table_names_original"]
        if shuffle_schema:
            self.random.shuffle(table_names)

        columns = [
            (column_name[0], column_name[1].lower(), column_type.lower())
            for column_name, column_type in zip(tables_json["column_names_original"], tables_json["column_types"])
        ]

        schema_description = [""]
        for table_index, table_name in enumerate(table_names):
            table_columns = [column[1] for column in columns if column[0] == table_index]
            if shuffle_schema:
                self.random.shuffle(table_columns)
            column_desc = " , ".join(table_columns)
            schema_description.append(f"{table_name.lower()} : {column_desc}")

        return " | ".join(schema_description)

    def collate_fn(self, batch, return_tensors='pt', padding=True, truncation=True):
        """
        Collate function for the DataLoader.
        """
        ids = [x["id"] for x in batch]
        input_ids = torch.stack([x["source_ids"] for x in batch]) # BS x SL
        masks = torch.stack([x["source_mask"] for x in batch]) # BS x SL
        pad_token_id = self.tokenizer.pad_token_id
        source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks)

        if self.is_test:
            return {
                "source_ids": source_ids,
                "source_mask": source_mask,
                "id": ids,
            }
        else:
            target_ids = torch.stack([x["target_ids"] for x in batch]) # BS x SL
            target_ids = trim_batch(target_ids, pad_token_id)
            return {
                "source_ids": source_ids,
                "source_mask": source_mask,
                "target_ids": target_ids,
                "id": ids,
            }

def trim_batch(input_ids, pad_token_id, attention_mask=None):
    """
    Trims padding from batches of tokenized text.
    """
    keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
    if attention_mask is None:
        return input_ids[:, keep_column_mask]
    else:
        return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])

Below is a simple guide on how to use the `T5Dataset` class:

In [15]:
from transformers import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained('t5-base')
train_dataset = T5Dataset(
    tokenizer=tokenizer,
    data_dir=NEW_TRAIN_DIR,
    tables_file=TABLES_PATH,
    db_id=DB_ID,  # NOTE: `mimic_iv` will be used for codabench
    append_schema_info=False,
)

sample_idx = 1
decoded_sample_src = tokenizer.decode(train_dataset[sample_idx]['source_ids'])
decoded_sample_trg = tokenizer.decode(train_dataset[sample_idx]['target_ids'])
print('\n')
print(f"source ids: {decoded_sample_src}")
print(f"target ids: {decoded_sample_trg}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565




source ids: How is olanzapine (disintegrating tablet) typically consumed?</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
target ids: SELECT DISTINCT prescriptions.route FROM prescriptions WHERE prescriptions.drug = 'olanzapine (disintegrating tablet)'</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><

## Step 4: Construct a Text-to-SQL Baseline Model

In this step, we set up and train a T5 model to translate natural language queries into SQL statements. The process involves several key stages including argument parsing, model initialization, data preparation, and the actual training.

In [16]:
import os
import random
import argparse
from tqdm import tqdm
import numpy as np
import pandas as pd
import gc

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup

In [17]:
def add_default_args(parser):
    """
    Define and set default arguments for the script.
    """
    parser.add_argument("--db_id", type=str, default="mimiciii", help="database name")  # NOTE: `mimic_iv` will be used for codabench
    parser.add_argument("--train_data_dir", type=str, help="train data path")
    parser.add_argument("--valid_data_dir", type=str, help="valid data path")
    parser.add_argument("--test_data_dir", type=str, help="test data path")
    parser.add_argument("--tables_file", type=str, help="table schema path")

    parser.add_argument("--output_dir", type=str, default="outputs", help="output directory")
    parser.add_argument("--output_file", type=str, default="prediction_raw.json", help="output file name")

    # basic parameters
    parser.add_argument("--exp_name", type=str, default=None, help="name of the experiment")
    parser.add_argument("--model_name", type=str, default=None)
    parser.add_argument("--save_checkpoint_path", type=str, default=None)
    parser.add_argument("--load_checkpoint_path", type=str, default=None)

    # training parameters
    parser.add_argument("--train_batch_size", type=int, default=8)
    parser.add_argument("--valid_batch_size", type=int, default=4)
    parser.add_argument("--test_batch_size", type=int, default=4)
    parser.add_argument("--max_source_length", type=int, default=512)
    parser.add_argument("--max_target_length", type=int, default=512)
    parser.add_argument("--train_epochs", type=int, default=100)
    parser.add_argument("--learning_rate", type=float, default=1e-4, help="learning rate")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--warmup_steps", type=int, default=0)
    parser.add_argument("--max_grad_norm", type=str, default=1.0)
    parser.add_argument("--weight_decay", type=float, default=0.1)
    parser.add_argument("--adam_epsilon", type=float, default=1e-8)

    parser.add_argument("--report_every_step", type=int, default=1000)
    parser.add_argument("--eval_every_step", type=int, default=-1000)
    parser.add_argument("--save_every_epoch", type=bool, default=False)
    parser.add_argument("--bf16", type=bool, default=False)
    parser.add_argument("--seed", type=int, default=0)

    # generation parameters
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument("--num_samples", type=int, default=1)
    return parser


def set_seed(args):
    """
    Ensure reproducibility by setting the seed for random number generation.
    """
    np.random.seed(args.seed)
    random.seed(args.seed)
    if torch.cuda.is_available():
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)  # if use multi-GPU
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def save_model(model, optimizer, scheduler, step, best_metric, args, name="last"):
    """
    Save model checkpoints during or after training.
    """
    os.makedirs(args.save_model_path, exist_ok=True)

    save_file_path = os.path.join(args.save_model_path, f"checkpoint_{name}.pth.tar")
    state_dict = {
        "step": step,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict() if scheduler is not None else None,
        "args": args,
        "best_metric": best_metric,
    }
    torch.save(state_dict, save_file_path)
    print(f"Model checkpoint '{name}' saved successfully to {save_file_path}.")


def load_model(model, load_model_path, args, reset_optim=False):
    """
    Load a saved model checkpoint.
    """
    checkpoint = torch.load(load_model_path, weights_only=False)
    model.load_state_dict(checkpoint["model_state_dict"])

    prev_args = checkpoint["args"]
    args = update_args(new_args=args, prev_args=prev_args)

    step = checkpoint["step"]
    best_metric = checkpoint["best_metric"]
    if not reset_optim:
        optimizer, scheduler = set_optim(model, args)
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    else:
        optimizer, scheduler = set_optim(args, model)

    return model, optimizer, scheduler, args, step, best_metric


def update_args(new_args, prev_args):
    """
    Update training arguments with the values saved in the checkpoint.
    """
    for arg in vars(prev_args):
        if arg not in new_args:
            setattr(new_args, arg, getattr(prev_args, arg))
    return new_args


def set_optim(model, args):
    """
    Initialize the optimizer and learning rate scheduler for the model.
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, eps=args.adam_epsilon)
    t_total = (len(train_loader.dataset) // (args.train_batch_size * max(1, args.n_gpu))) * args.train_epochs // args.gradient_accumulation_steps
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
    return optimizer, scheduler


def train(tokenizer, model, train_loader, optimizer, step=0, valid_loader=None, best_metric=-1, scheduler=None, args=None):
    """
    Conduct the training process for a given model.
    """
    train_loss_list = []
    batch_idx = 0

    if best_metric == -1:
        best_metric = np.inf

    # Main training loop
    for epoch in range(1, args.train_epochs + 1):
        model.train()  # Set the model to training mode

        for batch in train_loader:
            # Extract and send batch data to the specified device
            source_ids = batch["source_ids"].to(args.device)
            attention_mask = batch["source_mask"].to(args.device)
            labels = batch["target_ids"].to(args.device)

             # Making padded ids (pad=0) are set to -100, which means ignore for loss calculation
            labels[labels[:,:]==tokenizer.pad_token_id] = -100
            labels = labels.to(args.device)

            # Forward pass and calculate loss
            loss = model(input_ids=source_ids, attention_mask=attention_mask, labels=labels)[0]
            # Normalize loss to account for gradient accumulation
            loss = torch.mean(loss) / args.gradient_accumulation_steps
            loss.backward()

            # Gradient accumulation logic
            if batch_idx % args.gradient_accumulation_steps == 0:
                # Clip gradients to avoid exploding gradient problem
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()  # Update model parameters
                if scheduler:
                    scheduler.step()  # Update learning rate
                model.zero_grad()  # Reset gradients
                step += 1

            train_loss_list.append(loss.item())

            # Get the current learning rate from scheduler or optimizer
            lr = scheduler.get_last_lr()[0] if scheduler else optimizer.param_groups[0]["lr"]

            # Log training progress
            if batch_idx % (args.report_every_step * args.gradient_accumulation_steps) == 0:
                log = f"epoch: {epoch} (step: {step}) | "
                log += f"train loss: {sum(train_loss_list)/len(train_loss_list):.6f} | "
                log += f"lr: {lr:.6f}"
                print(log)
                train_loss_list = []

            # Validation step
            if valid_loader and batch_idx % (args.eval_every_step * args.gradient_accumulation_steps) == 0:
                model.eval()  # Set the model to evaluation mode
                valid_loss_list = []
                with torch.no_grad():
                    for i, batch in enumerate(valid_loader):
                        ids = batch["source_ids"].to(args.device)
                        mask = batch["source_mask"].to(args.device)
                        labels = batch["target_ids"].to(args.device)

                        labels[labels[:,:]==tokenizer.pad_token_id] = -100
                        labels = labels.to(args.device)

                        valid_loss = model(input_ids=ids, attention_mask=mask, labels=labels)[0]
                        valid_loss_list.append(valid_loss.item())

                    # Calculate average validation loss
                    valid_loss = sum(valid_loss_list) / len(valid_loss_list)

                    log = f"epoch: {epoch} (step: {step})"
                    log += f" | valid_loss: {valid_loss:.6f}"
                    print(log)

                    if best_metric > valid_loss:
                        best_metric = valid_loss
                        save_model(model, optimizer, scheduler, step, best_metric, args, name="best")

                    model.train()  # Set the model back to training mode

            batch_idx += 1

            # Clear CUDA cache if it's a good time
            if batch_idx % (args.eval_every_step * args.gradient_accumulation_steps) == 0:
                torch.cuda.empty_cache()
                gc.collect()  # Trigger Python garbage collection

        # Save a checkpoint at the end of each epoch if specified in args
        if args.save_every_epoch:
            save_model(model, optimizer, scheduler, epoch, best_metric, args, name=f"{epoch}")


def generate_sql(tokenizer, model, eval_loader, args):
    # Set the model to evaluation mode. This turns off certain layers like dropout.
    model.eval()

    # Disable gradient calculations for efficiency, as they are not needed in evaluation.
    with torch.no_grad():
        out_eval = []

        # Iterate over batches of data in the evaluation dataset.
        for batch in tqdm(eval_loader):
            # Extract relevant data from the batch.
            ids = batch["id"]
            source_ids = batch["source_ids"].to(args.device)
            attention_mask = batch["source_mask"].to(args.device)

            # Generate predictions using the model.
            generation_output = model.generate(
                input_ids=source_ids,
                max_length=args.max_target_length,
                num_beams=args.num_beams,
                return_dict_in_generate=True,
                output_scores=True,
            )

            # Move the generated sequences to the CPU if using CUDA.
            preds = generation_output["sequences"].cpu() if args.device == "cuda" else generation_output["sequences"]

            # Process logits and calculate probabilities and entropies.
            logits = torch.stack(generation_output["scores"], dim=1)[:: int(args.num_beams / args.num_samples)]
            logits = logits.cpu() if args.device == "cuda" else logits
            probs = torch.softmax(logits, dim=2).float()
            log_probs = torch.log_softmax(logits, dim=2).float()
            entropies = (torch.sum(probs * log_probs, axis=2) * (-1)).numpy()

            # Determine if the current batch is for testing or training.
            is_test = True
            if "target_ids" in batch:
                is_test = False
                reals = batch["target_ids"]

            # Initialize lists to store predictions, probabilities, and entropies.
            pred_list = []
            entropy_list = []

            # Process each prediction in the batch.
            for idx in range(len(preds)):
                pred = preds[idx]
                pred_tensor = preds[idx][1:]
                entropy_truncated = entropies[idx].tolist()

                # Truncate the prediction at the end-of-sequence token, if present.
                if tokenizer.eos_token_id in pred_tensor:
                    pred_eos_idx = torch.nonzero(pred_tensor == tokenizer.eos_token_id)[0].item()
                    entropy_truncated = entropy_truncated[: pred_eos_idx + 1]

                pred_list.append(pred)
                entropy_list.append(entropy_truncated)

            # Construct the output results for each prediction.
            for idx in range(len(preds)):
                result = {
                    "id": ids[idx],
                    "question": tokenizer.decode(source_ids[idx], skip_special_tokens=True),
                    "pred": tokenizer.decode(pred_list[idx], skip_special_tokens=True),
                    "entropy": entropy_list[idx],
                }

                # Include the real target output if it's training data.
                if not is_test:
                    result["real"] = tokenizer.decode(reals[idx], skip_special_tokens=True)

                out_eval.append(result)

            # Clear cache after processing each batch
            torch.cuda.empty_cache()
            gc.collect()

        return out_eval

### Argument Parsing

First, we define and parse the necessary command-line arguments. These arguments configure the model, specify the paths for training, validation, and test data, and set various training parameters.

In [18]:
# Define and parse command line arguments for model configuration
ARGS_STR = f"""
--exp_name=t5-baseline \
--model_name=t5-base \
--train_data_dir={NEW_TRAIN_DIR} \
--valid_data_dir={NEW_VALID_DIR} \
--test_data_dir={NEW_TEST_DIR} \
--tables_file={TABLES_PATH} \
--train_epochs=10 \
--train_batch_size=4 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-3 \
--report_every_step=10 \
--eval_every_step=10 \
--bf16=1
"""

# Parse arguments
parser = argparse.ArgumentParser()
parser = add_default_args(parser)
args = parser.parse_args(ARGS_STR.split())

# Configure CUDA settings
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Set random seed for reproducibility
set_seed(args)

# Determine device for training and set model save path
args.device = "cuda" if torch.cuda.is_available() else "cpu"
args.n_gpu = torch.cuda.device_count()
args.save_model_path = os.path.join(args.output_dir, args.exp_name)

### Model Initialization

Here, we initialize the T5 model and tokenizer. The model is configured to the appropriate device, and the tokenizer is extended with special tokens specific to our SQL translation task.

In [23]:
# Initialize T5 model and set device
model = T5ForConditionalGeneration.from_pretrained(args.model_name)
model = model.to(args.device)

# Convert model to bfloat16 precision if required
if args.bf16:
    print("bfloat16 precision will be used")
    model = model.to(torch.bfloat16)

# Initialize tokenizer with additional SQL tokens
add_tokens = ["<", "<=", "<>"]
tokenizer = T5Tokenizer.from_pretrained(args.model_name)
tokenizer.add_tokens(add_tokens)

# Resize model token embeddings
model.resize_token_embeddings(len(tokenizer))

bfloat16 precision will be used


Embedding(32103, 768)

### Data Preparation

We prepare the datasets for training, validation, and testing. This involves loading the data from specified paths and processing it into a format compatible with the T5 model.

In [24]:
# Define parameters for dataset preparation
dataset_kwargs = dict(
    db_id=args.db_id,
    max_source_length=args.max_source_length,
    max_target_length=args.max_target_length,
    tables_file=args.tables_file,
)

# Initialize datasets for different phases
train_dataset = T5Dataset(tokenizer, args.train_data_dir, is_test=False, exclude_unans=True, **dataset_kwargs)
valid_dataset = T5Dataset(tokenizer, args.valid_data_dir, is_test=False, exclude_unans=False, **dataset_kwargs)
test_dataset = T5Dataset(tokenizer, args.test_data_dir, is_test=True, exclude_unans=False, **dataset_kwargs)

# Create DataLoader instances for batch processing
train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, collate_fn=train_dataset.collate_fn, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=args.valid_batch_size, collate_fn=valid_dataset.collate_fn, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, collate_fn=test_dataset.collate_fn, shuffle=False)

### Optimizer and Scheduler

Setting up the optimizer and learning rate scheduler is crucial for controlling and optimizing the training process.

In [25]:
# Load existing model or initialize optimizer and scheduler
if args.load_checkpoint_path:
    model, optimizer, scheduler, args, step, best_metric = load_model(model, args.load_checkpoint_path, args)
else:
    step, best_metric = 0, -1
    optimizer, scheduler = set_optim(model, args)

### Training the Model

Finally, we train the model on the dataset. The training process involves learning to generate SQL queries from textual descriptions through iterative forward and backward passes, loss computation, and parameter updates.

In [26]:
# Start the training process
train(
    tokenizer=tokenizer,
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    step=step,
    best_metric=best_metric,
    args=args,
)

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


epoch: 1 (step: 1) | train loss: 13.437500 | lr: 0.001000
epoch: 1 (step: 1) | valid_loss: 4.786539
Model checkpoint 'best' saved successfully to outputs/t5-baseline/checkpoint_best.pth.tar.
epoch: 1 (step: 11) | train loss: 3.004688 | lr: 0.000999
epoch: 1 (step: 11) | valid_loss: 2.121398
Model checkpoint 'best' saved successfully to outputs/t5-baseline/checkpoint_best.pth.tar.
epoch: 1 (step: 21) | train loss: 1.575000 | lr: 0.000998
epoch: 1 (step: 21) | valid_loss: 1.594746
Model checkpoint 'best' saved successfully to outputs/t5-baseline/checkpoint_best.pth.tar.
epoch: 1 (step: 31) | train loss: 1.060547 | lr: 0.000997
epoch: 1 (step: 31) | valid_loss: 1.409940
Model checkpoint 'best' saved successfully to outputs/t5-baseline/checkpoint_best.pth.tar.
epoch: 1 (step: 41) | train loss: 0.858594 | lr: 0.000996
epoch: 1 (step: 41) | valid_loss: 1.262920
Model checkpoint 'best' saved successfully to outputs/t5-baseline/checkpoint_best.pth.tar.
epoch: 1 (step: 51) | train loss: 0.74257

## Step 5: Initial Model Evaluation on All Queries

In this step, we will evaluate the model's performance across all queries, using the Reliability Score (RS) as our evaluation metric. This will provide a baseline understanding of the model's reliability scroe without filtering for unanswerable queries.

In [27]:
from scoring_program.scoring_utils import execute_all, reliability_score, penalize
from scoring_program.postprocessing import post_process_sql

# Load the best-performing model checkpoint
model, optimizer, scheduler, args, step, best_metric = load_model(
    model,
    os.path.join(args.save_model_path, 'checkpoint_best.pth.tar'),
    args,
)

# Perform inference on the validation set
valid_eval = generate_sql(tokenizer, model, valid_loader, args)

# Post-process SQL queries for evaluation
label = {sample['id']: post_process_sql(sample['real']) for sample in valid_eval}
label_y = {sample['id']: post_process_sql(sample['pred']) for sample in valid_eval}
id2maxent = {sample['id']: max(sample['entropy']) for sample in valid_eval}  # NOTE: Abstain strategy not used here

# Calculate the Reliability Score (RS) across all queries
real_dict = {id_: post_process_sql(label[id_]) for id_ in label}
pred_dict = {id_: post_process_sql(label_y[id_]) for id_ in label_y}
assert set(real_dict) == set(pred_dict), "IDs do not match"

real_result = execute_all(real_dict, db_path=DB_PATH, tag='real')
pred_result = execute_all(pred_dict, db_path=DB_PATH, tag='pred')

scores, score_dict = reliability_score(real_result, pred_result, return_dict=True)
accuracy0 = penalize(scores, penalty=0)
accuracy5 = penalize(scores, penalty=5)
accuracy10 = penalize(scores, penalty=10)
accuracyN = penalize(scores, penalty=len(scores))

print(f"RS without filtering unanswerable queries: Accuracy0: {accuracy0}, Accuracy5: {accuracy5}, Accuracy10: {accuracy10}, AccuracyN: {accuracyN}")

100%|██████████| 257/257 [24:19<00:00,  5.68s/it]


RS without filtering unanswerable queries: Accuracy0: 0.8097560975609757, Accuracy5: -0.14146341463414633, Accuracy10: -1.0926829268292684, AccuracyN: -194.19024390243902


## Step 6: Model Evaluation Considering Unanswerable Questions

This step refines the evaluation process by considering unanswerable questions ($Q_{una}$). We apply a threshold based on maximum entropy to filter out uncertain predictions. The RS is then recalculated to assess the model's performance more accurately in scenarios where abstaining from answering difficult questions is preferable.

Here, compared to the Step 5, aims to provide a clear and concise evaluation process, highlighting the importance of reliability in text-to-SQL modeling, especially when dealing with complex or uncertain queries.

In [28]:
def get_threshold(id2maxent, score_dict):
    """
    Determine the optimal threshold for filtering based on maximum entropy and scores.
    """
    values = []
    scores = []
    for key, val in id2maxent.items():
        values.append(val)
        scores.append(score_dict[key])

    sorted_indices = np.argsort(values)
    sorted_values = np.array(values)[sorted_indices]
    sorted_scores = np.array(scores)[sorted_indices]

    max_score, threshold = 0, -1
    for idx in range(len(sorted_scores)):
        cum_score = sum(sorted_scores[:idx+1])
        if cum_score > max_score:
            print('cum_score > max_score')
            max_score, threshold = cum_score, sorted_values[idx-1]

    return threshold  # We abstain if maxent is greater than this threshold.

In [29]:
# Calculate threshold for filtering unanswerable queries
threshold = get_threshold(id2maxent, score_dict)
print(f"Threshold for filtering: {threshold}")

# Apply threshold to filter out uncertain predictions
label_y = {sample['id']: 'null' if threshold < max(sample['entropy']) else post_process_sql(sample['pred']) for sample in valid_eval}

# Recalculate RS with filtered predictions
real_dict = {id_: post_process_sql(label[id_]) for id_ in label}
pred_dict = {id_: post_process_sql(label_y[id_]) for id_ in label_y}

scores_filtered = reliability_score(real_dict, pred_dict)

accuracy0_filtered = penalize(scores_filtered, penalty=0)
accuracy5_filtered = penalize(scores_filtered, penalty=5)
accuracy10_filtered = penalize(scores_filtered, penalty=10)
accuracyN_filtered = penalize(scores_filtered, penalty=len(scores))

# Output the refined RS scores with abstention
print(f"RS with filtered unanswerable queries: Accuracy0: {accuracy0_filtered}, Accuracy5: {accuracy5_filtered}, Accuracy10: {accuracy10_filtered}, AccuracyN: {accuracyN_filtered}")

cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score > max_score
cum_score 

## Step 7: Test data inference

Now, we conduct inference using the original validation set as our test data. Our model generates SQL predictions, applying the previously defined entropy threshold to filter out uncertain responses. Predictions with high uncertainty are marked `null`, indicating the model's strategic abstention from potentially incorrect outputs.

In [30]:
# Conduct inference on the test set (For now, we use original validation set as test data)
test_eval = generate_sql(tokenizer, model, test_loader, args)

# Apply the threshold to uncertain predictions
label_y = {sample['id']: 'null' if threshold < max(sample['entropy']) else post_process_sql(sample['pred']) for sample in test_eval}

100%|██████████| 291/291 [28:09<00:00,  5.80s/it]


We save these predictions to a JSON file in a designated results directory, creating the directory if necessary. A final check confirms the presence of our output file, ensuring our test inference process is complete and successful.

In [31]:
import locale; locale.getpreferredencoding = lambda: "UTF-8" # if necessary

In [32]:
from utils.data_io import write_json as write_label

# Save the filtered predictions to a JSON file
os.makedirs(RESULT_DIR, exist_ok=True)
SCORING_OUTPUT_DIR = os.path.join(RESULT_DIR, 'prediction.json')
write_label(SCORING_OUTPUT_DIR, label_y)

# Verify the file creation
print("Listing files in RESULT_DIR:")
!ls {RESULT_DIR}

Listing files in RESULT_DIR:
prediction.json


## Step 8: Submission

In this final step, we'll prepare and submit our results to the Codabench competition.

In [None]:
# Change to the directory containing the prediction file
%cd {RESULT_DIR}

# Compress the prediction.json file into a ZIP archive
!zip predictions.zip prediction.json

/content/ehrsql-2024/sample_result_submission
  adding: prediction.json (deflated 54%)


- Submission File: Ensure that the `predictions.zip` file contains only the `prediction.json` file. This ZIP archive is the required format for submission to Codabench.

- Submitting on Codabench: Navigate to the Codabench competition page and go to the **My Submissions** tab. Upload the `predictions.zip` file following the provided instructions. Make sure to adhere to any guidelines or submission requirements detailed on the competition page.