Installation of packages

In [2]:
!pip install PyAthena

Collecting PyAthena
  Using cached PyAthena-2.2.0-py3-none-any.whl (37 kB)
Collecting tenacity>=4.1.0
  Using cached tenacity-7.0.0-py2.py3-none-any.whl (23 kB)
Installing collected packages: tenacity, PyAthena
Successfully installed PyAthena-2.2.0 tenacity-7.0.0


In [3]:
!pip install --upgrade gensim

Collecting gensim
  Using cached gensim-4.0.1-cp36-cp36m-manylinux1_x86_64.whl (23.9 MB)
Installing collected packages: gensim
Successfully installed gensim-4.0.1


Import Libraries

In [4]:
from __future__ import print_function
import numpy as np
import string
import nltk
from nltk import word_tokenize
import datetime
import pandas as pd
import boto3
from botocore.client import ClientError
# below is used to print out pretty pandas dataframes
from IPython.display import display, HTML

from pyathena import connect
from pyathena.pandas.util import as_pandas
import time
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/ec2-user/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

The following code is to enable us to import data from athena services, be able to query parquet files through SQL queries

In [6]:
s3 = boto3.resource('s3')
client = boto3.client("sts")
account_id = client.get_caller_identity()["Account"]
my_session = boto3.session.Session()
region = my_session.region_name
athena_query_results_bucket = 'aws-athena-query-results-'+account_id+'-'+region

try:
    s3.meta.client.head_bucket(Bucket=athena_query_results_bucket)
except ClientError:
    bucket = s3.create_bucket(Bucket=athena_query_results_bucket)
    print('Creating bucket '+athena_query_results_bucket)
cursor = connect(s3_staging_dir='s3://'+athena_query_results_bucket+'/athena/temp').cursor()

We know that not all patients have the same number of visit dates, therefore, we need to find what is the maximum number of visit dates for any given patient

The following function is to pre-process the notes, get rid of numbers, punctuation and tokenize the words.

In [8]:
def preprocess_dataset(df):    
    ''' Preprocess the text data. And return a list of clinical notes. '''
    clinical_notes = []
    
    df.notes_agg = df.notes_agg.fillna(' ')  # remove NA
    df.notes_agg = df.notes_agg.str.replace('\n',' ')  # remove newline
    df.notes_agg = df.notes_agg.str.replace('\r',' ')
    """
    TODO: 1. remove punc;
          2. remove numbers.
          
    HINT: consider using `string.punctuation`, `str.maketrans`, and `str.translate`.
    """
    df.notes_agg = df.notes_agg.str.translate(str.maketrans('', '', string.punctuation)) # remove punctuation
    df.notes_agg = df.notes_agg.str.translate(str.maketrans('', '', '1234567890')) # remove numbers
    
    df.notes_agg = df.notes_agg.str.lower()  # convert to lower case
    
    # tokenize
    for note in df.notes_agg.values:
        note_tokenized = word_tokenize(note)
        clinical_notes.append(note_tokenized)

    return clinical_notes

To process our notes, we first need to load our pre-trained word2vec model:

In [9]:
from gensim.models import Word2Vec
from gensim.models import KeyedVectors


#pubMedWord2VecModel = KeyedVectors.load_word2vec_format('PubMed-w2v.bin', binary=True)
word2vec_model = KeyedVectors.load('note_vectors.kv')



Now we can start processing the notes for each patient:

