# **Project Links**

### [Project GitHub Repository](https://github.com/maxwelllwang/cs598EHRTransformer/tree/main)

### [Project Video](https://drive.google.com/file/d/1v-igXoba1TicqjT4ozxfhXb0kaP0A9W2/view?usp=sharing)

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
%cd /content/gdrive/MyDrive/CS 598 Project/cs598ehrupload
# %cd /content/gdrive/MyDrive/cs598ehrupload
! ls

# change these to the path of your data
diagnoses_file_path = r'/content/gdrive/MyDrive/CS 598 Project/cs598ehrupload/mimic_data/diagnoses_icd.csv.gz'
map_file_path = r'/content/gdrive/MyDrive/CS 598 Project/cs598ehrupload/mimic_data/D_ICD_DIAGNOSES.csv'
icd9_embedding_txt = r"/content/gdrive/MyDrive/CS 598 Project/cs598ehrupload/embeddings/ic9_embeddings.txt"

#**Introduction**

## *Background:*
Current pretraining objectives in predictive EHR-based models are limited to predicting a fraction of ICD codes within a patient’s visit, when in reality, patients usually have multiple, often highly-correlated diseases. In addition, current models are unable to accurately predict the timeline of correlated diagnoses and could lead to missed opportunities in preventative care. Predictive tasks surrounding healthcare data can be challenging due to the complexity of healthcare data, which includes high-dimensional and often incomplete patient data over time. The state of the art methods used to solve similar healthcare related problems have been transformer based deep learning models trained on extensive datasets and fine-tuned for specific tasks. Despite their success, current models are often fine-tuned to focus on predicting a limited set of outcomes, thus overlooking the interconnected nature of various health conditions.

## *Original Paper: TransformEHR*
The paper presents "TransformEHR," a novel generative encoder-decoder model leveraging transformer architecture, specifically designed for predicting future patient outcomes based on their longitudinal EHRs. The authors utilized techniques like visit-masking and time embedding to achieve results that outperform the other state of the art models. For example, when testing their encoder-decoder model against an encoder only model, the authors were able to achieve an, “improvement of 95%CI: 0.74%–1.16%, p < 0.001 in AUROC across all diseases/outcomes tested.” While the paper boasted strong results on a variety of both common and uncommon diseases, the authors mentioned that their work was related to predictive model studies focused on intentional self-harm. TransformEHR performed exceptionally well within this subspace and esteemed to reduce incremental cost-effective ratio by $109k per quality-adjusted life-years.


#**Scope of Reproducibility**

While trying to recreate the TransformEHR model as per the original paper, we discovered a lot of missing code in the project repo, specifically in the helper class `dataset.py`. This file was meant to contain the DataCollator functions and Tokenizers to be used in the model. Since the model required the input formatted a specific way, missing these functions and not knowing how they were implemented or what the expected input format was made it nearly impossible to get the model running.

As a result, we decided to base our model off of BEHRT, another type of transformer model used to predict EHR codes. We also integrated this with prelearned ICD code embeddings.


## *BEHRT Overview*

BEHRT is a transformer model inspired by BERT, which aims to accomplish a similar goal to TransformEHR of predicting patient's future diseases/outcomes based on EHR data of their past visits. The model takes each diagnosis as a "word," each visit as a "sentence," and the entire medical history as a "document," and uses multi-head self-attention and masked language modeling. BEHRT integrates disease embeddings, positional encodings, age, and visit segment information, and uses deep bidirectional representation to make predictions of a patient's medical journey.

#**Methodology**

##*Environment*

Python version: 3.10

###Dependencies/Packages:

In [None]:
!pip install pytorch-pretrained-bert
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
from torch.utils.data.dataset import Dataset
import os
import torch
import torch.nn as nn
import sklearn.metrics as skm
import math
from torch.utils.data.dataset import Dataset
import random
import numpy as np
import torch
import time
import transformers
import json
import pytorch_pretrained_bert as Bert

##*Data*

