In [None]:
from collections.abc import Callable
from typing import Tuple, Union
from tqdm import tqdm
import pandas as pd
import numpy as np
import tracemalloc
import os
import torch
import torch.nn as nn
from torch.optim import Adam
from torch_geometric.nn.conv import GATConv
from torch_geometric.nn.inits import glorot
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.multiclass import OneVsRestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression
import random
import time

from Utils import multi_label_metric  

random.seed(time.time())

# <span style="text-decoration: underline">G-BERT</span>

Note that ATC vocab and ICD vocab are generated in PART-2: (_get_atc_ontology and _get_icd_ontology) and to be used overall[Can be revisited if needed]

In [None]:
from vocab import Vocab

Once we have vocab class, we can define a utility method to get multi-hot vector using vocab and incoming list of codes. This is pure multi-hot-vector in the sense that it do not contain special tokens

In [None]:
def get_multi_hot(code_list: Union[list[str], list[list[str]]], vocab: Vocab)->list[int]:
    if len(code_list) == 0:
        raise Exception("Should have codes to convert")
    if isinstance(code_list[0], list):
        zeros_and_ones = torch.zeros((len(code_list), len(vocab.word2idx)), dtype=torch.float)
        for i, codes in enumerate(code_list):
            for code in codes: 
                if vocab.word2idx.get(code) == None:
                    pass
                    # TODO: Need to check these pass 
                    # print("[get_multi_hot]: has some unseen code", code)
                else:
                    zeros_and_ones[i, vocab.word2idx.get(code)] = 1
    else:    
        zeros_and_ones = torch.zeros(len(vocab.word2idx), dtype=torch.float)
        for code in code_list:
            if vocab.word2idx.get(code) == None:
                pass
                # TODO: Need to check these pass
                # print("[get_multi_hot]: has some unseen code", code)
            else:
                zeros_and_ones[vocab.word2idx.get(code)] = 1
            
    return zeros_and_ones

Other Global Variables:

In [None]:
from constants import *

DEVICE = torch.device("cpu")
if torch.cuda.is_available():
    DEVICE = torch.device("cuda") 
    torch.cuda.device(DEVICE)


### <u>Part-1: Filtering MIMIC-III for Unique ATC/ICD</u>

<a href="https://mimic.mit.edu/docs/iii/">MIMIC-III</a> has PRESCRIPTIONS.csv which has NDC medication codes. So we would need mapping to ATC codes.

Mapping is done using RX-Norm API. First NDC is converted to RXCUI(getNDCStatus) and then to ATC code(rx-norm class API). This all is done in an R script here: https://github.com/fabkury/ndc_map