The following is a function to process the datasets (train or tests and create the neccesary objects that later we will use on the notes network:

In [116]:
def process_dataset(cohort_patients_df, patients_date_notes_df, patients_max_visits):
    #patients_notes_fetures = torch.zeros(len(cohort_patients_df), patients_max_visits,200, dtype = torch.float)
    #Patients Subject_ID
    patient_subject_id = []
    
    #sparse vector values
    patients_notes_fetures = []
    #sparse vector indexes
    index_0 = []
    index_1 = []
    #Patients last note_date
    zero_notes = np.array(patients_max_visits.number_dates.values) - 1 == -1
    patients_notes_last_date = np.array(patients_max_visits.number_dates.values) - 1
    patients_notes_last_date[zero_notes] = 0
    patients_notes_last_date = np.expand_dims(patients_notes_last_date, axis=1)
    missing_words = 0
    patient_mortality = np.array(cohort_patients_df.mortality_flag.values)
    

    for patient_idx, patient in enumerate(cohort_patients_df.subject_id.values):
        if patient_idx % 100 == 0 and patient_idx >0:
            print('Processing patient ' + str(patient_idx))
        patient_subject_id.append(patient)
        patient_date_notes_list = preprocess_dataset(patients_date_notes_df[patients_date_notes_df.subject_id == patient].copy())
        #print(patient_date_notes_list)
        for date_idx, note in enumerate(patient_date_notes_list):
            patient_date_note = np.zeros(200, dtype =  np.float64)
            #print(note)
            for note_word in note:
                #if note_word not in vocab:
                    #vocab.append(note_word)
                try:
                    patient_date_note = patient_date_note + word2vec_model.get_vector(note_word)
                except:
                    missing_words = missing_words + 1
                    #if note_word not in missing_words:
                        #missing_words =  missing_words+1
            #if patient_idx == 0:
                #print(patient_date_note)
            index_0.append(patient_idx)
            index_1.append(date_idx)
            patients_notes_fetures.append(patient_date_note)
            

    return patient_subject_id, patients_notes_fetures, index_0, index_1, patients_notes_last_date, patient_mortality


The following two functions are to help us save and re-load all the processed notes into/from file objects

In [66]:
def save_notes_dataset_objects(patient_subject_id, patients_notes_fetures, index_0, index_1, patients_notes_last_date, patient_mortality, prefix = ''):
    
    np.save(prefix + 'subject_id.npy', patient_subject_id, allow_pickle=True)
    np.save(prefix + 'patients_notes_fetures.npy', patients_notes_fetures, allow_pickle=True)
    np.save(prefix + 'index_0.npy', index_0, allow_pickle=True)
    np.save(prefix + 'index_1.npy', index_1, allow_pickle=True)
    np.save(prefix + 'patients_notes_last_date.npy', patients_notes_last_date, allow_pickle=True)
    np.save(prefix + 'patient_mortality.npy', patient_mortality, allow_pickle=True)



In [32]:
def load_notes_dataset_object(prefix = ''):
    
    patient_subject_id = np.load(prefix + 'subject_id.npy', allow_pickle=True).tolist()
    patients_notes_fetures = np.load(prefix + 'patients_notes_fetures.npy', allow_pickle=True)
    index_0 = np.load(prefix + 'index_0.npy', allow_pickle=True)
    index_1 = np.load(prefix + 'index_1.npy', allow_pickle=True)
    patients_notes_last_date = np.load(prefix + 'patients_notes_last_date.npy', allow_pickle=True)
    patient_mortality = np.load(prefix + 'patient_mortality.npy', allow_pickle=True)
    return patient_subject_id, patients_notes_fetures, index_0, index_1, patients_notes_last_date, patient_mortality
    

We have already pre-processed our cohort sets using SQL queries, the following is the query to fetch the list of all test cohort patients and then process the notes for each patient, we first do it for the train set, then for the test set:

In [45]:
query = 'select cohort.new_subject_id as subject_id, cohort.mortality_flag from default.train_cohort2 cohort order by cohort.new_subject_id'
cursor.execute(query)
cohort_patients_df = as_pandas(cursor)

In [47]:
cohort_patients_df

Unnamed: 0,subject_id,mortality_flag
0,0,0
1,1,0
2,2,0
3,3,0
4,4,0
...,...,...
13785,13785,1
13786,13786,1
13787,13787,1
13788,13788,1


In [48]:
t0 = time.time()
query = ('select cohort.subject_id, nts.chart_date, nts.notes_agg from default.diabetic_patients_notes_agg nts '
         'join (select cohort.new_subject_id as subject_id, cohort.subject_id as old_subject_id from default.train_cohort2 cohort order by cohort.new_subject_id) cohort '
         'on nts.subject_id = cohort.old_subject_id '
         'order by cohort.subject_id asc, nts.chart_date asc;')
cursor.execute(query)
patients_date_notes_df = as_pandas(cursor)
print('Total number of train patients_dates: ' +str(len(patients_date_notes_df)))
t1 = time.time()
processing_time = t1-t0
print('Training set notes Agg query time: ' + str(processing_time))

Total number of train patients_dates: 237636
Training set notes Agg query time: 380.5437870025635


In [86]:
query = ('select cohort.new_subject_id as subject_id, count(nts.chart_date) as number_dates '
         'from default.diabetic_patients_notes_agg nts join default.train_cohort2 cohort '
         'on nts.subject_id = cohort.subject_id group by cohort.new_subject_id order by cohort.new_subject_id;')
cursor.execute(query)
patients_max_visits = as_pandas(cursor)


[ 5 16 10 ... 10 50  6]
13790


In [84]:
train_patients_max_visits

Unnamed: 0,subject_id,number_dates
0,0,6
1,1,17
2,2,11
3,3,6
4,4,80
...,...,...
13785,13785,9
13786,13786,4
13787,13787,11
13788,13788,51


In [None]:
t0 = time.time()
train_subject_id, train_patients_notes_fetures, train_index_0, train_index_1, train_patients_notes_last_date, train_patient_mortality = process_dataset(cohort_patients_df, patients_date_notes_df, patients_max_visits)
t1 = time.time()
processing_time = t1-t0
print('Training set processing time: ' + str(processing_time))

Training set processing time: 3004.7571506500244


In [107]:
zero_notes = np.array(train_patients_max_visits.number_dates.values) - 1 == -1
train_patients_notes_last_date = np.array(train_patients_max_visits.number_dates.values) - 1
train_patients_notes_last_date[zero_notes] = 0
len(train_patients_notes_last_date)


13790

In [106]:
train_patients_notes_last_date = np.expand_dims(train_patients_notes_last_date, axis=1)
save_notes_dataset_objects(train_subject_id, train_patients_notes_fetures, train_index_0, train_index_1, train_patients_notes_last_date, train_patient_mortality, prefix = 'train_')


We now process the test set:

In [111]:
query = 'select cohort.subject_id as subject_id, cohort.mortality_flag from default.test_cohort cohort order by cohort.subject_id'
cursor.execute(query)
cohort_patients_df = as_pandas(cursor)

In [112]:
cohort_patients_df

Unnamed: 0,subject_id,mortality_flag
0,21,1
1,28,0
2,59,0
3,130,0
4,188,1
...,...,...
1960,99573,0
1961,99714,0
1962,99776,0
1963,99893,0


In [113]:
t0 = time.time()
query = ('select cohort.subject_id, nts.chart_date, nts.notes_agg from default.diabetic_patients_notes_agg nts '
         'join (select cohort.subject_id from default.test_cohort cohort order by cohort.subject_id) cohort '
         'on nts.subject_id = cohort.subject_id '
         'order by cohort.subject_id asc, nts.chart_date asc;')
cursor.execute(query)
patients_date_notes_df = as_pandas(cursor)
print('Total number of train patients_dates: ' +str(len(patients_date_notes_df)))
t1 = time.time()
processing_time = t1-t0
print('Test set notes Agg query time: ' + str(processing_time))

Total number of train patients_dates: 26839
Test set notes Agg query time: 50.07832431793213


In [114]:
query = ('select cohort.subject_id, count(nts.chart_date) as number_dates '
         'from default.diabetic_patients_notes_agg nts join default.test_cohort cohort '
         'on nts.subject_id = cohort.subject_id group by cohort.subject_id order by cohort.subject_id;')
cursor.execute(query)
patients_max_visits = as_pandas(cursor)

In [115]:
t0 = time.time()
test_subject_id, test_patients_notes_fetures, test_index_0, test_index_1, test_patients_notes_last_date, test_patient_mortality = process_dataset(cohort_patients_df, patients_date_notes_df,patients_max_visits)
t1 = time.time()
processing_time = t1-t0
print('Training set processing time: ' + str(processing_time))

Processing patient 0
Processing patient 1000
Training set processing time: 336.23525953292847


In [110]:
zero_notes = np.array(test_patients_max_visits.number_dates.values) - 1 == -1
test_patients_notes_last_date = np.array(test_patients_max_visits.number_dates.values) - 1
test_patients_notes_last_date[zero_notes] = 0
len(test_patients_notes_last_date)

NameError: name 'test_patients_max_visits' is not defined

In [117]:
save_notes_dataset_objects(test_subject_id, test_patients_notes_fetures, test_index_0, test_index_1, test_patients_notes_last_date, test_patient_mortality, prefix = 'test_')

In [None]:
After we have done all the processing, we do some tests:

In [59]:
train_subject_id_cp, train_patients_notes_fetures_cp, train_index_0_cp, train_index_1_cp, train_patients_notes_last_date, train_patient_mortality = load_notes_dataset_object(prefix = 'train_') 

In [91]:
#print(train_subject_id)
#print(train_subject_id_cp)
print(train_patients_notes_last_date)

[[ 5]
 [16]
 [10]
 ...
 [10]
 [50]
 [ 6]]


In [61]:
print(train_index_1_cp[0:40])

[ 0  1  2  3  4  5  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16  0
  1  2  3  4  5  6  7  8  9 10  0  1  2  3  4  5]


In [62]:
#train_patients_notes_fetures[0]
#train_patients_notes_fetures_cp[0]
#train_index_1_cp
index = [train_index_0_cp, train_index_1_cp]

In [63]:
import torch
s = torch.sparse_coo_tensor(index, train_patients_notes_fetures_cp, (len(train_subject_id_cp),patients_max_visits,200))

In [92]:
s[17].to_dense()[0]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.], dtype=torch.float64)

