In [1]:
# IMPORTS
import torch
from torch import nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from tokenizers import  Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
import glob
from os.path import exists
import os

import pickle5 as pickle
import wandb

import pandas as pd
import numpy as np

In [2]:
class ContrastiveLoss(nn.Module):
    def __init__(self, batch_size, device, temperature=0.5):
        super().__init__()
        self.batch_size = batch_size
        self.device = device
        self.register_buffer("temperature", torch.tensor(temperature))
        self.register_buffer("negatives_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float())
            
    def forward(self, emb_i, emb_j):
        """
        emb_i and emb_j are batches of embeddings, where corresponding indices are pairs
        z_i, z_j as per SimCLR paper
        """
        z_i = F.normalize(emb_i, dim=1)
        z_j = F.normalize(emb_j, dim=1)

        representations = torch.cat([z_i, z_j], dim=0)
        
        similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
        
        sim_ij = torch.diag(similarity_matrix, self.batch_size)
        sim_ji = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([sim_ij, sim_ji], dim=0)
        
        nominator = torch.exp(positives / self.temperature)
        denominator = self.negatives_mask.to(self.device) * torch.exp(similarity_matrix / self.temperature)
    
        loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))
        loss = torch.sum(loss_partial) / (2 * self.batch_size)
        return loss

In [3]:
max_length = {'demographics':5, 'lab_tests':400, 'vitals':31, 'medications':255}
# max_length = {'demographics':5, 'lab_tests':400, 'vitals':200, 'medications':255}

class MyDataset(Dataset):

    def __init__(self, df, tokenizer, max_length, pred_window=2, observing_window=3):
        self.df = df
        self.tokenizer = tokenizer
        self.observing_window = observing_window
        self.pred_window = pred_window
        self.max_length = max_length
        self.max_length_diags = 35

        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):

        return self.make_matrices(idx)
    
    def tokenize(self, text, max_length): 
        
        max_length = max_length + 2
        tokenizer.enable_truncation(max_length=max_length)

        output = self.tokenizer.encode(text)

        # padding and truncation
        if len(output.ids) < max_length:
            len_missing_token = max_length - len(output.ids)
            padding_vec = [self.tokenizer.token_to_id('PAD') for _ in range(len_missing_token)]
            token_output = [*output.ids, *padding_vec]
        elif len(output.ids) > max_length:
            token_output = output.ids[:max_length]
        else:
            token_output = output.ids
        
        return token_output

    def make_matrices(self, idx):
        
        hadm_id = self.df.hadm_id.values[idx]
        diagnoses_info = self.df.previous_diagnoses.values[idx]
        demo_info = self.df.demographics_in_visit.values[idx][0]
        lab_info = self.df.lab_tests_in_visit.values[idx]
        med_info = self.df.medications_in_visit.values[idx]
        vitals_info = self.df.vitals_in_visit.values[idx]
        
        # aki_status = self.df.aki_status_in_visit.values[idx]
        days = self.df.days.values[idx]
        # print(idx)

        lab_info_list = []
        med_info_list = []
        vitals_info_list = []
        label = None

        for day in range(days[0], days[0] + self.observing_window + self.pred_window):
            # print('day', day)
            if day not in days:
                vitals_info_list.append(self.tokenize('', self.max_length['vitals']))
                lab_info_list.append(self.tokenize('', self.max_length['lab_tests']))
                med_info_list.append(self.tokenize('', self.max_length['medications']))

            else:
                i = days.index(day)
                
                # vitals
                if (str(vitals_info[i]) == 'nan') or (vitals_info[i] == np.nan):
                    vitals_info_list.append(self.tokenize('PAD', self.max_length['vitals']))
                else:
                    vitals_info_list.append(self.tokenize(vitals_info[i], self.max_length['vitals']))

                # lab results
                if (str(lab_info[i]) == 'nan') or (lab_info[i] == np.nan):
                    lab_info_list.append(self.tokenize('PAD', self.max_length['lab_tests']))
                else:
                    lab_info_list.append(self.tokenize(lab_info[i], self.max_length['lab_tests']))
                
                # medications
                if (str(med_info[i]) == 'nan') or (med_info[i] == np.nan):
                    med_info_list.append(self.tokenize('PAD', self.max_length['medications']))
                else:
                    med_info_list.append(self.tokenize(med_info[i], self.max_length['medications']))

        # diagnoses
        if (str(diagnoses_info) == 'nan') or (diagnoses_info == np.nan):
            diagnoses_info = self.tokenize('PAD', self.max_length_diags)
        else:
            diagnoses_info = self.tokenize(diagnoses_info, self.max_length_diags)

        # demographics
        if (str(demo_info) == 'nan') or (demo_info == np.nan):
            demo_info = self.tokenize('PAD', self.max_length_diags)
        else:
            demo_info = self.tokenize(demo_info, self.max_length['demographics'])

        #make tensors
        tensor_demo = torch.tensor(demo_info, dtype=torch.int64)
        tensor_diags = torch.tensor(diagnoses_info, dtype=torch.int64)
        tensor_vitals = torch.tensor(vitals_info_list, dtype=torch.int64)
        tensor_labs = torch.tensor(lab_info_list, dtype=torch.int64)
        tensor_meds = torch.tensor(med_info_list, dtype=torch.int64)
        # tensor_labels = torch.tensor(label, dtype=torch.float64)
    
        return tensor_demo, tensor_diags, tensor_vitals, tensor_labs, tensor_meds, hadm_id