The data we use for this model comes from the **[MIMIC-III dataset](https://physionet.org/content/mimiciii/1.4/)**, which can be accessed via PhysioNet to anyone with completed training/permissions. Specifically, we use the files *diagnoses_icd.csv.g* and *D_ICD_DIAGNOSES.csv*.


Please note that the filepaths for variables `diagnoses_file_path` (which will read from *diagnoses_icd.csv.gz*) and `map_file_path` (which will read from *D_ICD_DIAGNOSES.csv*) may need to be updated depending on where they are located locally.


### Preprocessing

In [None]:
#Data stuff
# diagnoses_file_path = r'/content/gdrive/MyDrive/cs598ehrupload/mimic_data/diagnoses_icd.csv.gz'
# map_file_path = r'/content/gdrive/MyDrive/cs598ehrupload/mimic_data/D_ICD_DIAGNOSES.csv'


# diagnoses_file_path = r'/content/gdrive/MyDrive/CS 598 Project/cs598ehrupload/sample_data/DIAGNOSES_ICD.csv'
# map_file_path = r'/content/gdrive/MyDrive/CS 598 Project/cs598ehrupload/sample_data/D_ICD_DIAGNOSES.csv'




diagnoses_df = pd.read_csv(diagnoses_file_path)
print(diagnoses_df.columns)
map_df = pd.read_csv(map_file_path)

icd_code_col_name = 'icd_code'   # NOTE: column name 'icd9_code' in DIAGNOSES_ICD.csv from MIMIC demo dataset is called 'icd_code' in full dataset


#list of patient id's that have been diagnosed with something
#make everything sequential and not patient_id key based
patient_ids = diagnoses_df['subject_id'].unique().tolist()

#2d array where each nested list is the hadm_id for each visit
visits = diagnoses_df.groupby('subject_id')['hadm_id'].apply(lambda x: list(set(x))).tolist()

#3d array contains a list of visits with respective ICD9 code per visit
# NOTE: column name 'icd9_code' in DIAGNOSES_ICD.csv from MIMIC demo dataset is called 'icd_code' in full dataset
patient_visits = (
    diagnoses_df.groupby(['subject_id', 'hadm_id'])[icd_code_col_name].apply(list).groupby(level=0).apply(list).tolist()
)

#dict of {icd9_code : short_title}
#not all icd9_codes which are present in DIAGNOSES_ICD.csv are present in D_ICD_DIAGNOSES.csv, so not all codes will have a title
icd9_to_title = pd.Series(map_df['short_title'].values, index=map_df['icd9_code']).to_dict()

print("Patient ID:", patient_ids[53])
print("num of visits for patient: " , len(visits[53]))
for visit in range(len(visits[53])):
    print(f"\t{visit}-th visit id:", visits[53][visit])
    print(f"\t{visit}-th visit diagnosis codes:", patient_visits[53][visit])
    print(f"\t{visit}-th visit diagnosis short titles:",
[icd9_to_title.get(label, label) for label in patient_visits[53][visit]])


The following code outputs some descriptive stats on the complete dataset to determine how many patients have ICD-9 vs ICD-10 codes as part of their visit. This is not applicable to the demo dataset because that only contains ICD-9 codes.

In [None]:
#Descriptive Statistics

#Total rows with icd 9/10
count_icd_version_10 = (diagnoses_df['icd_version'] == 10).sum()
count_icd_version_9 = (diagnoses_df['icd_version'] == 9).sum()

print("Number of rows with icd_version = 10:", count_icd_version_10)
print("Number of rows with icd_version = 9:", count_icd_version_9)

#Num of unique ICD 9/10 codes
unique_icd9_codes = diagnoses_df[diagnoses_df['icd_version'] == 9]['icd_code'].nunique()
unique_icd10_codes = diagnoses_df[diagnoses_df['icd_version'] == 10]['icd_code'].nunique()

print("Number of unique ICD-9 codes:", unique_icd9_codes)
print("Number of unique ICD-10 codes:", unique_icd10_codes)

#num patients with atleast 1 ICD 9 code
icd9_df = diagnoses_df[diagnoses_df['icd_version'] == 9]
unique_patients_with_icd9 = icd9_df['subject_id'].unique()
num_patients_with_icd9 = len(unique_patients_with_icd9)
print("Number of patients with at least one ICD-9 code:", num_patients_with_icd9)

#num patients with both ICD 9 / 10 codes
grouped = diagnoses_df.groupby('subject_id')['icd_version'].agg(set)
patients_with_both = grouped[grouped.apply(lambda x: {9, 10}.issubset(x))]
print("Number of patients with both ICD-9 and ICD-10 codes:", len(patients_with_both))

# num pateients with ONLY ICD 9 codes
patient_versions = diagnoses_df.groupby('subject_id')['icd_version'].unique()
patients_with_only_icd9 = patient_versions[patient_versions.apply(lambda x: set(x) == {9})]
num_patients_only_icd9 = len(patients_with_only_icd9)
print("Number of patients with only ICD-9 codes:", num_patients_only_icd9)

We only want to consider patients with 4 or more visits so that we can split their visits into label and feature data later on.

In [None]:
# Total num of usable patients
num_unique_patients = diagnoses_df['subject_id'].nunique()
print("Number of unique patients:", num_unique_patients)


patient_versions = diagnoses_df.groupby('subject_id')['icd_version'].unique()
patients_with_only_icd9 = patient_versions[patient_versions.apply(lambda x: set(x) == {9})].index
icd9_patients_df = diagnoses_df[diagnoses_df['subject_id'].isin(patients_with_only_icd9)]

# visit_counts = diagnoses_df.groupby('subject_id')['hadm_id'].nunique()    # for demo data
visit_counts = icd9_patients_df.groupby('subject_id')['hadm_id'].nunique()  # for main data
patients_more_than_three_visits = visit_counts[visit_counts > 3].index
num_patients = len(patients_more_than_three_visits)
print("Number of patients with only ICD-9 codes and more than 3 visits:", num_patients)

### Vocab Dicts

We create a token vocabulary to map each ICD code to a unique index. We also add special tokens as padding and separators.

In [None]:
#Creates Token Vocabulary

import json
# TODO i need full vocabulary
truncated_codes = {str(code)[:3] for code in map_df['icd9_code']}
sorted_truncated_codes = sorted(truncated_codes)  # Sort codes

# Define special tokens with a specific order
special_tokens = ['[PAD]', '[CLS]', '[SEP]', '[UNK]', '[MASK]']

# Create dictionary mapping each code to a unique index, starting with special tokens
token2idx = {token: idx for idx, token in enumerate(special_tokens + sorted_truncated_codes)}

# Print the number of unique codes to verify
print("Number of unique truncated codes:", len(token2idx) - len(special_tokens))

# Print token to index mapping
#print("Token to Index Mapping:", token2idx)

# Save the token2idx dictionary to a JSON file for later use
with open('token2idx.json', 'w') as f:
    json.dump(token2idx, f)

    # for labels just make the full dict with all 1042 codes

Similarly, we create a vocab dictionary for labels, excluding the special tokens, except for 'UNK', or unknown, which refers to an unknown or missing ICD code.

In [None]:
#create the label vocab
#same as token vocab but no SEP, CLS, PAD, MASk (Leave UNK)
# DID NOT reorganize indices, may need to

special_tokens = ['[PAD]', '[CLS]', '[SEP]', '[MASK]']

labelVocab = {token: idx for token, idx in token2idx.items() if token not in special_tokens}


def format_label_vocab(token2idx):
    token2idx = token2idx.copy()
    del token2idx['[PAD]']
    del token2idx['[SEP]']
    del token2idx['[CLS]']
    del token2idx['[MASK]']
    token = list(token2idx.keys())
    labelVocab = {}
    for i,x in enumerate(token):
        labelVocab[x] = i
    return labelVocab

labelVocab = format_label_vocab(token2idx)

# Print the new dictionary to verify
print(len(labelVocab))

### Label and Feature Splits

First three lines of code used for filtering out patients with ICD-10 code -- commented out since it's not applicable for demo dataset.

In [None]:
#preparing the label and feature splits

patients_with_icd10 = diagnoses_df[diagnoses_df['icd_version'] == 10]['subject_id'].unique()

icd9_only_df = diagnoses_df[~diagnoses_df['subject_id'].isin(patients_with_icd10)]

#Finds patients with only icd9 codes
icd9_only_df = icd9_only_df[icd9_only_df['icd_version'] == 9]

# icd9_only_df = diagnoses_df   # for demo data
visit_counts = icd9_only_df.groupby('subject_id')['hadm_id'].nunique()

patients_more_than_three_visits = visit_counts[visit_counts > 3].index

#final DataFrame of patients with only ICD-9 codes and more than three visits
final_df = icd9_only_df[icd9_only_df['subject_id'].isin(patients_more_than_three_visits)]

patient_visits = final_df.groupby(['subject_id', 'hadm_id'])[icd_code_col_name].apply(list).reset_index()
patient_visits = patient_visits.groupby('subject_id')[icd_code_col_name].apply(list)

#extract all ICD9 codes but only take first 3 digits
map_df['truncated_icd9'] = map_df['icd9_code'].apply(lambda x: x[:3])

unique_truncated_codes = sorted(map_df['truncated_icd9'].unique())
code_to_index = {code: idx for idx, code in enumerate(unique_truncated_codes)}

def encode_labels(label_codes, labelVocab):
    # print(len(label_codes), len(labelVocab))
    multi_hot = [0] * len(labelVocab)
    #print(len)

    for code in label_codes:
        if code in labelVocab:
            index = labelVocab[code]
            multi_hot[index] = 1
        else:
            print(f"Warning: Code {code} not found in labelVocab.")

    return multi_hot



features = []
labels = {}

for subject_id, visits in patient_visits.items():
    if len(visits) > 3:
        # split_index = len(visits) // 2    # splits visits down middle

        split_index = random.randint(len(visits)-2, len(visits)-1)    # random split index from 3 to end of visits array


        # Initialize the feature list with 'CLS'
        feature_visits = ['CLS']

        # Append codes and 'SEP' after each visit up to the split index
        for sublist in visits[:split_index]:
            visit_codes = [code for code in sublist]
            visit_codes.append('SEP')
            feature_visits.extend(visit_codes)

        # Append the feature list to the features
        features.append(feature_visits)

        # Gather label codes from visits after the split index
        label_codes = [code[:3] for sublist in visits[split_index:] for code in sublist]

        # Encode the labels and store using subject_id as key
        labels[subject_id] = encode_labels(label_codes, labelVocab)


#output for one set of features and labels
if features and labels:
    print("Example Features: ", features[0])
    print("length Labels: " , len(labels))
    print("example label", len(labels[list(patient_visits.items())[0][0]]))

### Custom Dataset

This custom dataset NextVisit prepares input sequences with appropriate tokens, masks, and labels for the model.

In [None]:
#custom dataset helpers

def index_seg(tokens, symbol=2):
    flag = 0
    seg = []

    for token in tokens:
        if token == symbol:
            seg.append(flag)
            if flag == 0:
                flag = 1
            else:
                flag = 0
        else:
            seg.append(flag)
    return seg


def position_idx(tokens, symbol=2):
    pos = []
    flag = 0

    for token in tokens:
        if token == symbol:
            pos.append(flag)
            flag += 1
        else:
            pos.append(flag)
    return pos

In [None]:
#custom dataset

class NextVisit(Dataset):
    def __init__(self, token2idx, labels, patient_visits, max_len):
        self.token2idx = token2idx
        self.labels = labels
        self.patient_visits = patient_visits
        self.max_len = max_len

    def __len__(self):
        return len(self.patient_visits)

    def __getitem__(self, index):
        # code_idxs, position idx, mask, segment, label
        # Retrieve patient data by index
        patient_id = list(self.patient_visits.keys())[index]
        codes = self.patient_visits[patient_id]

        # Initialize sequence with [CLS] token
        sequence = [self.token2idx['[CLS]']]



        # Add each code to the sequence and append [SEP] after each visit
        for visit in codes:
            #change 'code' to be truncated version!!!!!!
            sequence.extend([self.token2idx.get(code[:3], self.token2idx['[UNK]']) for code in visit])
            sequence.append(self.token2idx['[SEP]'])



        # Cut or pad the sequence to the maximum length
        if len(sequence) > self.max_len:
            sequence = sequence[:self.max_len]
        else:
            sequence.extend([self.token2idx['[PAD]']] * (self.max_len - len(sequence)))

        position_indices = position_idx(sequence)
        segment = index_seg(sequence)


        # Create a mask for the sequence
        mask = [1 if token != self.token2idx['[PAD]'] else 0 for token in sequence]

        # Prepare the labels
        label = torch.tensor(self.labels[patient_id], dtype=torch.float)

        return torch.tensor(sequence, dtype=torch.long), torch.tensor(position_indices, dtype=torch.long), torch.tensor(mask, dtype=torch.long), torch.tensor(segment, dtype=torch.long), label

In [None]:
#testint stuff out

max_len = 512

patient_id = list(patient_visits.keys())[0]
codes = patient_visits[patient_id]

# Initialize sequence with [CLS] token
sequence = [token2idx['[CLS]']]


# Add each code to the sequence and append [SEP] after each visit
for visit in codes:
    sequence.extend([token2idx.get(code[:3], token2idx['[UNK]']) for code in visit])
    sequence.append(token2idx['[SEP]'])

print(len(sequence))
print(sequence)


# Cut or pad the sequence to the maximum length
if len(sequence) > max_len:
    sequence = sequence[:max_len]
else:
    sequence.extend([token2idx['[PAD]']] * (max_len - len(sequence)))

position_indices = position_idx(sequence)
segment = index_seg(sequence)


mask = [1 if token != token2idx['[PAD]'] else 0 for token in sequence]

label = torch.tensor(labels[patient_id], dtype=torch.float)

print(len(sequence))
print(sequence)

print(len(position_indices))
print("position: ", position_indices)

print(len(segment))
print("segment: ", segment)

print(len(mask))
print("mask: ", mask)

# print(len(labels[0]))
print("labels: " , label)

'''
A bunch of 3 ([UNK]) because patient_visits includes full ICD Codes but token2idx has truncated version
Need to do version matching before feeding into model.
'''

##*Model*

###Original Paper: TransformEHR:
* Citation: Yang, Z., Mitra, A., Liu, W. et al. TransformEHR: transformer-based encoder-decoder generative model to enhance prediction of disease outcomes using electronic health records. Nat Commun 14, 7857 (2023). https://doi.org/10.1038/s41467-023-43715-z
* Repo: https://github.com/whaleloops/TransformEHR/tree/main


###BEHRT
* Citation: Li, Y., Rao, S., Solares, J.R.A. et al. BEHRT: Transformer for Electronic Health Records. Sci Rep 10, 7155 (2020). https://doi.org/10.1038/s41598-020-62922-y


* Repo: https://github.com/deepmedicine/BEHRT/tree/master

###Model Description:

The model includes the model definitation which usually is a class, model training, and other necessary parts.

* Model architecture: The model uses 6 hidden layers, 12 attention heads, intermediate layer size of 512, and hidden size of 288.
* We used relu and gelu activation functions for the encoder and pooler.
* Weight decay = 0.02
* Model classes:
  * BertEmbeddings
  * BertModel
  * BertForMultiLabelPrediction
* Our model utilizes an ICD9Embeddings class for pretrained ICD embeddings.

## **If you want to load a trained model search for the section (Loading presaved/trained model)**



### Implementation Code

Setting up parameters and configurations to use in model.

In [None]:
from transformers import BertConfig, BertPreTrainedModel , BertModel
from transformers.models.bert.modeling_bert import BertEncoder, BertPooler, BertEmbeddings

global_params = {
    'batch_size': 64,
    'gradient_accumulation_steps': 1,
    'device': 'cuda',
    'output_dir': '',  # output dir
    'best_name': '', # output model name
    'save_model': True,
    'max_len_seq': 512,
    'max_age': 110,
    'month': 1,
    'age_symbol': None,
    'min_visit': 5
}

feature_dict = {
    'age': False,
    'seg': False,
    'posi': True
}


optim_config = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.02
}
model_config = {
    'vocab_size': len(labelVocab), # number of disease + symbols for word embedding 1047
    'hidden_size': 300, # word embedding and seg embedding hidden size
    #'seg_vocab_size': 2, # number of vocab for seg embedding
    #'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding
    'max_position_embedding': global_params['max_len_seq'], # maximum number of tokens
    'hidden_dropout_prob': 0.3, # dropout rate
    'num_hidden_layers': 4, # number of multi-head attention layers required
    'num_attention_heads': 12, # number of attention heads
    'attention_probs_dropout_prob': 0.45, # multi-head attention dropout rate
    'intermediate_size': 512, # the size of the "intermediate" layer in the transformer encoder
    'hidden_act': 'relu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'initializer_range': 0.02, # parameter weight initializer range
}