In [40]:
masks = torch.from_numpy(train_patients_notes_last_date).long()
last_visit = masks.expand(-1,505).unsqueeze(1)
masks.dtype

torch.int64

Original dataset

In [118]:
query = 'select cohort.subject_id as subject_id, cohort.mortality_flag from default.diabetic_patients_cohort cohort order by cohort.subject_id'
cursor.execute(query)
cohort_patients_df = as_pandas(cursor)
cohort_patients_df

Unnamed: 0,subject_id,mortality_flag
0,13,0
1,18,0
2,20,0
3,21,1
4,24,0
...,...,...
9817,99955,1
9818,99957,0
9819,99991,0
9820,99995,0


In [119]:
t0 = time.time()
query = ('select nts.subject_id, nts.chart_date, nts.notes_agg from default.diabetic_patients_notes_agg nts '
         'order by nts.subject_id asc, nts.chart_date asc;')
cursor.execute(query)
patients_date_notes_df = as_pandas(cursor)
print('Total number of train patients_dates: ' +str(len(patients_date_notes_df)))
t1 = time.time()
processing_time = t1-t0
print('Original set notes Agg query time: ' + str(processing_time))

Total number of train patients_dates: 133419
Training set notes Agg query time: 202.77840876579285


In [121]:
query = ('select nts.subject_id as subject_id, count(nts.chart_date) as number_dates '
         'from default.diabetic_patients_notes_agg nts '
         'group by nts.subject_id order by nts.subject_id;')