In [4]:
class EHR_PRETRAINING(nn.Module):
    def __init__(self, max_length, vocab_size, device, pred_window=2, observing_window=3,  H=128, embedding_size=200, drop=0.6):
        super(EHR_PRETRAINING, self).__init__()

        self.observing_window = observing_window
        self.pred_window = pred_window
        self.H = H
        self.max_length = max_length
        self.max_length_diags = 30
        self.embedding_size = embedding_size
        self.vocab_size = vocab_size
        self.device = device
        self.drop = drop

        # self.embedding = pretrained_model
        self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)

        self.lstm_day = nn.LSTM(input_size=embedding_size,
                            hidden_size=self.H,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True)

        self.fc_day = nn.Linear(self.max_length * 2 * self.H, 2048)

        self.fc_adm = nn.Linear(2048*self.observing_window +  self.max_length_diags * 2 * self.H, 2048)

        self.lstm_adm = nn.LSTM(input_size=2048,
                            hidden_size=self.H,
                            num_layers=2,
                            batch_first=True,
                            bidirectional=False)

        self.drop = nn.Dropout(p=drop)
        self.inner_drop = nn.Dropout(p=0.5)

        # self.fc_2 = nn.Linear(self.H*2, 2)
        self.projection = nn.Sequential(
            nn.ReLU(),
            nn.Linear(in_features=self.H, out_features=256)
        )

    def forward(self, tensor_day, tensor_diagnoses):

        batch_size = tensor_day.size()[0]

        full_output = torch.tensor([]).to(device=self.device)
        out_emb_diags = self.embedding(tensor_diagnoses.squeeze(1))
        out_lstm_diags, _ = self.lstm_day(out_emb_diags)
        full_output = out_lstm_diags.reshape(batch_size, self.max_length_diags * 2 * self.H)
        

        for d in range(self.observing_window):
            # embedding layer applied to all tensors [16,400,200]
            out_emb = self.embedding(tensor_day[:, d, :].squeeze(1))
            # print('out_emb', out_emb.size())

            # lstm layer applied to embedded tensors
            output_lstm_day= self.inner_drop(self.fc_day(\
                                    self.lstm_day(out_emb)[0]\
                                        .reshape(batch_size, self.max_length * 2 * self.H)))

            # print('output_lstm_day', output_lstm_day.size())                   
            # concatenate for all * days
            full_output = torch.cat([full_output, output_lstm_day], dim=1) # [16, 768]

        # print('full_output size: ', full_output.size(), '\n')
        output = self.fc_adm(full_output)
        # print('output after fc_adm size: ', output.size(), '\n')
        output_vector, _ = self.lstm_adm(output)

        # the fisrt transformation
        output_vector_X = self.drop(output_vector)
        projection_X = self.projection(output_vector_X)
        # the second transformation
        output_vector_Y = self.drop(output_vector)
        projection_Y = self.projection(output_vector_Y)

        return output_vector_X, projection_X, output_vector_Y, projection_Y

In [5]:
#paths
CURR_PATH = os.getcwd()
PKL_PATH = CURR_PATH+'/pickles/'
DF_PATH = CURR_PATH +'/dataframes/'
TXT_DIR_TRAIN = CURR_PATH + '/txt_files/train'
destination_folder = '/l/users/svetlana.maslenkova/models' + '/pretraining/fc1_fixed/'

In [6]:
# tokenizer = Tokenizer.from_file('/home/svetlana.maslenkova/LSTM/aki_prediction/tokenizer.json')
# print(f' Vocab size is {tokenizer.get_vocab_size()}')