class BertConfig(BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings = config.get('max_position_embedding'),
            initializer_range=config.get('initializer_range'),
        )
        #self.seg_vocab_size = config.get('seg_vocab_size')
        #self.age_vocab_size = config.get('age_vocab_size')

In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset

full_dataset = NextVisit(token2idx, labels, patient_visits, max_len)

train_idx, test_idx = train_test_split(range(len(full_dataset)), test_size=0.2, random_state=42)

train_dataset = Subset(full_dataset, train_idx)
test_dataset = Subset(full_dataset, test_idx)

train_loader = DataLoader(train_dataset, batch_size=global_params['batch_size'], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=global_params['batch_size'], shuffle=False)

In [None]:


class ICD9Embeddings:
    def __init__(self, filename="./ic9_embeddings.txt"):
        self.embedding_filename = filename
        # self.icd9_to_embeddings = self._read_embeddings()
        self.icd9_to_embeddings = self._read_embeddings_trunc()
        self.embedding_size = 300

    # this reads the 5 digit code correctly
    def _read_embeddings(self):
        icd9_to_embeddings = {}
        with open(self.embedding_filename, "r") as infile:
            data = infile.readlines()
            for row in data:
                eles = row.strip().split(" ")
                name = eles[0]
                embedding = eles[1:]
                code = name[4:]

                code = code.replace(".", "")
                if len(code) > 5 or len(code) < 3:
                    print("code is bad")

                icd9_to_embeddings[code] = torch.tensor(
                    [float(i) for i in embedding], dtype=torch.float32
                )
        return icd9_to_embeddings

    # this does the opposite of greedy it basically just takes the last icd_9 with the first three that match
    def _read_embeddings_trunc(self):
        icd9_to_embeddings = {}
        codes_lost = 0
        with open(self.embedding_filename, "r") as infile:
            data = infile.readlines()
            for row in data:
                eles = row.strip().split(" ")
                name = eles[0]
                embedding = eles[1:]
                code = name[4:]

                code = code.replace(".", "")
                if len(code) > 5 or len(code) < 3:
                    print("code is bad")
                trunc_code = code[:3]
                # print(trunc_code)
                if trunc_code in icd9_to_embeddings:
                    codes_lost += 1

                icd9_to_embeddings[trunc_code] = torch.tensor(
                    [float(i) for i in embedding], dtype=torch.float32
                )
        # print('codes_lost', codes_lost)
        return icd9_to_embeddings

    def get(self, code):
        if code in self.icd9_to_embeddings:
            return self.icd9_to_embeddings[code]
        else:
            return None

    def get_idx_to_embedding(self, token2idx):
        idx2embedding = {}
        for code, idx in token2idx.items():
            if code in self.icd9_to_embeddings:
                idx2embedding[idx] = self.icd9_to_embeddings[code]
            else:
                idx2embedding[idx] = torch.zeros(
                    self.embedding_size, dtype=torch.float32
                )
                # print("code is not in icd9 embeddings", code)

        return idx2embedding