Steps for _atc_ in an ideal situation would be:
<ol>
    <li>Parse output of the R-Script mentioned above to create an NDC to ATC map</li>
    <li>Use that map to find out ATC of PRESCRIPTION.csv [<i><a href="https://github.com/jshang123/G-Bert">G-BERT author's implementation</a></i> has first visit only][can this qualify as an extension?]</li>
    <li>Store unique atc codes in atc4-vocab.csv</li>
</ol>

But since the Prescriptions.csv is pretty old and many NDC codes have changed we are using <i>ndc2rxnorm_mapping,csv</i> and <i>rxnorm2atc_level4.csv</i> which are obtained from <a href="https://github.com/sjy1203/GAMENet">GAMENET implementation</a>. Thus the new steps:

<ol>
<li>Load NDC codes from Prescriptions.csv(MIMIC-III)</li>
<li>Convert them to RXNorm codes using corresponding mapping file</li>
<li>Convert RXNorm codes to ATC codes</li>
</ol>

In [None]:
def merge_prescriptions_and_mapping_files() -> pd.DataFrame:
    # TODO: need to reduce the unique codes and validate the mapping files
    # step 1
    print("[merge_prescriptions_and_mapping_files] Begining of loading MIMIC-III PRESCRIPTIONS.csv")
    prescriptions_df = pd.read_csv(os.path.join(GLOBAL_DATA_PATH, PRESCRIPTIONS))
    medications = prescriptions_df.loc[~(prescriptions_df["NDC"]==0.0)].astype({"NDC": "Int64"})
    medications.dropna(inplace=True)

    # step 2
    print("[merge_prescriptions_and_mapping_files] Begining of loading NDC-RXCUI mapping for merging prescriptions")
    ndc2rx_cui:dict[str, str] = {}
    with open(os.path.join(GLOBAL_DATA_PATH, NDC_2_RXCUI_MAPPING), 'r') as f:
        for key, value in eval(f.read()).items():
            try:
                ndc2rx_cui[int(key)] = value 
            except Exception:
                print("[merge_prescriptions_and_mapping_files]", key, value, " is causing some issue while loading NDC-RXCUI mapping")
    medications["RXCUI"] = medications["NDC"].map(ndc2rx_cui, na_action="ignore")
    medications.drop(index = (medications.loc[medications["RXCUI"] == ""]).index, axis=0, inplace=True)
    medications["RXCUI"] = medications["RXCUI"].astype("int64")

    # step 3:
    print("[merge_prescriptions_and_mapping_files] Begining of loading RXCUI-ATC mapping for merging prescriptions")
    rxcui2atc = pd.read_csv(os.path.join(GLOBAL_DATA_PATH, RXCUI_2_ATC))
    return pd.merge(medications[["SUBJECT_ID", "HADM_ID", "RXCUI"]], rxcui2atc[["RXCUI", "ATC4"]], on=["RXCUI"])

def create_unique_atc_csv_file(outputPath: str) -> pd.DataFrame:
    print("[create_unique_atc_csv_file] Starting...")
    prescriptions = merge_prescriptions_and_mapping_files()
    print("[create_unique_atc_csv_file] Unique ATC codes in hand now")
    atc4_codes = prescriptions["ATC4"].drop_duplicates().reset_index()
    print("[create_unique_atc_csv_file] Going to store ATC codes now")
    atc4_codes.to_csv(outputPath)

    return atc4_codes

Code to create atc vocab <br>
<span style="color: #055BA6">create_unique_atc_csv_file(os.path.join(GLOBAL_DATA_PATH, UNIQUE_ATC_CSV))</span>

Steps for _DX_ would be:
<ol>
    <li>Use DIAGNOSES_ICD.csv column "ICD9_CODE"</li>
    <li>Step 2-3: Store unique ICD codes in unique-icd.csv. <a href="https://arxiv.org/pdf/1906.00346.pdf">G-BERT</a> implementation on <a href="https://github.com/jshang123/G-Bert">github</a> process and stores only top 2000 diagnosis(by freq). <span style="color: #FF540D">[Assuming] Since not in paper</span></li>
</ol>
Moreover, we analyzed that only 3% visits have >30% "[UNK]" tokens if we take top 2000. This means that 97% of data has > 70% diagnosis per visit within top 2000 diagnosis.

<span style="color: #FF540D">[Issue] Author implementation has two vocabs but ideally BERT style pre-training should have one global vocab</span>

In [None]:
def create_unique_icd_csv_file(outputPath: str) -> pd.DataFrame:
    # step 1
    print("[create_unique_icd_csv_file] Step 1... start")
    diagnosis_df = pd.read_csv(os.path.join(GLOBAL_DATA_PATH, DIAGNOSIS_ICD))
    diagnosis_df.dropna(inplace=True)
    diagnosis_df.drop_duplicates(inplace=True)

    # Step 2: to find top 2000 by freq of occurrence
    print("[create_unique_icd_csv_file] Step 2... start")
    icd_codes = diagnosis_df["ICD9_CODE"]
    icd_codes = icd_codes.groupby(icd_codes).count().to_frame(name="ICD9_CODE_COUNT").reset_index()
    sorted_icd_codes = icd_codes.sort_values("ICD9_CODE_COUNT", ascending=False).iloc[:2000]

    sorted_icd_codes.drop_duplicates(subset=["ICD9_CODE"], inplace=True)

    #step 3
    print("[create_unique_icd_csv_file] Step 3... start")
    sorted_icd_codes["ICD9_CODE"].reset_index(drop=True).astype("string").to_csv(outputPath)

    return sorted_icd_codes.head()

Code to create ICD vocab <br>
<span style="color: #055BA6">create_unique_icd_csv_file(os.path.join(GLOBAL_DATA_PATH, UNIQUE_ICD_CSV))</span>

### <u>Part-2: Defining Functions that can create Ontologies/trees(medication(ATC) and diseases(ICD9))</u>

We have list of ICD codes available(hierarchical) here: https://icd.codes/icd9cm

Based on this list and knowledge, below is a method which <i>identify</i> what would be ancestor given an ICD code. After that we can simply create a tree[Refer the second method in the code below] of icd codes which we will call an Ontology. 

In [None]:
"""
ICD-9
"""

def level2_ancestors() ->  dict[str, str]:
    level2 = ["001-009", "010-018", "020-027", "030-041", "042", "045-049", "050-059", "060-066", "070-079", "080-088",
              "090-099", "100-104", "110-118", "120-129", "130-136", "137-139", "140-149", "150-159", "160-165",
              "170-176",
              "176", "179-189", "190-199", "200-208", "209", "210-229", "230-234", "235-238", "239", "240-246",
              "249-259",
              "260-269", "270-279", "280-289", "290-294", "295-299", "300-316", "317-319", "320-327", "330-337", "338",
              "339", "340-349", "350-359", "360-379", "380-389", "390-392", "393-398", "401-405", "410-414", "415-417",
              "420-429", "430-438", "440-449", "451-459", "460-466", "470-478", "480-488", "490-496", "500-508",
              "510-519",
              "520-529", "530-539", "540-543", "550-553", "555-558", "560-569", "570-579", "580-589", "590-599",
              "600-608",
              "610-611", "614-616", "617-629", "630-639", "640-649", "650-659", "660-669", "670-677", "678-679",
              "680-686",
              "690-698", "700-709", "710-719", "720-724", "725-729", "730-739", "740-759", "760-763", "764-779",
              "780-789",
              "790-796", "797-799", "800-804", "805-809", "810-819", "820-829", "830-839", "840-848", "850-854",
              "860-869",
              "870-879", "880-887", "890-897", "900-904", "905-909", "910-919", "920-924", "925-929", "930-939",
              "940-949",
              "950-957", "958-959", "960-979", "980-989", "990-995", "996-999", "V01-V91", "V01-V09", "V10-V19",
              "V20-V29",
              "V30-V39", "V40-V49", "V50-V59", "V60-V69", "V70-V82", "V83-V84", "V85", "V86", "V87", "V88", "V89",
              "V90",
              "V91", "E000-E899", "E000", "E001-E030", "E800-E807", "E810-E819", "E820-E825", "E826-E829", "E830-E838",
              "E840-E845", "E846-E849", "E850-E858", "E860-E869", "E870-E876", "E878-E879", "E880-E888", "E890-E899",
              "E900-E909", "E910-E915", "E916-E928", "E929", "E930-E949", "E950-E959", "E960-E969", "E970-E978",
              "E980-E989", "E990-E999"]

    level2_ancestor = {}
    for i in level2:
        tokens = i.split("-")
        if i[0] == "V":
            if len(tokens) == 1:
                level2_ancestor[i] = i
            else:
                for j in range(int(tokens[0][1:]), int(tokens[1][1:]) + 1):
                    level2_ancestor["V%02d" % j] = i
        elif i[0] == "E":
            if len(tokens) == 1:
                level2_ancestor[i] = i
            else:
                for j in range(int(tokens[0][1:]), int(tokens[1][1:]) + 1):
                    level2_ancestor["E%03d" % j] = i
        else:
            if len(tokens) == 1:
                level2_ancestor[i] = i
            else:
                for j in range(int(tokens[0]), int(tokens[1]) + 1):
                    level2_ancestor["%03d" % j] = i
    return level2_ancestor


def build_icd9_tree(unique_codes:list[str]) -> Tuple[list[str], Vocab]:
    paths = []
    icd9_vocab = Vocab()
    icd9_leaf_vocab = Vocab()

    root_node = "icd9_root"
    level3_dict = level2_ancestors()
    for code in unique_codes:
        level1 = code
        icd9_leaf_vocab.add_sentence([code])
        level2 = level1[:4] if level1[0] == "E" else level1[:3]
        level3 = level3_dict[level2]
        level4 = root_node

        path = [level1, level2, level3, level4]

        icd9_vocab.add_sentence(path)
        paths.append(path)

    return paths, icd9_vocab, icd9_leaf_vocab

Similarly for ATC codes. Now to understand ATC hierarchy these are the resources:
<ul>
    <li><a href="https://www.whocc.no/atc/structure_and_principles/">https://www.whocc.no/atc/structure_and_principles/<a></li>
    <li><a href="https://www.atccode.com/">https://www.atccode.com/</a></li>
</ul>

We are using ATC since its mentioned in the paper. Moreover several papers consider ATC codes have a good hierarchy.

In [None]:
"""
ATC
"""

def build_atc_tree(unique_codes: list[str]) -> Tuple[list[str], Vocab]:
    paths = []
    atc_vocab = Vocab()
    atc_leaf_vocab = Vocab()

    root_node = "atc_root"
    for code in unique_codes:
        atc_leaf_vocab.add_sentence([code])
        path = [code] + [code[:i] for i in [4, 3, 1]] + [root_node]

        atc_vocab.add_sentence(path)
        paths.append(path)

    return paths, atc_vocab, atc_leaf_vocab

Helper Functions to create globally available <i>"Ontology"</i> and vocab over Unique (ATC/ICD) codes

In [None]:
def _get_atc_ontology():
    unique_atc_codes = pd.read_csv(os.path.join(GLOBAL_DATA_PATH, UNIQUE_ATC_CSV))["ATC4"]
    return build_atc_tree(unique_atc_codes.values.tolist())

def _get_icd_ontology():
    unique_icd_codes = pd.read_csv(os.path.join(GLOBAL_DATA_PATH, UNIQUE_ICD_CSV))["ICD9_CODE"]
    return build_icd9_tree(unique_icd_codes.values.tolist())

# TODO: icd codes in my parse are 6984 but acc to paper it should have been 1997 ==> prof suggested that we should see stats esp based on freq
# TODO: atc codes in my parse are 413 but acc to paper it should have been 323
GLOBAL_ATC_GRAPH, GLOBAL_ATC_VOCAB, GLOBAL_ATC_LEAF_VOCAB = _get_atc_ontology()
GLOBAL_ICD_GRAPH, GLOBAL_ICD_VOCAB, GLOBAL_ICD_LEAF_VOCAB = _get_icd_ontology()

We need to add special token to vocab. We are adding in the end so that they don't interfere with masking and multi-hot to code-list conversion

In [None]:
GLOBAL_ATC_VOCAB.add_special_tokens()
GLOBAL_ICD_VOCAB.add_special_tokens()

### <u>Part-3: Building Ontology Embedding over ICD and ATC (using Ontology from prev part and apply attention over it)</u>

For ontology embedding as described in <a href="https://arxiv.org/pdf/1906.00346.pdf">GBERT paper</a>, we would need several components:
<ol>
    <li>Utility to build stage 1 edges. That is edge from direct child to parent </li>
    <li>Utility to build stage 2 edges. That is edge from ancestor of a leaf node to the lead node itself </li>
    <li>Graph Neural Network which are attention based message passing graph convolution neural network(GAT) </li>
</ol>

For more details about stage based edges, one should refer the main GBERT paper. Moreover, details about GAT can be found here https://arxiv.org/abs/1710.10903

<h5> <u>Step 1. and 2.</u></h5>
Assuming you have knowledge of what are stage-1/2 edges, we would go about how to create them now.

We are using pytorch geometric for GNN. They have <a href="https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.message_passing.MessagePassing">Message Passing GNN</a> which takes in "flow" param. This param defines how message "flows" between the nodes. For flow="source_to_target" message flow from neighbor(j^th node) to target(i^th node). 

Thus one has to be very careful about how "edge_index"(input to Pytorch Geometric based GNNs) is formed. First of all, pytorch-geometric based Modules require COO based edge-index input. [One can google about them, here is the summary along with other formats that might be used somewhere else: https://scipy-lectures.org/advanced/scipy_sparse/storage_schemes.html#summary]

Now coming back to step/stage-1, we would create [[parent indices], [direct-child indices]], since we want i = parent and j = direct child in this stage. (i.e. child to parent message flow, given flow="source_to_target")

For step/stage-2 we would create [[leaf indices], [ancestor indices]], since we want i = leaf and j = ancestor in stage-2. (i.e. ancestor to leaf message flow, given flow="source_to_target")

In [None]:
def build_stage_one_edges(paths: list[str], graph_voc: Vocab) -> list[list[int], list[int]]:
    """
        :param paths: different paths from leaf node to root node in an ontology(icd/atc)
        :param graph_voc: Vocab object for that ontology
        :return: edge_idx: in COO format (more about that in pytorch docs or may refer here https://scipy-lectures.org/advanced/scipy_sparse/coo_matrix.html)
    """
    edge_idx = []
    for path in paths:
        path_in_idx_format = list(map(lambda x: graph_voc.word2idx[x], path))
        for i in range(len(path_in_idx_format) - 1):
            # only direct children edges
            # edge : parent(i)->child(j) [assuming flow j->i]
            parent = path_in_idx_format[i+1]
            node = path_in_idx_format[i]
            edge_idx.append((parent, node))

    edge_idx = list(dict.fromkeys(edge_idx))
    row = list(map(lambda x: x[0], edge_idx))
    col = list(map(lambda x: x[1], edge_idx))
    return [row, col]


def build_stage_two_edges(paths: list[str], graph_voc: Vocab) -> list[list[int], list[int]]:
    """
        :param paths: different paths from leaf node to root node in an ontology(icd/atc)
        :param graph_voc: Vocab object for that ontology
        :return: edge_idx: in COO format (more about that in pytorch docs or may refer here https://scipy-lectures.org/advanced/scipy_sparse/coo_matrix.html)
    """
    edge_idx = []
    for path in paths:
        path_in_idx_format = list(map(lambda x: graph_voc.word2idx[x], path))
        # leaf node to ancestor edges 
        # edge: leaf-node(i)->ancestors(j) [assuming flow j->i]
        leaf = path_in_idx_format[0]
        for i in range(1, len(path_in_idx_format)):
            ancestor = path_in_idx_format[i]
            edge_idx.append((leaf, ancestor))

    edge_idx = list(dict.fromkeys(edge_idx))
    row = list(map(lambda x: x[0], edge_idx))
    col = list(map(lambda x: x[1], edge_idx))
    return [row, col]

<h5> <u>Step 3.</u></h5>
We would use Pytorch Geometric GAT Convolutional layer and then use that module for creating General Embedding module ==> Would be called <i>OntologyEmbedding</i> that can be used for ICD/ATC embeddings over a graph.

Defining General Embedding module ==> <i>OntologyEmbedding</i>. This class, as mentioned earlier, can be used to represent ICD/ATC embeddings in later stages where we implement BERT.

In [None]:
class OntologyEmbedding(nn.Module):
    """
        get_ontology: method that will give the ontology on which to build embedding
    """
    def __init__(self, ontology_paths: Tuple[list[list[str]]], vocab: Vocab, emb_size=GAT_CONV_OUT_CHANNEL):
        super(OntologyEmbedding, self).__init__()

        stage_one_edges = build_stage_one_edges(ontology_paths, vocab)
        stage_two_edges = build_stage_two_edges(ontology_paths, vocab)

        self.edges1 = torch.LongTensor(stage_one_edges)
        # print(f"[OntologyEmbedding:init] shape of edge_index: {self.edges1.shape}")
        self.edges2 = torch.LongTensor(stage_two_edges)
        self.vocab = vocab

        self.in_channels = GAT_CONV_IN_CHANNEL
        self.heads = GAT_CONV_HEADS
        assert self.in_channels == self.heads * emb_size
        self.g = GATConv(in_channels=self.in_channels, out_channels=emb_size, heads=self.heads, \
            dropout=GAT_CONV_DROPOUT, negative_slope=GAT_CONV_NEGATIVE_SLOP)

        num_nodes = len(vocab.word2idx)
        # TODO: Need to check that does special token gradient is happening
        self.initial_embedding = nn.Parameter(torch.empty((num_nodes, self.in_channels)))
        glorot(self.initial_embedding)

    def forward(self):
        emb = self.initial_embedding
        # print(f"[OntologyEmbedding::forward] emb.shape:{emb.shape}")
        stage1 = self.g(emb, self.edges1)
        # print(f"[OntologyEmbedding::forward] stage1.shape:{stage1.shape}")

        # concatenation being done of heads that is why below should pass
        assert stage1.shape[-1] == self.g.in_channels
        
        emb = self.g(stage1, self.edges2)
        # print(f"[OntologyEmbedding::forward] emb.shape:{emb.shape}")
        return emb


### <u>Part-4: Defining GBERT Model</u>

We would need data of medication(ATC) and diagnosis(ICD) from visits for training. So first step towards building GBERT would be to parse MIMIC-III data for single visit patient's medication and diagnosis codes that would be used in pre-training.

Steps for creating single visit patient's pkl file:
<ol>
    <li>Load <a href="https://mimic.mit.edu/docs/iii/tables/diagnoses_icd/">DIAGNOSIS_ICD.csv</a> along with <a href="https://mimic.mit.edu/docs/iii/tables/admissions/">ADMISSIONS.csv</a>. For medications: Merge <a href="https://mimic.mit.edu/docs/iii/tables/prescriptions/">PRESCRIPTIONS.csv</a> and mapping files i.e. NDC-RXCUI mapping and RXCUI-ATC</li>
    <li>To get the single visit patients we should group by "SUBJECT_ID" and "HADM_ID" on ADMISSIONS.csv and then select "SUBJECT_ID" with only 1 "HADM_ID"</li>
    <li>We should select these "SUBJECT_ID"[which we got in step-2 ] and icd codes[which we got in previous step-1], from DIAGNOSIS_ICD.csv </li>
    <li>Similarly, We should select these "SUBJECT_ID"[which we got in step-2] and their Atc4 codes[which we got in step-1], from PRESCRIPTIONS.csv </li>
    <li>Next, pickling after merging both Diagnosis and Medication Data frames </li>
</ol>

In [None]:
def create_single_visit_pkl_file(outputPath: str) -> pd.DataFrame:
    # step 1
    visit_record_key = ["SUBJECT_ID", "HADM_ID"]
    diag_df = pd.read_csv(os.path.join(GLOBAL_DATA_PATH, DIAGNOSIS_ICD))
    print("[create_single_visit_pkl_file] Begining of Loading of PRESCRIPTIONS.csv")
    medications_df = merge_prescriptions_and_mapping_files()[visit_record_key + ["ATC4"]]
    print("[create_single_visit_pkl_file] Loaded PRESCRIPTIONS.csv")
    admission_df = pd.read_csv(os.path.join(GLOBAL_DATA_PATH, ADMISSIONS))[visit_record_key]

    # step 2
    sub_id_adm_count = admission_df.groupby(by=["SUBJECT_ID"])["HADM_ID"].apply(lambda _: len(set(_))).reset_index()
    single_visit_sub_id = sub_id_adm_count[sub_id_adm_count["HADM_ID"] == 1]["SUBJECT_ID"]
    
    # step 3
    single_visit_diag = pd.merge(diag_df, single_visit_sub_id, on=["SUBJECT_ID"]).dropna().drop_duplicates()
    single_visit_diag = single_visit_diag[visit_record_key + ["ICD9_CODE"]].groupby(by=visit_record_key)["ICD9_CODE"].apply(list).reset_index()    
    # Columns of single_visit_diag would be ["SUBJECT_ID", "HADM_ID",list["ICD9_CODE"]]

    # step 4
    single_visit_meds = pd.merge(medications_df, single_visit_sub_id, on=["SUBJECT_ID"]).dropna().drop_duplicates()
    single_visit_meds = single_visit_meds.groupby(by=visit_record_key)["ATC4"].apply(list).reset_index() 
    # row of single_visit_meds would look like: ["SUBJECT_ID", "HADM_ID", list["ATC4"]]

    # step 5
    final_df = pd.merge(single_visit_diag, single_visit_meds, on=visit_record_key)
    final_df.to_pickle(outputPath)
    return final_df

Code to create pkl file for single visit <br>
<span style="color: #055BA6">create_single_visit_pkl_file(os.path.join(GLOBAL_DATA_PATH, SINGLE_VISIT_PKL))</span>

In [None]:
class BERT(nn.Module):
    def __init__(self, d_model, d_ff, num_layers, heads, batch_first):
        super(BERT, self).__init__()
        self.transformerEncoderLayer = nn.TransformerEncoderLayer(d_model, heads, d_ff, batch_first=batch_first)
        self.transformerEncoder = nn.TransformerEncoder(self.transformerEncoderLayer, num_layers)

    def forward(self, x, mask)->torch.Tensor:
        # print(f"[BERT: forward] x.shape:{x.shape}, mask.shape{mask.shape}")
        return self.transformerEncoder(x, src_key_padding_mask=mask)

Now that we have defined a basic BERT. We will define GBERT which uses BERT to get "CLS" encoding that represents the input DX/RX codes and will be later used to pre-train GBERT or predict RX in final training.

In [None]:
class GBERT(nn.Module):
    def __init__(self, embedding_size=100, d_model=300, d_ff=300, num_layers=2, heads=4, batch_first=True):
        super(GBERT, self).__init__()
        self.ingest_embeddings = nn.Linear(embedding_size, d_model)
        self.bert = BERT(d_model, d_ff, num_layers, heads, batch_first)
        # This linear layer is to be applied on top of "CLS" of BERT output
        self.cls = nn.Linear(d_model, d_model)

    def forward(self, dx: torch.Tensor, dx_mask: torch.Tensor, rx: torch.Tensor, rx_mask: torch.Tensor)->Tuple[torch.Tensor, torch.Tensor]:
        """
            dx shape: [-1, max_seq_len, GAT_CONV_HEADS*GAT_CONV_OUT_CHANNEL]
            dx_mask: [-1, max_seq_len]
            rx shape: [-1, max_seq_len, GAT_CONV_HEADS*GAT_CONV_OUT_CHANNEL]
            rx_mask: [-1, max_seq_len]
        """
        dx = self.ingest_embeddings(dx) # simple since we don't have positional encodings
        rx = self.ingest_embeddings(rx)
        vd = self.bert(dx, dx_mask)
        # print(f"[GBERT::forward] vd[:, 0:1, :].shape: {vd[:, 0:1, :].shape}")
        vm = self.bert(rx, rx_mask)
        # print(f"[GBERT::forward] vm[:, 0:1, :].shape: {vm[:, 0:1, :].shape}")

        return self.cls(vd[:, 0:1, :].squeeze(dim=1)), self.cls(vm[:, 0:1, :].squeeze(dim=1))


### <u>Part-5: Dataset and CollatedModel for Pre-training GBERT Models</u>

We define here data loading step which include:
<ol>
    <li>Creating Padded list of ATC/ICD codes and mask tensor for padded codes. Will have different for ATC4 and ICD9_CODE</li>
    <li> We also define multi-hot label that will be used as gold-standard or the target tensor</li>
    <li>[Masking Strategy]: We will be masking with 15% probability - randomly </li>
</ol>

After that we will work on pre-training.

<span style="color: #FF540D">[Assuming] that paper means by single visit as Single Hospital admission. But in git implementation of G-BERT it seem to be first 24 hrs too. </span> Moreover Paper have max seq length of 62 but since We have whole admission(and not first 24 hrs) that may be the reason, so Subject_ID#96232 still have 100 ATC4 codes. Will need to revisit this since this will impact in defining max_seq_length of BERT too.
<span style="color: #FF540D">[Issue in author dataset] SUBJECT_ID 11, 86, 92 are not present in single visit. And no reasoning</span>
For now We are sticking with 100 ==> 101 due to "[CLS]" for ATC4 and 40 for ICD9_CODE

<span style="color: #FF540D">[Issue in paper]: 15% of codes might not mean masking with probability 15%</span>

In [None]:
class PreTrainingDataset(Dataset):
    
    def __init__(self, df: pd.DataFrame, atc_vocab: Vocab, atc_leaf_vocab: Vocab, icd_vocab: Vocab, icd_leaf_vocab: Vocab, max_atc_len: int, \
        max_icd_len: int):

        super(PreTrainingDataset, self).__init__()
        self.df = df
        self.maxAtcLen = max_atc_len
        self.maxIcdLen = max_icd_len
        self.atcVocab = atc_vocab
        self.atcLeafVocab = atc_leaf_vocab
        self.icdVocab = icd_vocab
        self.icdLeafVocab = icd_leaf_vocab
    
    @classmethod
    def Build(self):
        # Refer part-5 why max len -> 101, 40
        max_atc_len =  PRETRAINING_MAX_ATC_LEN 
        max_icd_len =  PRETRAINING_MAX_ICD_LEN 
        single_visit: pd.DataFrame = pd.read_pickle(os.path.join(GLOBAL_DATA_PATH, SINGLE_VISIT_PKL))

        # print("[PreTrainingDataset::Build] type ", first_visit_diag["ICD9_CODE"].apply(len).max())
        assert max_icd_len >= single_visit["ICD9_CODE"].apply(len).max(), "[PreTrainingDataset::Build] Max Seq length is less"
        assert max_atc_len >= single_visit["ATC4"].apply(len).max(), "[PreTrainingDataset::Build] Max Seq length is less"
        
        return PreTrainingDataset(single_visit, GLOBAL_ATC_VOCAB, GLOBAL_ATC_LEAF_VOCAB, GLOBAL_ICD_VOCAB, GLOBAL_ICD_LEAF_VOCAB, max_atc_len, \
            max_icd_len)

    def __len__(self):
        # self.df["SUBJECT_ID"].size
        # can see if the below works or not, if not then use above one
        return self.df.shape[0]
    
    def __getitem__(self, index: int):
        """
            can pad and covert to tensor later as well(collate_fn), if any issue here then can try that way
        """
        atc:list[str] = self.df.iloc[index]["ATC4"]
        icd:list[str] = self.df.iloc[index]["ICD9_CODE"]
        
        atc_word2idx_mask = torch.ones(self.maxAtcLen, dtype=torch.bool, device=DEVICE) # ones due to pytorch convention
        atc_word2idx_list = torch.zeros(self.maxAtcLen, dtype=torch.long, device=DEVICE)
        # CLS:
        atc_word2idx_mask[0] = 0
        atc_word2idx_list[0] = self.atcVocab.word2idx.get("[CLS]")
        atc_mlm_mask = torch.rand(len(atc)) < 0.15
        for i, atc_code in enumerate(atc):
            if atc_mlm_mask[i] == True:
                idx = self.atcVocab.word2idx.get("[MASK]") # need this idx to be over all graph nodes
            else:
                idx = self.atcVocab.word2idx.get(atc_code) # need this idx to be over all graph nodes
            if idx == None:
                print("PreTrainingDataset: Some unseen ATC code", atc_code)
            else:
                atc_word2idx_mask[i+1] = 0
                atc_word2idx_list[i+1] = idx


        icd_word2idx_mask = torch.ones(self.maxIcdLen, dtype=torch.bool, device=DEVICE)
        icd_word2idx_list = torch.zeros(self.maxIcdLen, dtype=torch.long, device=DEVICE)
        # CLS:
        icd_word2idx_mask[0] = 0
        icd_word2idx_list[0] = self.icdVocab.word2idx.get("[CLS]")
        icd_mlm_mask = torch.rand(len(icd)) < 0.15
        for i, icd_code in enumerate(icd):
            if icd_mlm_mask[i] == True:
                idx = self.icdVocab.word2idx.get("[MASK]")
            else:
                idx = self.icdVocab.word2idx.get(icd_code)
            icd_word2idx_mask[i+1] = 0
            if idx == None:
                # TODO: check if this is fine ==> one way can be to check if single visit pkl from author code has unk icds
                icd_word2idx_list[i+1] = self.icdVocab.word2idx.get("[UNK]")
                # print("PreTrainingDataset: Some unseen icd code", icd_code)
                # print("[PreTrainingDataset]: Skipping this for now")
            else:
                icd_word2idx_list[i+1] = idx
    
        return atc_word2idx_list, get_multi_hot(atc, self.atcLeafVocab), atc_word2idx_mask,\
            icd_word2idx_list, get_multi_hot(icd, self.icdLeafVocab), icd_word2idx_mask 


Pre-training of GBERT is inspired from NLP based tasks such as NSP and MLM. However, to fit into medical domain, authors have changed them to Self-prediction and Double-prediction task(More details in paper). 

For above mentioned tasks we need 4 linear layers. 
<ul>
    <li>Layer 1: converts Vd(output of BERT) to Diagnosis labels(i.e. self prediction)</li>
    <li>Layer 2: converts Vd to Medication Labels(i.e. double prediction)</li>
    <li>Layer 3: converts Vm(output of BERT) to Diagnosis labels(i.e. double prediction)</li>
    <li>Layer 4: converts Vm to Medication Labels(i.e. self prediction)</li>
</ul>
We define a template for the above layers.

In [None]:
pretraining_task_layer_template = lambda d_model, pretraining_ff_model, vocab_size:\
    nn.Sequential(nn.Linear(d_model, pretraining_ff_model), nn.ReLU(), nn.Linear(pretraining_ff_model, vocab_size))

We now integrate all the tools build before: GBERT, OntologyEmbedding and these 4 layers into one single model.
Pre-training is done on this final model and the loss is added and propagated backward. 

Notice that data loading gives a list of codes and mask tensor. We obtain Ontology embedding[PART-3] inside the collated Model. And pass that embedding as input to BERT.

In [None]:
class CollatedModelForPretraining(nn.Module):
    def __init__(self):
        super(CollatedModelForPretraining, self).__init__()
        self.atc_ontology_embedding = OntologyEmbedding(GLOBAL_ATC_GRAPH, GLOBAL_ATC_VOCAB)
        self.icd_ontology_embedding = OntologyEmbedding(GLOBAL_ICD_GRAPH, GLOBAL_ICD_VOCAB)

        self.gbert = GBERT(embedding_size=GAT_CONV_HEADS*GAT_CONV_OUT_CHANNEL, d_model=BERT_IN_OUT, \
            d_ff=BERT_HIDDEN, num_layers=BERT_NUM_LAYERS, heads=BERT_NUM_HEAD_EACH_LAYER)

        self.gbert_pretraining_prediction_layers = nn.ModuleList([\
            pretraining_task_layer_template(BERT_IN_OUT, PRETRAINING_FF_HIDDEN, len(GLOBAL_ICD_LEAF_VOCAB.idx2word)),\
            pretraining_task_layer_template(BERT_IN_OUT, PRETRAINING_FF_HIDDEN, len(GLOBAL_ATC_LEAF_VOCAB.idx2word)),\
            pretraining_task_layer_template(BERT_IN_OUT, PRETRAINING_FF_HIDDEN, len(GLOBAL_ICD_LEAF_VOCAB.idx2word)),\
            pretraining_task_layer_template(BERT_IN_OUT, PRETRAINING_FF_HIDDEN, len(GLOBAL_ATC_LEAF_VOCAB.idx2word))])
    
    def forward(self, atc: torch.Tensor, atc_mask: torch.Tensor, icd: torch.Tensor, icd_mask: torch.Tensor):
        # self.atc_ontology_embedding(): [num_nodes in graph = vocab words(contain special tokens), out_channels = 100]
        # atc list contains: idx and special token like "[CLS]" etc
        atc_ontology_emb = self.atc_ontology_embedding()[atc]
        icd_ontology_emb = self.icd_ontology_embedding()[icd]
        vd, vm = self.gbert(atc_ontology_emb, atc_mask, icd_ontology_emb, icd_mask) # mask => seq_pad_mask

        assert vd.shape[1] == BERT_IN_OUT and vm.shape[1] == BERT_IN_OUT, "Bert should return only cls embedding"

        vd2dx, vd2rx, vm2dx, vm2rx = self.gbert_pretraining_prediction_layers[0](vd), self.gbert_pretraining_prediction_layers[1](vd),\
            self.gbert_pretraining_prediction_layers[2](vm), self.gbert_pretraining_prediction_layers[3](vm)
        return torch.sigmoid(vd2dx), torch.sigmoid(vd2rx), torch.sigmoid(vm2dx), torch.sigmoid(vm2rx)

### <u>Part-6: Pre-training Loop</u>

Below will start pre-training of the GBERT model built and the OntologyEmbedding model.

In [None]:
def pretraining_loop(load_from_disk: bool = False):    
    pretrain_data_loader = DataLoader(PreTrainingDataset.Build(), batch_size=PRETRAINING_BATCH_SIZE, shuffle=True)
    pretraining_model = CollatedModelForPretraining().to(DEVICE)
    if load_from_disk:
        checkpoint = torch.load(os.path.join(GLOBAL_MODELS_PATH, PRETRAINING_MODEL))
        pretraining_model.load_state_dict(checkpoint["model_state_dict"])
    # print("[pretraining_loop]: Num of params: ", sum(p.numel() for p in pretraining_model.parameters()))
    pretraining_optimizer = Adam(pretraining_model.parameters(), lr=0.001)
    pretrain_criterion = nn.BCELoss()
    pretraining_model.train(mode=True)
    for epoch in range(PRETRAINING_EPOCH):
        # print(f"[pretraining_loop]: epoch: {epoch}")
        # print("===================================")
        pretraining_model.train()
        jaccard = []
        pr_auc = []
        tqdm_data_loader = tqdm(pretrain_data_loader)
        for data in tqdm_data_loader:
            pretraining_optimizer.zero_grad()
            atc_list, atc_labels, atc_mask, icd_list, icd_labels, icd_mask = data
            vd2dx, vd2rx, vm2dx, vm2rx = pretraining_model(atc_list, atc_mask, icd_list, icd_mask)
            # print(f"vm2dx.shape: {vm2dx.shape}, icd_labels.shape: {icd_labels.shape}")
            loss = pretrain_criterion(vd2dx, icd_labels) + pretrain_criterion(vd2rx, atc_labels) + \
                pretrain_criterion(vm2dx, icd_labels) + pretrain_criterion(vm2rx, atc_labels)
            loss.backward()
            tqdm_data_loader.set_postfix({f"[pretraining_loop] epoch:{epoch}:: loss":str(loss.item())})
            
            y_pred = torch.zeros(vm2rx.shape).to(DEVICE)
            y_pred[vm2rx > 0.5] = 1
            jaccard_items, pr_auc_item = multi_label_metric(atc_labels, y_pred, vm2rx)[:2]
            jaccard.append(jaccard_items)
            pr_auc.append(pr_auc_item)
            # print("[pretraining_loop]: j, pr_auc, f", multi_label_metric(atc_labels, y_pred, vm2rx))
            
            pretraining_optimizer.step()

        print("[pretraining_loop] Jaccard for this epoch:", sum(jaccard)/len(jaccard))
        print("[pretraining_loop] PR-AUC for this epoch:", sum(pr_auc)/len(pr_auc))
    pretraining_model.train(mode=False)
    return pretraining_model

Code to start pre-training <br>
<span style="color: #055BA6">pretrain_model = pretraining_loop()</span>

Saving the pre-training model <br>
<span style="color: #055BA6">
torch.save({ <br>
    "model_state_dict": pretrain_model.state_dict() <br>
    }, os.path.join(GLOBAL_MODELS_PATH, PRETRAINING_MODEL)) <br>
</span>

### <u>Part-7: Preparing Dataset for Training</u>

Since by now we are well versed with MIMIC-III, I would go directly to main points and point out that for multi-visit data pkl files we need these columns: "SUBJECT_ID", "HADM_ID", "ICD9_CODE", "ATC4". Multi-visit is actually dataset of patients with >= 1 visit and their ICD9 codes and ATC4 codes.

But here "ICD9_CODE" and "ATC4" columns are actually list of codes. 

<ol>
    <li>Read DIAGNOSIS_ICD.csv MIMIC-III file and process to get list format of ICD9_CODE </li>
    <li>Read PRESCRIPTIONS.csv MIMIC-III file and process to get list format of ATC4 codes </li>
    <li>Merge two above Data frames </li>
    <li>Pickle merged Data frame</li>
</ol>

In [None]:
def create_multi_visit_pkl_files(outputPath: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
    # step 1
    diag_df = pd.read_csv(os.path.join(GLOBAL_DATA_PATH, DIAGNOSIS_ICD)).dropna().drop_duplicates()
    diag_df = diag_df.groupby(["SUBJECT_ID", "HADM_ID"])["ICD9_CODE"].apply(list).reset_index()
    
    # step 2
    medications = merge_prescriptions_and_mapping_files()[["SUBJECT_ID", "HADM_ID", "ATC4"]].dropna().drop_duplicates()
    print("[create_multi_visit_pkl_files] Got medication DF")
    medications = medications.groupby(["SUBJECT_ID", "HADM_ID"])["ATC4"].apply(list).reset_index()

    # step 3
    merged_entity = pd.merge(medications, diag_df, on=["SUBJECT_ID", "HADM_ID"])
    print("[create_multi_visit_pkl_files] Merging of Medications and Diagnosis code done")

    # step 4
    merged_entity.to_pickle(outputPath)
    return merged_entity

Code to pkl multi-visit data <br>
<span style="color: #055BA6">create_multi_visit_pkl_files(os.path.join(GLOBAL_DATA_PATH, MULTI_VISIT_PKL))</span>

However, since we need all previous visits per visit, for training input, we will create another file which we call MULTI_VISIT_TEMPORAL_PKL. This file's each row contains "SUBJECT_ID", "HADM_ID", list of list of ICD codes, list of list of ATC codes.

The list of list column is basically list of ICDs/ATC4 which appear before and during that visit. Note patients with >= 2 hospital admission are considered.

steps for creating this file would be:
<ol>
    <li>Load MULTI_VISIT_PKL, ADMISSIONS.csv and SINGLE_VISIT_PKL file for processing. Remove single visit patient from ADMISSIONS data frame</li>
    <li>Find out from ADMISSIONS.csv HADM_ID's which are for given SUBJECT_ID and are temporally before the given HADM_ID </li>
    <li>Now merge MULTI_VISIT_PKL with the Data frame obtained above. This would give us "All the HADM_IDs in temporal order with ATC4 and ICD9_CODE of temporally < current HADM_ID". We call the new Data frame obtained as "multi_visit_temporal"</li>
    <li>For simplification, we then disintegrate "multi_visit_temporal" DF obtained from prev step. This helps in applying list method(this can be improved). Disintegrated DFs are for ICD9_CODE and ATC4 codes. We merge these two DFs to get "multi_visit_temporal"(order not reqd between ICD9_CODE and ATC4)</li>
    <li>After above processing we have codes for visits < HADM_ID of each row. So we will again join with MULTI_VISIT_PKL to get the latest one codes available too and insert at 0th place for the list of ICD9_CODE obtained so far. </li>
    <li>check and report error if any length 1 array in ICD9_CODE or ATC4 code present</li>
    <li>Now we need to add another factor which we will call Loss factor. This factor is "1/(T-1)" present in equation (10) of the G-BERT paper. This would scale the training loss properly.</li>
</ol>


In [None]:
def create_multi_visit_temporal_pkl_files(outputPath: str) -> Tuple[pd.DataFrame, pd.DataFrame]:
    # step 1
    visit_record_key = ["SUBJECT_ID", "HADM_ID", "ICD9_CODE", "ATC4"]
    multi_visit = pd.read_pickle(os.path.join(GLOBAL_DATA_PATH, MULTI_VISIT_PKL))[visit_record_key[1:]]
    admission_df = pd.read_csv(os.path.join(GLOBAL_DATA_PATH, ADMISSIONS))[visit_record_key[:2] + ["ADMITTIME"]]
    single_visit_sub_id = pd.read_pickle(os.path.join(GLOBAL_DATA_PATH, SINGLE_VISIT_PKL))["SUBJECT_ID"]
    admission_df = admission_df.loc[~admission_df["SUBJECT_ID"].isin(single_visit_sub_id)]
    
    # Step 2
    self_joined = pd.merge(admission_df, admission_df, on=["SUBJECT_ID"])
    self_joined_admission_df = self_joined.loc[self_joined["ADMITTIME_x"] > self_joined["ADMITTIME_y"]]
    # HADM_ID_y/HADM_ID -- represents temporally < current_HADM_ID
    self_joined_admission_df = self_joined_admission_df.rename(columns={"HADM_ID_y":"HADM_ID", "HADM_ID_x":"current_HADM_ID"})
    # Columns: SUBJECT_ID, current_HADM_ID, ADMITTIME_x, HADM_ID, ADMITTIME_y

    # Step 3
    multi_visit_temporal = pd.merge(self_joined_admission_df, multi_visit, on=["HADM_ID"])
    multi_visit_temporal = multi_visit_temporal.drop(columns=["HADM_ID"])
    multi_visit_temporal = multi_visit_temporal.rename(columns={"current_HADM_ID": "HADM_ID"})

    # Step 4
    temporal_icd = multi_visit_temporal[visit_record_key[:3]].groupby(by=["SUBJECT_ID", "HADM_ID"])["ICD9_CODE"].apply(list).reset_index()
    temporal_atc = multi_visit_temporal[visit_record_key[:2] + ["ATC4"]].groupby(by=["SUBJECT_ID", "HADM_ID"])["ATC4"].apply(list).reset_index()
    multi_visit_temporal = pd.merge(temporal_icd, temporal_atc, on=["SUBJECT_ID", "HADM_ID"])   
    
    # Step 5
    multi_visit_temporal = pd.merge(multi_visit_temporal, multi_visit, on=["HADM_ID"], suffixes=("", "_curr"))
    multi_visit_temporal.apply(lambda row: row["ICD9_CODE"].insert(0, row["ICD9_CODE_curr"]), axis=1)
    multi_visit_temporal.apply(lambda row: row["ATC4"].insert(0, row["ATC4_curr"]), axis=1)

    # Step 6
    multi_visit_temporal = multi_visit_temporal.drop(columns=["ICD9_CODE_curr"])
    multi_visit_temporal = multi_visit_temporal.drop(columns=["ATC4_curr"])
    def row_checker(row):
        assert len(row["ICD9_CODE"]) > 1 and len(row["ATC4"]) > 1, "[create_multi_visit_temporal_pkl_files] Some issue in data"
    multi_visit_temporal.apply(row_checker, axis=1)

    # Step 7
    T_minus_one = multi_visit_temporal["SUBJECT_ID"]
    T_minus_one = T_minus_one.groupby(T_minus_one).count().to_frame(name="T_1").reset_index()
    multi_visit_temporal = pd.merge(multi_visit_temporal, T_minus_one, on=["SUBJECT_ID"])

    multi_visit_temporal.to_pickle(outputPath)

    return multi_visit_temporal
    

Code to pkl multi-visit data <br>
<span style="color: #055BA6">create_multi_visit_temporal_pkl_files(os.path.join(GLOBAL_DATA_PATH, MULTI_VISIT_TEMPORAL_PKL))</span>

### <u>Part-8: Dataset and CollatedModel for Fine-tuning (training)</u>
To get final training/validation and testing procedures we need to proceed in step wise fashion.

#### <u>Step 1:</u>
Now similar to previous steps of pre-training, here we would define a training data-set class. 

Steps to create training data-set:
<ol>
    <li>Use Dataset class(from pytorch) with Build method specifying MULTI_VISIT_TEMPORAL_PKL</li>
    <li>for __getitem__ we would return list of list of padded sequence for both ICD9_CODE and ATC4 code</li>
</ol>

In [None]:
class TrainingDatasetWithoutMasking(Dataset):
    
    def __init__(self, df: pd.DataFrame, atc_vocab: Vocab, atc_leaf_vocab: Vocab, icd_vocab: Vocab, icd_leaf_vocab: Vocab, max_atc_len: int, \
        max_icd_len: int):
        super(TrainingDatasetWithoutMasking, self).__init__()
        self.df = df
        self.maxAtcLen = max_atc_len
        self.maxIcdLen = max_icd_len
        self.atcVocab = atc_vocab
        self.atcLeafVocab = atc_leaf_vocab
        self.icdVocab = icd_vocab
        self.icdLeafVocab = icd_leaf_vocab
    
    @classmethod
    def Build(self):
        # testing max len -> 92 since ATC4 max for 100 was in actually single-visit only
        max_atc_len = TRAINING_MAX_ATC_LEN # can also set ATC code max 
        max_icd_len = TRAINING_MAX_ICD_LEN # can also set ATC code max 
        multi_visit_temporal_df: pd.DataFrame = pd.read_pickle(os.path.join(GLOBAL_DATA_PATH, MULTI_VISIT_TEMPORAL_PKL))

        # since now the __len__ contains indexing based on SUBJECT_ID
        indexed = multi_visit_temporal_df["SUBJECT_ID"]
        indexed = indexed.groupby(indexed).count().to_frame("TRASH_COLUMN").reset_index().drop(columns=["TRASH_COLUMN"]).reset_index() # index and SUBJECT_ID
        multi_visit_temporal_df = pd.merge(multi_visit_temporal_df, indexed, on=["SUBJECT_ID"])

        # print("[TrainingDatasetWithoutMasking::Build] type ", multi_visit_temporal_df["ICD9_CODE"]\
        #     .apply(lambda item: max([len(ls) for ls in item])).max())
        getMaxLengthOfListInList = lambda item: max([len(ls) for ls in item])
        assert max_icd_len >= multi_visit_temporal_df["ICD9_CODE"].apply(getMaxLengthOfListInList).max() and \
            max_atc_len >= multi_visit_temporal_df["ATC4"].apply(getMaxLengthOfListInList).max(),\
                "[TrainingDatasetWithoutMasking::Build] Max Seq length is less"
        
        return TrainingDatasetWithoutMasking(multi_visit_temporal_df, GLOBAL_ATC_VOCAB, GLOBAL_ATC_LEAF_VOCAB, GLOBAL_ICD_VOCAB, \
            GLOBAL_ICD_LEAF_VOCAB, max_atc_len, max_icd_len)
    
    def __len__(self):
        # self.df.shape[0]
        # can see if the below works or not, if not then use above one
        return self.df["index"].max() + 1 # need to change the logic of build for this
    
    def _process_(self, atc:list[list[str]], icd:list[list[str]]):
        assert len(atc) == len(icd), "Should have equal number of temporal visit"
        
        atc_word2idx_mask = torch.ones((len(atc), self.maxAtcLen), dtype=torch.bool, device=DEVICE) # ones due to pytorch convention
        atc_word2idx_list = torch.zeros((len(atc), self.maxAtcLen), dtype=torch.long, device=DEVICE)
        # CLS:
        atc_word2idx_mask[:, 0] = 0
        atc_word2idx_list[:, 0] = self.atcVocab.word2idx.get("[CLS]")
        for i, atc_codes in enumerate(atc):
            for j, atc_code in enumerate(atc_codes):
                idx = self.atcVocab.word2idx.get(atc_code) # need this idx to be over all graph nodes
                if idx == None:
                    print("PreTrainingDataset: Some unseen ATC code", atc_code)
                else:
                    atc_word2idx_mask[i, j+1] = 0
                    atc_word2idx_list[i, j+1] = idx

        icd_word2idx_mask = torch.ones((len(icd), self.maxIcdLen), dtype=torch.bool, device=DEVICE)
        icd_word2idx_list = torch.zeros((len(icd), self.maxIcdLen), dtype=torch.long, device=DEVICE)
        # CLS:
        icd_word2idx_mask[:, 0] = 0
        icd_word2idx_list[:, 0] = self.icdVocab.word2idx.get("[CLS]")
        for i, icd_codes in enumerate(icd):
            for j, icd_code in enumerate(icd_codes):
                idx = self.icdVocab.word2idx.get(icd_code)
                icd_word2idx_mask[i, j+1] = 0
                if idx == None:
                    # print("PreTrainingDataset: Some unseen icd code", icd_code)
                    # We are going with [UNK] token since we need top 2k diagnosis only
                    icd_word2idx_list[i, j+1] = self.icdVocab.word2idx.get("[UNK]")
                else:
                    icd_word2idx_list[i, j+1] = idx
    

        return atc_word2idx_list, get_multi_hot(atc, self.atcLeafVocab), atc_word2idx_mask,\
            icd_word2idx_list, get_multi_hot(icd, self.icdLeafVocab), icd_word2idx_mask 


    def __getitem__(self, index: int):
        """
            index: represents a patient
            0th index icd/atc always represent the most recent diagnosis/medication
            can pad and covert to tensor later as well(collate_fn), if any issue here then can try that way
        """
        patient_data: pd.DataFrame = self.df.loc[self.df["index"] == index]
        T_1 = patient_data["T_1"].iloc[0]
        atc_padded_seq = []
        atc_seq_mask = []
        atc_multi_hot = []
        icd_padded_seq = []
        icd_seq_mask = []
        for _, row in patient_data.iterrows():
            atc_seq, atc_mh, atc_mask, icd_seq, _, icd_mask = self._process_(row["ATC4"], row["ICD9_CODE"])
            atc_padded_seq.append(atc_seq)
            atc_multi_hot.append(atc_mh)
            atc_seq_mask.append(atc_mask)

            icd_padded_seq.append(icd_seq)
            icd_seq_mask.append(icd_mask)

        return atc_padded_seq, atc_multi_hot, atc_seq_mask, icd_padded_seq, icd_seq_mask, T_1


        

#### <u>Step 2:</u>
In this step we define the transformation function as defined by equation 9 of paper. 

In [None]:
prediction_layer = lambda d_model, vocab_size: nn.Sequential(nn.Linear(3*d_model, vocab_size), nn.Sigmoid())

#### <u>Step 3:</u>
In this step we define a Collated Model just like we did in pre-training. 

In [None]:
class CollatedModelForTraining(nn.Module):
    def __init__(self, load_prediction_layer: bool = False):
        super(CollatedModelForTraining, self).__init__()
        pretraining_model = CollatedModelForPretraining()
        checkpoint = torch.load(os.path.join(GLOBAL_MODELS_PATH, PRETRAINING_MODEL))
        pretraining_model.load_state_dict(checkpoint["model_state_dict"])
        self.atc_ontology_embedding = pretraining_model.atc_ontology_embedding
        self.icd_ontology_embedding = pretraining_model.icd_ontology_embedding
        self.gbert = pretraining_model.gbert

        if load_prediction_layer:
            training_checkpoint = torch.load(os.path.join(GLOBAL_MODELS_PATH, TRAINING_MODEL))
            # It would be recursive so never send True from here
            dummy_train_model = CollatedModelForTraining(load_prediction_layer=False)
            dummy_train_model.load_state_dict(training_checkpoint["model_state_dict"])
            self.prediction_layer = dummy_train_model.prediction_layer
        else:
            self.prediction_layer = prediction_layer(BERT_IN_OUT, len(GLOBAL_ATC_LEAF_VOCAB.idx2word))
    
    def forward(self, atc: torch.Tensor, atc_mask: torch.Tensor, icd: torch.Tensor, icd_mask: torch.Tensor):
        # self.atc_ontology_embedding(): [num_nodes in graph = vocab words(contain special tokens), out_channels = 100]
        # atc/icd list of list where each innermost list contains: idx and special token idx like "[CLS]" etc
        assert len(icd) == len(atc), "==> Input incorrect"
        icd_ontology_embeddings = self.icd_ontology_embedding()[icd]
        atc_ontology_embeddings = self.atc_ontology_embedding()[atc]
        assert len(icd) == len(icd_ontology_embeddings), "==> Issue in embeddings"
        assert len(atc) == len(atc_ontology_embeddings), "==> Issue in embeddings"

        vd, vm = self.gbert(atc_ontology_embeddings, atc_mask, icd_ontology_embeddings, icd_mask) # mask => seq_pad_mask
        # 0th index represents the most latest, rest are random ordered
        current_visit_vd = vd[0]

        historical_visit_vd = torch.zeros(vd[0].shape)
        historical_visit_vm = torch.zeros(vm[0].shape)
        cnt = 0
        for i in range(1, len(icd)):
            cnt += 1
            historical_visit_vd += vd[i]
            historical_visit_vm += vm[i]

        assert historical_visit_vd.shape[0] == BERT_IN_OUT and historical_visit_vm.shape[0] == BERT_IN_OUT, "Bert should return only cls embedding"
        assert cnt > 0, "Some issue in input"
        historical_visit_vd /= cnt
        historical_visit_vm /= cnt

        concat_visit = torch.cat((current_visit_vd, historical_visit_vd, historical_visit_vm), dim=0)

        return self.prediction_layer(concat_visit)

### <u>Part-9: Training Loop</u>

Below will start "training" of the GBERT model built and the OntologyEmbedding model. In this step we would define training-validation-testing ratio . Acc to paper we have 0.6:0.2:0.2 ratio. So we will follow that. 

In [None]:
def get_train_val_dataloader()-> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
    orig_data_set = TrainingDatasetWithoutMasking.Build()
    split_for_train = int(0.6*len(orig_data_set))
    split_for_validation = int(0.2*len(orig_data_set))

    train_data_set, val_data_set, test_data_set = random_split(orig_data_set, [split_for_train, \
        split_for_validation, len(orig_data_set) - split_for_train - split_for_validation])

    train_data_loader = DataLoader(train_data_set, batch_size=1, shuffle=True) # batch size is 1 since each 
    # row has list of list. If it is fast then will increase batch size ==> will need code changes
    # will do it non conventionally ==> will loop batch times and loss.backward on 10 loops manually
    # this will require minimum code changes. Otherwise I will have to work out matrices a lot
    val_data_loader = DataLoader(val_data_set, batch_size=1) 
    test_data_loader = DataLoader(test_data_set, batch_size=1) 

    return train_data_loader, val_data_loader, test_data_loader

Defining a training loop like we did in pre-training

In [None]:
def trainingAndValidation_loop(data_loader: torch.utils.data.DataLoader, eval_mode: bool = False, \
    load_fine_tuning:bool = False):    

    model = CollatedModelForTraining(load_fine_tuning).to(DEVICE)
    if eval_mode:
        checkpoint = torch.load(os.path.join(GLOBAL_MODELS_PATH, TRAINING_MODEL))
        model.load_state_dict(checkpoint["model_state_dict"])

    # print("[trainingAndValidation_loop]: model params", sum(p.numel() for p in model.parameters()))
    if eval_mode:
        model.eval()
    else:
        model.train(mode=True)
        training_optimizer = Adam(model.parameters(), lr=0.001)
        train_criterion = nn.BCELoss()
    epoch = TRAINING_EPOCH if not eval_mode else 1
    for epoch in range(epoch):
        # print(f"[trainingAndValidation_loop]: epoch: {epoch}")
        # print("===================================")
        epoch_jaccard_list = []
        epoch_pr_auc_list = []
        if not eval_mode:
            epoch_loss = []
        tqdm_iterator = tqdm(data_loader)
        for patient_data in tqdm_iterator:
            if not eval_mode:
                training_optimizer.zero_grad()
                patient_loss: torch.tensor = None
            jaccard_list = []
            pr_auc_list = []
            atc_padded_list, atc_multi_hot_labels, atc_seq_mask, icd_padded_list, icd_seq_mask, t_1 = patient_data

            for data_count in range(t_1):
                atc_list = atc_padded_list[data_count].squeeze(dim=0)
                atc_labels = atc_multi_hot_labels[data_count].squeeze(dim=0)
                atc_mask = atc_seq_mask[data_count].squeeze(dim=0)

                icd_list = icd_padded_list[data_count].squeeze(dim=0)
                icd_mask = icd_seq_mask[data_count].squeeze(dim=0)

                # print(f"[trainingAndValidation_loop]: atc_list.shape: {atc_list.shape}, icd_list.shape: {icd_list.shape}")
                # print(f"[trainingAndValidation_loop]: atc_labels.shape: {atc_labels.shape}")
                # print(f"[trainingAndValidation_loop]: atc_mask.shape: {atc_mask.shape}, icd_mask.shape: {icd_mask.shape}")

                predicted_rx = model(atc_list, atc_mask, icd_list, icd_mask)

                # print(f"vm2dx.shape: {vm2dx.shape}, icd_labels.shape: {icd_labels.shape}")
                if not eval_mode:
                    if patient_loss == None:
                        patient_loss = train_criterion(predicted_rx, atc_labels[0])
                    else: 
                        patient_loss += train_criterion(predicted_rx, atc_labels[0])
                
                y_pred = torch.zeros(predicted_rx.shape).to(DEVICE)
                y_pred[predicted_rx > 0.5] = 1
                # print("[trainingAndValidation_loop]: j, r, f, p, p_1, p_2, p_3, roc_auc ",\
                    #  multi_label_metric(atc_labels[0].unsqueeze(dim=0), y_pred.unsqueeze(dim=0), predicted_rx.unsqueeze(dim=0)))
                jaccard_item, pr_auc_item = multi_label_metric(atc_labels[0].unsqueeze(dim=0), y_pred.unsqueeze(dim=0), predicted_rx.unsqueeze(dim=0))[:2]
                jaccard_list.append(jaccard_item)
                pr_auc_list.append(pr_auc_item)

            if not eval_mode:           
                patient_loss = patient_loss / t_1
                patient_loss.backward()
                training_optimizer.step()
                epoch_loss.append(patient_loss.item())
                tqdm_iterator.set_postfix({f"[trainingAndValidation_loop] epoch:{epoch}:: patient loss":str(patient_loss.item())})

            patient_avg_jacc = sum(jaccard_list)/len(jaccard_list)
            patient_avg_pr_auc = sum(pr_auc_list)/len(pr_auc_list)
            epoch_jaccard_list.append(patient_avg_jacc)
            epoch_pr_auc_list.append(patient_avg_pr_auc)
        if not eval_mode:
            print("[trainingAndValidation_loop]: Average loss for this epoch", sum(epoch_loss)/len(epoch_loss))
        print("[trainingAndValidation_loop]: Average jaccard for this epoch", sum(epoch_jaccard_list)/len(epoch_jaccard_list))
        print("[trainingAndValidation_loop]: Average PR-AUC for this epoch", sum(epoch_pr_auc_list)/len(epoch_pr_auc_list))
    if not eval_mode:
        model.train(mode=False)
    return model

Code to start training <br>
<span style="color: #055BA6">
train_loader, val_loader, test_loader = get_train_val_dataloader() <br>
model = trainingAndValidation_loop(train_loader)
</span>

Saving the training model <br>
<span style="color: #055BA6">
torch.save({ <br>
    "model_state_dict": model.state_dict() <br>
    }, os.path.join(GLOBAL_MODELS_PATH,TRAINING_MODEL)) <br>
</span>

### <u>Part-10: Testing (and validation) Loop</u>

Almost all the logic remains same in training_loop, main thing that changes is : rather than creating a new model we load it from the disk. For this we would simply pass a flag to <u><i>training_loop</i></u> and that flag would be <u><i>eval_mode: bool</i></u>

Validation code --> <u><i>eval_mode: bool</i></u> is True here also. Just send different data_loader.

Code to start evaluation <br>
<span style="color: #055BA6">
train_loader, val_loader, test_loader = get_train_val_dataloader() <br>
model = trainingAndValidation_loop(test_loader, eval_mode = True)
</span>

### <u>Part-11: Pre-Training, Training and Validation as per Original Paper setup </u>
Author have mentioned to alternate 5 epochs between pre-training and fine-tuning(training) for 15 times to stabilize the training procedure

In [None]:
def main():
    def pre_train(load_from_disk:bool = False, iteration:int = 0):
        print(f"=================Iteration: {iteration}==============================")
        tracemalloc.start()
        pretrain_model = pretraining_loop(load_from_disk)
        _, peek = tracemalloc.get_traced_memory()
        print("Peek memory during pre-training", peek)
        tracemalloc.stop()
        torch.save({ 
            "model_state_dict": pretrain_model.state_dict() 
            }, os.path.join(GLOBAL_MODELS_PATH, PRETRAINING_MODEL)) 

    train_loader, _, test_loader = get_train_val_dataloader()
    def train(load_fine_tuning:bool = False):
        tracemalloc.start()
        train_model = trainingAndValidation_loop(train_loader, load_fine_tuning=load_fine_tuning)
        _, peek = tracemalloc.get_traced_memory()
        print("Peek memory during training", peek)
        tracemalloc.stop()
        torch.save({
            "model_state_dict": train_model.state_dict()
            }, os.path.join(GLOBAL_MODELS_PATH,TRAINING_MODEL))

    # pre_train()
    # train()
    # for iteration in range(8, TRINING_ITERATION):
        # pre_train(True, iteration)
        # train(True)
    train(True)

    print(f"=================Evaluation==============================")
    trainingAndValidation_loop(test_loader, load_fine_tuning=True, eval_mode=True)

In [None]:
main()

### <u>Part-xx: Baselines</u>

#### <u>[I] Logistic Regression:</u>
Input to logistic regression according to GBERT paper is multi-hot vector of each visit. Assuming that GBERT paper do not sum up previous visit multi-hot vectors and predicts only based on the current visit multi-hot vector.

##### <u>Step 1: Data input:</u>
<ol>
    <li>We will use MULTI_VISIT_PKL and get ICD9_CODE and ATC4 code for each visit</li>
    <li>We will convert visit to multi-hot vector and finally store them in numpy matrix</li>
    <li>Ratio of breakdown is assumed to be 0.8 for training and 0.2 for test</li>
</ol>

In [None]:
def data_loader_for_logistic_regression()-> Tuple[np.ndarray, np.ndarray]:
    df = pd.read_pickle(os.path.join(GLOBAL_DATA_PATH, MULTI_VISIT_PKL))
    df["ICD9_CODE"] = df["ICD9_CODE"].apply(lambda ls: get_multi_hot(ls, GLOBAL_ICD_LEAF_VOCAB).numpy())
    df["ATC4"] = df["ATC4"].apply(lambda ls: get_multi_hot(ls, GLOBAL_ATC_LEAF_VOCAB).numpy())
    X = np.array(df["ICD9_CODE"].values.tolist(), dtype=np.uint8)
    Y = np.array(df["ATC4"].values.tolist(), dtype=np.uint8)
    index = int(0.8 * X.shape[0])
    return X[:index],Y[:index], X[index:], Y[index:]

Code to load logistic regression data<br>
<span style="color: #055BA6">X_train, Y_train, X_test, Y_test = data_loader_for_logistic_regression()</span>

##### <u>Step 2: Logistic Regression:</u>
<ol>
    <li>We are using sklearn package to get one-vs-rest Binary relevance based Logistic regression with L2 regularization</li>
    <li>TODO: Need to add grid search for optimal Logistic regression</li>
</ol>

In [None]:
def run_logistic_regression():
    X_train, Y_train, X_test, Y_test = data_loader_for_logistic_regression()
    params = {
        "estimator__penalty": ["l2"],
        "estimator__C": np.linspace(0.00002, 1, 100)
    }
    clf = OneVsRestClassifier(LogisticRegression())
    lr_gs = GridSearchCV(clf, params, verbose=1).fit(X_train, Y_train)

    Y_pred = lr_gs.predict(X_test)
    Y_pred_prob = lr_gs.predict_proba(X_test)
    print(multi_label_metric(torch.from_numpy(Y_test), torch.from_numpy(Y_pred), torch.from_numpy(Y_pred_prob)))


In [31]:
run_logistic_regression()

Code to load logistic regression data<br>
<span style="color: #055BA6">run_logistic_regression()</span>