cursor.execute(query)
patients_max_visits = as_pandas(cursor)

In [122]:
t0 = time.time()
orig_subject_id, orig_patients_notes_fetures, orig_index_0, orig_index_1, orig_patients_notes_last_date, orig_patient_mortality = process_dataset(cohort_patients_df, patients_date_notes_df,patients_max_visits)
t1 = time.time()
processing_time = t1-t0
print('Original set processing time: ' + str(processing_time))

Processing patient 100
Processing patient 200
Processing patient 300
Processing patient 400
Processing patient 500
Processing patient 600
Processing patient 700
Processing patient 800
Processing patient 900
Processing patient 1000
Processing patient 1100
Processing patient 1200
Processing patient 1300
Processing patient 1400
Processing patient 1500
Processing patient 1600
Processing patient 1700
Processing patient 1800
Processing patient 1900
Processing patient 2000
Processing patient 2100
Processing patient 2200
Processing patient 2300
Processing patient 2400
Processing patient 2500
Processing patient 2600
Processing patient 2700
Processing patient 2800
Processing patient 2900
Processing patient 3000
Processing patient 3100
Processing patient 3200
Processing patient 3300
Processing patient 3400
Processing patient 3500
Processing patient 3600
Processing patient 3700
Processing patient 3800
Processing patient 3900
Processing patient 4000
Processing patient 4100
Processing patient 4200
P

In [127]:
save_notes_dataset_objects(orig_subject_id, orig_patients_notes_fetures, orig_index_0, orig_index_1, orig_patients_notes_last_date, orig_patient_mortality, prefix = 'orig_')

In [125]:
orig_patients_notes_last_date

array([[3],
       [0],
       [2],
       ...,
       [5],
       [1],
       [1]])

In [128]:
orig_patients_notes_fetures

[array([ 2.82588533e+02,  2.51790093e+02, -1.30849318e+02, -6.82312020e+01,
         2.95573529e+01, -2.85534904e+02, -5.27128601e+01, -1.56830389e+02,
         2.66608213e+01, -1.45361312e+02,  3.95318070e+01,  1.60528343e+01,
         1.93247389e+02,  8.69812649e+01, -1.40968018e+01, -1.26073747e+01,
        -7.05547264e+01,  2.75077861e+02,  8.69922538e+01,  1.08964065e+01,
         1.39814938e+02,  1.33056688e+02, -1.40511765e+02, -2.21243607e+00,
        -6.03700907e+01,  4.42106261e+01, -1.38543426e+02,  3.22604817e+01,
         2.09642685e+02,  5.02811901e+01,  2.35179993e+01, -2.13560599e+02,
         1.06956149e+02,  7.83611964e+01, -1.24569827e+02,  3.32446340e+01,
         8.53153487e+01,  8.15852437e+01, -8.54707402e+01, -7.08519448e+01,
        -7.94765557e+01, -1.68935036e+02,  1.49808986e+02, -1.04235997e+02,
        -2.01530040e+02, -6.79787709e+01, -1.49193000e+02,  8.56204633e+01,
         1.45902096e+02, -1.65252902e+02, -8.86541227e+01, -8.17867555e+01,
         1.8