In [None]:

class BertEmbeddings(nn.Module):
    def __init__(self, config, feature_dict):
        super(BertEmbeddings, self).__init__()
        self.feature_dict = feature_dict


        # TODO maybe load these as part of the Dataset so we don't have to do extra lookups

        # self.icd9_embeddings = ICD9Embeddings("./embeddings/ic9_embeddings.txt")
        self.icd9_embeddings = ICD9Embeddings(icd9_embedding_txt)
        self.idx2embedding = self.icd9_embeddings.get_idx_to_embedding(token2idx)


        icd9_codes = list(self.idx2embedding.keys())
        # print('length of icd9 codes', len(icd9_codes))

        embeddings_matrix = torch.stack([self.idx2embedding[code] for code in icd9_codes], dim=0)
        # print('embeddings matrix shape', embeddings_matrix.shape)

        # Initialize embeddings for CLS and SEP tokens
        additional_embeddings = torch.randn(2, config.hidden_size)

        # Combine precomputed and additional embeddings
        full_embeddings_matrix = torch.cat([embeddings_matrix, additional_embeddings], dim=0)
        # self.word_embeddings = nn.Embedding.from_pretrained(full_embeddings_matrix, freeze=False)
        # self.word_embeddings = nn.Embedding.from_pretrained(full_embeddings_matrix, freeze=True)
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)      # from scratch



        self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size).\
            from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size))


        self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, word_ids, posi_ids=None, segment_ids=None):
        # print('embeddings forward')
        if segment_ids is None:
            pass

        if posi_ids is None:
            posi_ids = torch.zeros_like(word_ids)
        # print('max word id', word_ids.max().item())
        # print('embeddings size', self.word_embeddings.weight.shape)
        # print(word_ids)
        # print('word_ids shape', word_ids.shape)

        embeddings = self.word_embeddings(word_ids)
        # print('embeddings',embeddings)

        # posi_embeddings = self.posi_embeddings(posi_ids)

        # if self.feature_dict['posi']:
        #     embeddings = embeddings + posi_embeddings

        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

    def _init_posi_embedding(self, max_position_embedding, hidden_size):
        def even_code(pos, idx):
            return np.sin(pos/(10000**(2*idx/hidden_size)))

        def odd_code(pos, idx):
            return np.cos(pos/(10000**(2*idx/hidden_size)))

        # initialize position embedding table
        lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32)

        # reset table parameters with hard encoding
        # set even dimension
        for pos in range(max_position_embedding):
            for idx in np.arange(0, hidden_size, step=2):
                lookup_table[pos, idx] = even_code(pos, idx)
        # set odd dimension
        for pos in range(max_position_embedding):
            for idx in np.arange(1, hidden_size, step=2):
                lookup_table[pos, idx] = odd_code(pos, idx)

        return torch.tensor(lookup_table)


