In [None]:
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
from torch.utils.data.dataset import Dataset
import os
import torch
import torch.nn as nn
import pytorch_pretrained_bert as Bert
import sklearn.metrics as skm
import math
from torch.utils.data.dataset import Dataset
import random
import numpy as np
import torch
import time
import transformers

In [None]:
#Data stuff
diagnoses_file_path = r'mimic_data/diagnoses_icd.csv.gz'
map_file_path = r'TransformerEHR/data/physionet.org/files/mimiciii-demo/1.4/D_ICD_DIAGNOSES.csv'


diagnoses_df = pd.read_csv(diagnoses_file_path)
print(diagnoses_df.columns)
map_df = pd.read_csv(map_file_path)

#list of patient id's that have been diagnosed with something
#make everything sequential and not patient_id key based
patient_ids = diagnoses_df['subject_id'].unique().tolist()

#2d array where each nested list is the hadm_id for each visit
visits = diagnoses_df.groupby('subject_id')['hadm_id'].apply(lambda x: list(set(x))).tolist()

#3d array contains a list of visits with respective ICD9 code per visit
patient_visits = (
    diagnoses_df.groupby(['subject_id', 'hadm_id'])['icd_code'].apply(list).groupby(level=0).apply(list).tolist()
)

#dict of {icd9_code : short_title}
#not all icd9_codes which are present in DIAGNOSES_ICD.csv are present in D_ICD_DIAGNOSES.csv, so not all codes will have a title
icd9_to_title = pd.Series(map_df['short_title'].values, index=map_df['icd9_code']).to_dict()

print("Patient ID:", patient_ids[53])
print("num of visits for patient: " , len(visits[53]))
for visit in range(len(visits[53])):
    print(f"\t{visit}-th visit id:", visits[53][visit])
    print(f"\t{visit}-th visit diagnosis codes:", patient_visits[53][visit])
    print(f"\t{visit}-th visit diagnosis short titles:",
[icd9_to_title.get(label, label) for label in patient_visits[53][visit]])


In [None]:
#Descriptive Statistics

#Total rows with icd 9/10
count_icd_version_10 = (diagnoses_df['icd_version'] == 10).sum()
count_icd_version_9 = (diagnoses_df['icd_version'] == 9).sum()

print("Number of rows with icd_version = 10:", count_icd_version_10)
print("Number of rows with icd_version = 9:", count_icd_version_9)

#Num of unique ICD 9/10 codes
unique_icd9_codes = diagnoses_df[diagnoses_df['icd_version'] == 9]['icd_code'].nunique()
unique_icd10_codes = diagnoses_df[diagnoses_df['icd_version'] == 10]['icd_code'].nunique()

print("Number of unique ICD-9 codes:", unique_icd9_codes)
print("Number of unique ICD-10 codes:", unique_icd10_codes)

#num patients with atleast 1 ICD 10 code
icd10_df = diagnoses_df[diagnoses_df['icd_version'] == 9]
unique_patients_with_icd10 = icd10_df['subject_id'].unique()
num_patients_with_icd10 = len(unique_patients_with_icd10)
print("Number of patients with at least one ICD-10 code:", num_patients_with_icd10)

#num patients with both ICD 9 / 10 codes
grouped = diagnoses_df.groupby('subject_id')['icd_version'].agg(set)
patients_with_both = grouped[grouped.apply(lambda x: {9, 10}.issubset(x))]
print("Number of patients with both ICD-9 and ICD-10 codes:", len(patients_with_both))

# num pateients with ONLY ICD 10 codes
patient_versions = diagnoses_df.groupby('subject_id')['icd_version'].unique()
patients_with_only_icd10 = patient_versions[patient_versions.apply(lambda x: set(x) == {9})]
num_patients_only_icd10 = len(patients_with_only_icd10)
print("Number of patients with only ICD-10 codes:", num_patients_only_icd10)

In [None]:
# Total num of usable patients
patient_versions = diagnoses_df.groupby('subject_id')['icd_version'].unique()

patients_with_only_icd9 = patient_versions[patient_versions.apply(lambda x: set(x) == {9})].index

icd9_patients_df = diagnoses_df[diagnoses_df['subject_id'].isin(patients_with_only_icd9)]
visit_counts = icd9_patients_df.groupby('subject_id').size()

patients_more_than_three_visits = visit_counts[visit_counts > 3]
num_patients = len(patients_more_than_three_visits)
print("Number of patients with only ICD-9 codes and more than 3 visits:", num_patients)

In [None]:
#preparing the label and feature splits 

patients_with_icd10 = diagnoses_df[diagnoses_df['icd_version'] == 10]['subject_id'].unique()

#
icd9_only_df = diagnoses_df[~diagnoses_df['subject_id'].isin(patients_with_icd10)]

#Finds patients with only icd9 codes
icd9_only_df = icd9_only_df[icd9_only_df['icd_version'] == 9]
visit_counts = icd9_only_df.groupby('subject_id')['hadm_id'].nunique()

patients_more_than_three_visits = visit_counts[visit_counts > 3].index

#final DataFrame of patients with only ICD-9 codes and more than three visits
final_df = icd9_only_df[icd9_only_df['subject_id'].isin(patients_more_than_three_visits)]

patient_visits = final_df.groupby(['subject_id', 'hadm_id'])['icd_code'].apply(list).reset_index()
patient_visits = patient_visits.groupby('subject_id')['icd_code'].apply(list)

features = []
labels = []
splits = []

for visits in patient_visits:
    if len(visits) > 3:
        split_index = len(visits) // 2
        splits.append(split_index)
        features.append([visit for visit in visits[:split_index]])
        labels.append([code for sublist in visits[split_index:] for code in sublist])