with open('/home/svetlana.maslenkova/LSTM/dataframes/pid_train_df_pretraining.pkl', 'rb') as f:
    pid_train_df = pickle.load(f)


In [818]:
pid_train_df.head(1)

Unnamed: 0,subject_id,hadm_id,demographics_in_visit,lab_tests_in_visit,medications_in_visit,vitals_in_visit,days_in_visit,aki_status_in_visit,previous_diagnoses,days
9,16679562,20001395,"[hispanic latino m 73, hispanic latino m 73, h...",[hematology blood hematocrit 51.2 %; hematol...,[influenza vaccine quadrivalent 0.5 ml ; bis...,[temp heartrate 80.0 resprate 16.0 o2sa...,[hispanic latino m 73$temp heartrate 80.0 ...,"[0, 0, 0, 1, 1, 0, 0, 1, 0]",,"[0, 1, 2, 3, 4, 5, 6, 7, 8]"


In [7]:
i=0
for row in pid_train_df.medications_in_visit:
    for day in row.values:
        print(day)    
        i+= 1
        if i>0:break

AttributeError: 'list' object has no attribute 'values'

In [None]:
i = 0
LIST_HADMS = []
list_diags = []
# for _, row in pid_val_df.iterrows():
#     # try:
#     #     list_diags.append(row.previous_diagnoses[0].replace('PAD ', '').replace('PAD', ''))
#     #     LIST_HADMS.append(row.hadm_id)
#     # except:
#     #     print(row.previous_diagnoses[0])
#     #     print(np.isnan(row.previous_diagnoses[0]))
#     #     break
#     try:
#         if isinstance(row.previous_diagnoses[0], float):
#             list_diags.append('')      
#         else:
#             list_diags.append(row.previous_diagnoses[0].replace('PAD ', '').replace('PAD', ''))  
#         LIST_HADMS.append(row.hadm_id) 
#     except:
#         print(row.previous_diagnoses[0])
#         print(type(row.previous_diagnoses[0]))
#         raise
#         break

#     i+=1
#     # if i>0:break

pid_train_df['previous_diags'] = [val.replace('PAD ', '').replace('PAD', '') if not isinstance(val, float) else '' for val in pid_train_df['previous_diags']]

In [695]:
from tokenizers.pre_tokenizers import Digits, Punctuation, Whitespace
from tokenizers.normalizers import Lowercase, Replace
from tokenizers import pre_tokenizers, normalizers
from tokenizers.processors import BertProcessing

print('Training tokenizer...')
os.environ["TOKENIZERS_PARALLELISM"] = "true"
tokenizer = Tokenizer(BPE(unk_token="UNK"))
tokenizer.normalizer = normalizers.Sequence([Lowercase()])
# tokenizer.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), Digits(individual_digits=False), Punctuation( behavior = 'removed')])
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([Whitespace(), Punctuation(behavior = 'isolated')])

trainer = BpeTrainer(special_tokens=["<s>", "</s>", "PAD", "UNK", "$"], min_frequency=10)

files = glob.glob('/home/svetlana.maslenkova/LSTM/aki_prediction/txt_files/train'+'/*')
tokenizer.train(files, trainer)
tokenizer.post_processor = BertProcessing(
        ("</s>", tokenizer.token_to_id("</s>")),
        ("<s>", tokenizer.token_to_id("<s>")), 
        )
os.environ["TOKENIZERS_PARALLELISM"] = "false"
print(f'Vocab size is {tokenizer.get_vocab_size()}')

Training tokenizer...



Vocab size is 22569


In [696]:
vocab_size = tokenizer.get_vocab_size()
print(vocab_size)
device='cpu'
frac=1
BATCH_SIZE=16
LR=0.00001

22569


In [831]:
train_dataset = MyDataset(pid_train_df.sample(frac=frac), tokenizer=tokenizer, max_length=max_length)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [172]:
model = EHR_PRETRAINING(max_length=400, vocab_size=vocab_size, device=device).to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)

Tokenization

In [771]:
tensor_demo, tensor_diags, tensor_vitals, tensor_labs, tensor_meds, idx = next(iter(train_loader))
print(idx)
# for idx, (tensor_demo, tensor_diags, tensor_vitals, tensor_labs, tensor_meds, idx) in enumerate(train_dataset):
#     print(idx)

tensor([22553884, 25509314, 22279579, 28187167, 24420446, 25905565, 20764576,
        29770896, 23424563, 24709146, 27972173, 28551210, 26626469, 26969230,
        29925932, 25792248])


In [811]:
i = 1
tokenizer.decode(tensor_vitals[i][0].cpu().detach().numpy())

'temp 97 . 7 heartrate 53 . 0 resprate 18 . 0 o2sat 10 . 0 sbp 129 . 0 dbp 99 . 0 rhythm pain'

In [832]:
tokenizer.enable_truncation(200)
out = tokenizer.encode('influenza vaccine quadrivalent  0.5  ml ; bisacodyl  10  mg ; senna  8.6  mg ; heparin  5000  unit ; ipratropium-albuterol neb  1  neb ; albuterol 0.083% neb soln  1  neb ; folic acid  1  mg ; multivitamins  1  tab ; nicotine patch  7  mg day ; thiamine  100  mg ; vitamin d  1000  unit ; potassium chloride replacement (critical care and oncology)   40  meq ; potassium chloride replacement (critical care and oncology)   60  meq ; potassium chloride replacement (critical care and oncology)   80  meq ; bag  1  bag ; magnesium sulfate  4  gm ; 0.9% sodium chloride  100  ml ; calcium gluconate  2  g ; 0.9% sodium chloride  250  ml ; calcium gluconate  4  g ; amlodipine  5  mg ; aspirin  81  mg ; chlorthalidone  25  mg ; nortriptyline  10  mg ; omeprazole  20  mg ; tamsulosin  0.4  mg ; azithromycin  500  mg ; prednisone  60  mg ; phytonadione  5  mg ; sodium chloride 0.9%  flush  3-10  ml ; ')
print(tokenizer.decode(out.ids))
print(len(out.ids))

influenza vaccine quadrivalent 0 . 5 ml ; bisacodyl 10 mg ; senna 8 . 6 mg ; heparin 5000 unit ; ipratropium - albuterol neb 1 neb ; albuterol 0 . 083 % neb soln 1 neb ; folic acid 1 mg ; multivitamins 1 tab ; nicotine patch 7 mg day ; thiamine 100 mg ; vitamin d 1000 unit ; potassium chloride replacement ( critical care and oncology ) 40 meq ; potassium chloride replacement ( critical care and oncology ) 60 meq ; potassium chloride replacement ( critical care and oncology ) 80 meq ; bag 1 bag ; magnesium sulfate 4 gm ; 0 . 9 % sodium chloride 100 ml ; calcium gluconate 2 g ; 0 . 9 % sodium chloride 250 ml ; calcium gluconate 4 g ; amlodipine 5 mg ; aspirin 81 mg ; chlorthalidone 25 mg ; nortriptyline 10 mg ; omeprazole 20 mg ; tamsulosin 0 . 4 mg ; azithromycin 500 mg ; prednisone 60 mg ; phytonadione 5 mg ; sodium chloride 0 . 9 % flush 3 - 10 ml ;
188


In [812]:
for p in tensor_vitals[i]:
    print(p.cpu().detach().numpy(), "===>", tokenizer.decode(p.cpu().detach().numpy()))

[  0 478 565  17  26 487 960  17  19 489 280  17  19 488 117  17  19 485
 988  17  19 483 326  17  19 475 479   1   2   2   2   2   2   2   2   2
   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2
   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2
   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2
   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2
   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2
   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2
   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2
   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2
   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2   2
   2   2   2   2] ===> temp 97 . 7 heartrate 53 . 0 resprate 18 . 0 o2sat 10 . 0 sbp 129 . 0 dbp 99 . 0 rhythm pain
[0 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2

In [749]:
day = pid_train_df[pid_train_df.hadm_id==idx[i].item()].previous_diagnoses.values[0]
day

''

In [733]:
output = tokenizer.encode(day)

In [734]:
tokenizer.decode(output.ids)

'de119 db20 de785 di10 dr338'

In [722]:
print(len(output.ids))
print(output.ids)

7
[0, 4277, 849, 3387, 828, 3507, 1]


In [723]:
for id_ in output.ids:
    print(id_, tokenizer.id_to_token(id_))

0 <s>
4277 d42843
849 d2449
3387 d5939
828 d4280
3507 d4239
1 </s>


In [752]:
print(len(tensor_vitals[i].cpu().detach().numpy()))

5


In [753]:
for id_ in tensor_vitals[i].cpu().detach().numpy():
    print(id_, tokenizer.id_to_token(id_))

TypeError: only integer scalar arrays can be converted to a scalar index

In [481]:
tokenizer.id_to_token(24)

'7'