# this should be fairly hands off we should just need to adjust the config parameters
class BertModel(BertPreTrainedModel):
    def __init__(self, config, feature_dict):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config, feature_dict)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.init_weights()

    def forward(self, input_ids, posi_ids=None, attention_mask=None,segment_ids=None,output_all_encoded_layers=True):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if posi_ids is None:
            posi_ids = torch.zeros_like(posi_ids)
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, posi_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output


# this should be fairly hands off we should just need to adjust the config parameters
class BertForMultiLabelPrediction(BertPreTrainedModel):
    def __init__(self, config, num_labels, feature_dict):
        super(BertForMultiLabelPrediction, self).__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel(config, feature_dict)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_labels)
        self.init_weights()

    def forward(self, input_ids, posi_ids=None, attention_mask=None, segment_ids=None, labels=None):
        _, pooled_output = self.bert(input_ids, posi_ids, attention_mask,segment_ids,
                                     output_all_encoded_layers=False)

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        if labels is not None:
            loss_fct = nn.MultiLabelSoftMarginLoss()

            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
            return loss, logits
        else:
            return logits

In [None]:
conf = BertConfig(model_config)
model = BertForMultiLabelPrediction(conf, model_config['vocab_size'], feature_dict)

In [None]:
model = model.to(global_params['device'])
# model = model.to("cuda")
optimizer = torch.optim.Adam(params =  model.parameters(), lr=3e-5)