#output for one set of features and labels
if features and labels:
    print("Split Index: ", splits[:10])
    print("Example Features: ", features[0])
    print("Example Labels: " , labels[0])

In [None]:
from transformers import BertTokenizer
class ICDCodeDataset(Dataset):
    def __init__(self, features, labels, tokenizer, max_len=512):
        self.features = features
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.features)

    def __getitem__(self, index):
        #Encode features
        feature_encoded = self.tokenizer.encode_plus(
            ' '.join(self.features[index]),
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        #Encode labels
        labels_encoded = self.tokenizer.encode_plus(
            ' '.join(self.labels[index]),
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': feature_encoded['input_ids'].flatten(),
            'attention_mask': feature_encoded['attention_mask'].flatten(),
            'labels': labels_encoded['input_ids'].flatten()
        }

#initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

#create dataset
dataset = ICDCodeDataset(features, labels, tokenizer)

In [None]:
from torch.utils.data import random_split

#train and test split
total_samples = len(dataset)
train_size = int(0.8 * total_samples)
test_size = total_samples - train_size 


train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
print("FULL Dataset: {}".format(len(dataset)))
print("TRAIN Dataset: {}".format(len(train_dataset)))
print("TEST Dataset: {}".format(len(test_dataset)))

In [None]:
#dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [None]:
global_params = {
    'batch_size': 64,
    'gradient_accumulation_steps': 1,
    'device': 'cuda:0',
    'output_dir': '',  # output dir
    'best_name': '', # output model name
    'save_model': True,
    'max_len_seq': 100,
    'max_age': 110,
    'month': 1,
    'age_symbol': None,
    'min_visit': 5
}

feature_dict = {
    'age': False,
    'seg': False,
    'posi': True
}


optim_config = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}

model_config = {
    'vocab_size': 1042, # number of disease + symbols for word embedding
    'hidden_size': 300, # word embedding and seg embedding hidden size the embeddings we get are 300 in length
    # 'seg_vocab_size': 2, # number of vocab for seg embedding
    # 'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding
    'max_position_embedding': global_params['max_len_seq'], # maximum number of tokens
    'hidden_dropout_prob': 0.2, # dropout rate
    'num_hidden_layers': 6, # number of multi-head attention layers required
    'num_attention_heads': 12, # number of attention heads
    'attention_probs_dropout_prob': 0.22, # multi-head attention dropout rate
    'intermediate_size': 512, # the size of the "intermediate" layer in the transformer encoder
    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'initializer_range': 0.02, # parameter weight initializer range
}

class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super(BertConfig, self).__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings = config.get('max_position_embedding'),
            initializer_range=config.get('initializer_range'),
        )
        self.seg_vocab_size = config.get('seg_vocab_size')
        self.age_vocab_size = config.get('age_vocab_size')


In [None]:
from embeddings.get_embeddings import ICD9Embeddings

class BertEmbeddings(nn.Module):
    def __init__(self, config, feature_dict):
        super(BertEmbeddings, self).__init__()
        self.feature_dict = feature_dict

        # TODO maybe load these as part of the Dataset so we don't have to do extra lookups 
        self.icd9_embeddings = ICD9Embeddings("./embeddings/ic9_embeddings.txt")
        
        self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size).\
            from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size))

        self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, word_ids, posi_ids=None,):

        if posi_ids is None:
            posi_ids = torch.zeros_like(word_ids)


        # TODO how to combine the embeddings of all the words
        word_embed = self.icd9_embeddings(word_ids)

        # TODO how to pass in positional embeddings?
        posi_embeddings = self.posi_embeddings(posi_ids)
        
        embeddings = word_embed
        
        if self.feature_dict['posi']:
            embeddings = embeddings + posi_embeddings
        
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings
    
    def _init_posi_embedding(self, max_position_embedding, hidden_size):
        def even_code(pos, idx):
            return np.sin(pos/(10000**(2*idx/hidden_size)))

        def odd_code(pos, idx):
            return np.cos(pos/(10000**(2*idx/hidden_size)))

        # initialize position embedding table
        lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32)

        # reset table parameters with hard encoding
        # set even dimension
        for pos in range(max_position_embedding):
            for idx in np.arange(0, hidden_size, step=2):
                lookup_table[pos, idx] = even_code(pos, idx)
        # set odd dimension
        for pos in range(max_position_embedding):
            for idx in np.arange(1, hidden_size, step=2):
                lookup_table[pos, idx] = odd_code(pos, idx)

        return torch.tensor(lookup_table)
    

# this should be fairly hands off we should just need to adjust the config parameters
class BertModel(Bert.modeling.BertPreTrainedModel):
    def __init__(self, config, feature_dict):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config, feature_dict)
        self.encoder = Bert.modeling.BertEncoder(config=config)
        self.pooler = Bert.modeling.BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, posi_ids=None, attention_mask=None, output_all_encoded_layers=True):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if posi_ids is None:
            posi_ids = torch.zeros_like(input_ids)
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, posi_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=output_all_encoded_layers)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output


# this should be fairly hands off we should just need to adjust the config parameters
class BertForMultiLabelPrediction(Bert.modeling.BertPreTrainedModel):
    def __init__(self, config, num_labels, feature_dict):
        super(BertForMultiLabelPrediction, self).__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel(config, feature_dict)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_labels)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, labels=None):
        _, pooled_output = self.bert(input_ids, age_ids ,seg_ids, posi_ids, attention_mask,
                                     output_all_encoded_layers=False)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        if labels is not None:
            loss_fct = nn.MultiLabelSoftMarginLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
            return loss, logits
        else:
            return logits

### SETUP MODEL

In [None]:
conf = BertConfig(model_config)
model = BertForMultiLabelPrediction(conf, model_config['vocab_size'], feature_dict)