## *Evaluation Metric Functions*

In [None]:
import sklearn
def precision(logits, label):
    sig = nn.Sigmoid()
    output=sig(logits)
    label, output=label.cpu(), output.detach().cpu()
    tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')
    return tempprc, output, label

def auroc_test(logits, label):
    sig = nn.Sigmoid()
    output=sig(logits)
    label, output = label.cpu(), output.detach().cpu()

    tempprc= sklearn.metrics.roc_auc_score(label.numpy(),output.numpy(), average='samples')
#     roc = sklearn.metrics.roc_auc_score()
    return tempprc

def accuracy(logits, label):
    sig = nn.Sigmoid()
    output = sig(logits)
    label, output = label.cpu(), output.detach().cpu()

    # Apply a threshold to convert probabilities to binary predictions
    # predictions = (output.numpy() > 0.5).astype(int)
    # # Calculate accuracy
    # acc = sklearn.metrics.accuracy_score(label.numpy(), predictions)

    probabilities = torch.sigmoid(logits).detach().cpu()

    threshold = 0.5
    # Apply threshold to convert probabilities to binary predictions
    predictions = (probabilities >= threshold).float()

    # Calculate accuracy
    correct_predictions = (predictions == label).float()  # Element-wise comparison
    # print('where correct: ', np.where(predictions[5] == label[5]))
    accuracy = correct_predictions.mean()  # Mean accuracy across all predictions and labels

    return accuracy.item()

    # # Calculate accuracy for each label independently
    # label_accuracies = []

    # for key in labels:
    #   label_array = labels[key]  # Get the label array corresponding to the key
    #   prediction_array = predictions[key]  # Get the prediction array corresponding to the key

    #   # Calculate accuracy for this label category
    #   acc = sklearn.metrics.accuracy_score(label_array, prediction_array)
    #   label_accuracies.append(acc)
    # # Take mean of all accuracies
    # overall_acc = np.mean(label_accuracies)



    return acc, output, label

def print_result(logits, label):
    sig = nn.Sigmoid()
    output = sig(logits)
    label, output = label.cpu(), output.detach().cpu()

    # predictions = (output.numpy() > 0.5).astype(int)
    predictions = np.argsort(output.numpy())
    predictions = predictions[::-1]


    # print('predicted diseases',     np.where(predictions == 1)[:5])
    print('predicted diseases',     predictions[:5])
    # print(output.numpy()[predictions[:5]])
    print('actual diseases',     np.where(label == 1)[:5])
    # print('label',     label)


def precision_test(logits, label):
    sig = nn.Sigmoid()
    output=sig(logits)
    tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')
    roc = sklearn.metrics.roc_auc_score(label.numpy(),output.numpy(), average='samples')
    return tempprc, roc, output, label,

## *Training*



**Hyperparameters**
are set in variables `global_params`, `model_config`, and `optim_config`, such as:
* Batch Size: 64
* Hidden Size: 300
* Number of Hidden Layers: 6
* Dropout Rate: 0.2
  * finetuned to 0.5

**Computational Requirements:**
* Hardware type: GPU T4
* Avg runtime per epoch: 2 min.
* GPU units used: 40
* Number of training epochs: 25 or 50


###Training Code

In [None]:
def train(e):
    model.train()
    tr_loss = 0
    temp_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    cnt = 0
    for step, batch in enumerate(train_loader):
        cnt +=1
        input_ids, posi_ids, attMask, segment_ids, targets = batch

        input_ids = input_ids.to(global_params['device'])
        posi_ids = posi_ids.to(global_params['device'])
        segment_ids = segment_ids.to(global_params['device'])
        attMask = attMask.to(global_params['device'])
        targets = targets.to(global_params['device'])

        loss, logits = model(input_ids, posi_ids,attention_mask=attMask, segment_ids=segment_ids, labels=targets)

        if global_params['gradient_accumulation_steps'] >1:
            loss = loss/global_params['gradient_accumulation_steps']
        loss.backward()

        temp_loss += loss.item()
        tr_loss += loss.item()
        nb_tr_examples += input_ids.size(0)
        nb_tr_steps += 1

        if step % 50 ==0:
            prec, a, b = precision(logits, targets)
            acc = accuracy(logits,targets)
            auroc = auroc_test(logits,targets)
            print_result(logits[2], targets[2])
            # acc = 0

            print("epoch: {}\t| Cnt: {}\t| Loss: {}\t| precision: {}\t| Accuracy: {}\t| AUROC:{} ".format(e, cnt,temp_loss/500, prec, acc, auroc))
            temp_loss = 0

        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:
            optimizer.step()
            optimizer.zero_grad()

## *Evaluation*

### Metric Description

Our model measures the following metrics:
* Precision
* Accuracy
* Evaluation Loss

The calculation functions for these precision and accuracy can be found above, while loss is calculated in the `evaluation()` function below.

### Evaluation Code

In [None]:
def evaluation():
    model.eval()
    y = []
    y_label = []
    tr_loss = 0
    for step, batch in enumerate(test_loader):
        model.eval()
        input_ids, posi_ids, attMask, segment_ids, targets = batch

        input_ids = input_ids.to(global_params['device'])
        posi_ids = posi_ids.to(global_params['device'])
        segment_ids = segment_ids.to(global_params['device'])
        attMask = attMask.to(global_params['device'])
        targets = targets.to(global_params['device'])

        with torch.no_grad():
          loss, logits = model(input_ids, posi_ids,attention_mask=attMask, segment_ids=segment_ids, labels=targets)
        logits = logits.cpu()
        targets = targets.cpu()

        tr_loss += loss.item()

        y_label.append(targets)
        y.append(logits)

    y_label = torch.cat(y_label, dim=0)
    y = torch.cat(y, dim=0)

    aps, roc, output, label = precision_test(y, y_label)
    return aps, roc, tr_loss

#Loading presaved/trained model:*`

In [None]:
checkpoint_path = '/content/gdrive/MyDrive/598ehrupload/saved_models/best_model_unfrozen_embeddings_finetune.pt'
checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)
print("Done")

# Model Output

The following code is for running the model itself.

In [None]:
import os

# If using Google Drive
base_dir = '/content/gdrive/MyDrive/598ehrupload/saved_models'
# If using local Colab space
# base_dir = '/content/ModelDirectory'

global_params['output_dir'] = base_dir
global_params['best_name'] = 'best_model_unfrozen_embeddings_finetune.pt'

def create_folder(path):
    if not os.path.exists(path):
        os.makedirs(path)

best_pre = 0.0
for e in range(50):
    print("starting training...")
    train(e)
    aps, roc, test_loss = evaluation()
    if aps >best_pre:
        # Save a trained model
        print("** ** * Saving fine - tuned model ** ** * ")
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(global_params['output_dir'],global_params['best_name'])
        create_folder(global_params['output_dir'])

        torch.save(model_to_save.state_dict(), output_model_file)
        best_pre = aps
    print('aps : {}'.format(aps))

# **Results**

Results (15)

* Table of results (*These metrics measure the ability for the model to predict subsequent diagnosis based on a patients diagnosis history*)
  * Pretrained unfrozen embeddings |  precision: 0.27581077047200264, AUROC:0.8643956994104
  * Pretrained frozen embeddings |  precision: 0.22871352401267167, AUROC:0.8595027934194563
  * Randomly Initialized embeddings | precision: 0.2644318730070032	AUROC:0.8797517785185457
  * no positional embeddings | precision: precision: 0.2730285026690921	 AUROC:0.8678234320505926

* Discuss with respect to the hypothesis and results from the original paper
   * We were not able to see a statistically significant difference in adding positional embeddings, nor were we able to reproduce the level of precision that the paper was able to, partially because they had access to 8.1 million patients while we only had 11,000 patients. However following their basic architecture we were able to achieve very high accuracy scores, though a better indication would be the high AUROC score. However analyzing the actual outputs of the model paints a strange picture of overfitting and poor data quality. The model begins to learn the frequency of diseases rather than the corrolation between features and predictions, as it is very keen to predict common diseases, for example in our labeled dataset it is common for later vists to contain "Facial nerve disorders" so the model predicts this disease in particular quite often. Which seems to point to class imbalance and this model underfitting the data, or just not enough data to begin with.
* Experiments beyond the original paper
   * In order to allow the model to converge faster with less training data we attempted to use pretrained icd-9 code embeddings, this also made a neglible difference in terms of performance. We even attempted to see what the different would be if we froze the weights, since the weights are supposed to be pretrained and the relationships between these diseases should already be known. However after performing an ablation study by omitting the loading of the pretrained embeddings weights we are able to see that the preformance, while it does converge slower, its precision remains low until the 3rd epoch, it is able to achieve the same level of precision and AUROC after 10 epochs
* Ablation Study.
The main ablation we replicated from the paper is the exclusion of positional embeddings. And this change does not seem to affect the performance of the model that much. See experiment section. Similarly the addition of pretrained embeddings while makes the model converge faster ultimatly even with randomly intialized embeddings the model is able to learn the embeddings on its own. See Experiments beyond the original paper


# **Discussion**

As previously mentioned, due to the missing code from our original paper, TransformEHR, we found that it was *not reproducible*. Since we did not have access to the data collator functions or the tokenizers that the authors used to prepare the data, we were not able to determine the exact format needed to run the model and reproduce results. We would recommend that the authors of this paper upload the `datasets.py` file which includes these missing functions that would allow the model to run.

The other paper we followed in order to achieve similar goals, BEHRT, was more reproducible, as it contained all the code necessary to run the model, including the custom dataset NextVisit we modeled ours off of.

# **References**

* BEHRT: Transformer for Electronic Health Records
  * Paper: https://www.nature.com/articles/s41598-020-62922-y
  * GitHub Repo: https://github.com/deepmedicine/BEHRT/tree/master
* LLM Embeddings for ICD-10 Data
  * https://github.com/whaleloops/TransformEHR/tree/main
* TransformEHR
  * Paper: https://doi.org/10.1038/s41467-023-43715-z
  * GitHub Repo: https://github.com/whaleloops/TransformEHR/tree/main




