# Installation of packages

In [None]:
!pip install PyAthena

In [None]:
!pip install --upgrade gensim

# Import Libraries

In [None]:
from __future__ import print_function
import numpy as np
import string
import nltk
from nltk import word_tokenize
import datetime
import pandas as pd
import boto3
from botocore.client import ClientError
# below is used to print out pretty pandas dataframes
from IPython.display import display, HTML

from pyathena import connect
from pyathena.pandas.util import as_pandas
import time
nltk.download('punkt')

The following code is to enable us to import data from athena services, be able to query parquet files through SQL queries

In [None]:
s3 = boto3.resource('s3')
client = boto3.client("sts")
account_id = client.get_caller_identity()["Account"]
my_session = boto3.session.Session()
region = my_session.region_name
athena_query_results_bucket = 'aws-athena-query-results-'+account_id+'-'+region

try:
    s3.meta.client.head_bucket(Bucket=athena_query_results_bucket)
except ClientError:
    bucket = s3.create_bucket(Bucket=athena_query_results_bucket)
    print('Creating bucket '+athena_query_results_bucket)
cursor = connect(s3_staging_dir='s3://'+athena_query_results_bucket+'/athena/temp').cursor()

# Extract and process notes for Cohort Patients

## Select Notes from Cohort Patients within observation window:

First we will create a table with the patients clinical notes only for those patients that we are considering as part of our cohort, also, any notes from patients that are within the last 48 hrs before last discharge are not considered:

In [None]:
query = ('CREATE TABLE default.diabetic_patients_notes WITH (format=''PARQUET'') AS '
         'select cohort.subject_id, nts.hadm_id, (CASE WHEN nts.charttime IS NOT NULL THEN nts.charttime ELSE nts.chartdate END) as charttime, cohort.mortality_flag, nts.text '
         'FROM  default.diabetic_patients_cohort cohort '
         'left outer join mimiciii.noteevents as nts '
         'on nts.subject_id = cohort.subject_id '
         'AND (CASE WHEN nts.charttime IS NOT NULL THEN nts.charttime ELSE nts.chartdate END) >= (CASE WHEN nts.charttime IS NOT NULL THEN cohort.admit_time ELSE date(cohort.admit_time) END) '
         'AND (CASE WHEN nts.charttime IS NOT NULL THEN nts.charttime ELSE nts.chartdate END) <= (CASE WHEN nts.charttime IS NOT NULL THEN date_add(''hour'',-48,cohort.discharge_time) '
                                                                                                                                      'ELSE date(date_add(''hour'',-48, cohort.discharge_time)) END) '
         'ORDER BY cohort.subject_id ASC, (CASE WHEN nts.charttime IS NOT NULL THEN nts.charttime ELSE nts.chartdate END) ASC')
cursor.execute(query)

## Aggregate Notes by distinct date

Since the Clinical Notes can be scattered accross different dates, but also, a single date can have more than one note, we will aggregate the notes that are in a single date, we will do this by concatenating all the notes of a single date into a single string column:

In [None]:
query = ('CREATE TABLE default.diabetic_patients_notes_agg WITH (format=''PARQUET'') AS '
         'select nts.subject_id,  date(charttime) as chart_date, array_join(array_agg( nts.text || ' || '), '') AS notes_agg '
         'from default.diabetic_patients_notes nts '
         'group by nts.subject_id, date(charttime) '
         'order by nts.subject_id asc, date(charttime) asc ')
cursor.execute(query)

## Generate Notes Embeding for each distinct Date

The following function is to pre-process the notes, get rid of numbers, punctuation and tokenize the words.

In [None]:
def preprocess_dataset(df):    
    ''' Preprocess the text data. And return a list of clinical notes. '''
    clinical_notes = []
    
    df.notes_agg = df.notes_agg.fillna(' ')  # remove NA
    df.notes_agg = df.notes_agg.str.replace('\n',' ')  # remove newline
    df.notes_agg = df.notes_agg.str.replace('\r',' ')
    """
    TODO: 1. remove punc;
          2. remove numbers.
          
    HINT: consider using `string.punctuation`, `str.maketrans`, and `str.translate`.
    """
    df.notes_agg = df.notes_agg.str.translate(str.maketrans('', '', string.punctuation)) # remove punctuation
    df.notes_agg = df.notes_agg.str.translate(str.maketrans('', '', '1234567890')) # remove numbers
    
    df.notes_agg = df.notes_agg.str.lower()  # convert to lower case
    
    # tokenize
    for note in df.notes_agg.values:
        note_tokenized = word_tokenize(note)
        clinical_notes.append(note_tokenized)

    return clinical_notes

To process our notes, we first need to load our pre-trained word2vec model:

In [None]:
from gensim.models import Word2Vec
from gensim.models import KeyedVectors


#pubMedWord2VecModel = KeyedVectors.load_word2vec_format('PubMed-w2v.bin', binary=True)
word2vec_model = KeyedVectors.load('note_vectors.kv')

The following is a function to process the datasets (train or tests and create the neccesary objects that later we will use on the notes network:

In [None]:
def process_dataset(cohort_patients_df, patients_date_notes_df, patients_max_visits):
    patient_subject_id = []
    
    #sparse vector values
    patients_notes_fetures = []
    #sparse vector indexes
    index_0 = []
    index_1 = []
    # Patients last note_date, sql query returns the max number of vistis, but in pythong, we index starting with zero
    # therefore we need to substract 1, but for those patients with zero notes (we still need to keep thos for cosnistency with events network),
    # we need to make thos as the max note being 0 (zero indexed)
    zero_notes = np.array(patients_max_visits.number_dates.values) - 1 == -1
    patients_notes_last_date = np.array(patients_max_visits.number_dates.values) - 1
    patients_notes_last_date[zero_notes] = 0
    patients_notes_last_date = np.expand_dims(patients_notes_last_date, axis=1)
    missing_words = 0
    # Targe label
    patient_mortality = np.array(cohort_patients_df.mortality_flag.values)
    

    for patient_idx, patient in enumerate(cohort_patients_df.subject_id.values):
        if patient_idx % 100 == 0 and patient_idx >0:
            print('Processing patient ' + str(patient_idx))
        patient_subject_id.append(patient)
        patient_date_notes_list = preprocess_dataset(patients_date_notes_df[patients_date_notes_df.subject_id == patient].copy())
        #print(patient_date_notes_list)
        for date_idx, note in enumerate(patient_date_notes_list):
            patient_date_note = np.zeros(200, dtype =  np.float64)
            #print(note)
            for note_word in note:
                #if note_word not in vocab:
                    #vocab.append(note_word)
                try:
                    patient_date_note = patient_date_note + word2vec_model.get_vector(note_word)
                except:
                    missing_words = missing_words + 1
                    #if note_word not in missing_words:
                        #missing_words =  missing_words+1
            #if patient_idx == 0:
                #print(patient_date_note)
            index_0.append(patient_idx)
            index_1.append(date_idx)
            patients_notes_fetures.append(patient_date_note)
            

    return patient_subject_id, patients_notes_fetures, index_0, index_1, patients_notes_last_date, patient_mortality


The following two functions are to help us save and re-load all the processed notes into/from file objects

In [None]:
def save_notes_dataset_objects(patient_subject_id, patients_notes_fetures, index_0, index_1, patients_notes_last_date, patient_mortality, prefix = ''):
    
    np.save(prefix + 'subject_id.npy', patient_subject_id, allow_pickle=True)
    np.save(prefix + 'patients_notes_fetures.npy', patients_notes_fetures, allow_pickle=True)
    np.save(prefix + 'index_0.npy', index_0, allow_pickle=True)
    np.save(prefix + 'index_1.npy', index_1, allow_pickle=True)
    np.save(prefix + 'patients_notes_last_date.npy', patients_notes_last_date, allow_pickle=True)
    np.save(prefix + 'patient_mortality.npy', patient_mortality, allow_pickle=True)



In [None]:
def load_notes_dataset_object(prefix = ''):
    
    patient_subject_id = np.load(prefix + 'subject_id.npy', allow_pickle=True).tolist()
    patients_notes_fetures = np.load(prefix + 'patients_notes_fetures.npy', allow_pickle=True)
    index_0 = np.load(prefix + 'index_0.npy', allow_pickle=True)
    index_1 = np.load(prefix + 'index_1.npy', allow_pickle=True)
    patients_notes_last_date = np.load(prefix + 'patients_notes_last_date.npy', allow_pickle=True)
    patient_mortality = np.load(prefix + 'patient_mortality.npy', allow_pickle=True)
    return patient_subject_id, patients_notes_fetures, index_0, index_1, patients_notes_last_date, patient_mortality
    

## Balanced Cohort, Train Set

We have already pre-processed our cohort sets using SQL queries, the following is the query to fetch the list of all test cohort patients and then process the notes for each patient, we first do it for the train set, then for the test set:

In [None]:
query = 'select cohort.new_subject_id as subject_id, cohort.mortality_flag from default.train_cohort2 cohort order by cohort.new_subject_id'
cursor.execute(query)
cohort_patients_df = as_pandas(cursor)

In [None]:
cohort_patients_df

In [None]:
We know that not all patients have the same number of visit dates, therefore, we need to find what is the maximum number of visit dates for any given patient

In [None]:
query = ('select cohort.new_subject_id as subject_id, count(nts.chart_date) as number_dates '
         'from default.diabetic_patients_notes_agg nts join default.train_cohort2 cohort '
         'on nts.subject_id = cohort.subject_id group by cohort.new_subject_id order by cohort.new_subject_id;')
cursor.execute(query)
patients_max_visits = as_pandas(cursor)


In [None]:
t0 = time.time()
query = ('select cohort.subject_id, nts.chart_date, nts.notes_agg from default.diabetic_patients_notes_agg nts '
         'join (select cohort.new_subject_id as subject_id, cohort.subject_id as old_subject_id from default.train_cohort2 cohort order by cohort.new_subject_id) cohort '
         'on nts.subject_id = cohort.old_subject_id '
         'order by cohort.subject_id asc, nts.chart_date asc;')
cursor.execute(query)
patients_date_notes_df = as_pandas(cursor)
print('Total number of train patients_dates: ' +str(len(patients_date_notes_df)))
t1 = time.time()
processing_time = t1-t0
print('Training set notes Agg query time: ' + str(processing_time))

In [None]:
t0 = time.time()
train_subject_id, train_patients_notes_fetures, train_index_0, train_index_1, train_patients_notes_last_date, train_patient_mortality = process_dataset(cohort_patients_df, patients_date_notes_df, patients_max_visits)
t1 = time.time()
processing_time = t1-t0
print('Training set processing time: ' + str(processing_time))

In [None]:
train_patients_notes_last_date = np.expand_dims(train_patients_notes_last_date, axis=1)
save_notes_dataset_objects(train_subject_id, train_patients_notes_fetures, train_index_0, train_index_1, train_patients_notes_last_date, train_patient_mortality, prefix = 'train_')


## Balanced Cohort, test set

We now process the test set:

In [None]:
query = 'select cohort.subject_id as subject_id, cohort.mortality_flag from default.test_cohort cohort order by cohort.subject_id'
cursor.execute(query)
cohort_patients_df = as_pandas(cursor)

In [None]:
cohort_patients_df

In [None]:
t0 = time.time()
query = ('select cohort.subject_id, nts.chart_date, nts.notes_agg from default.diabetic_patients_notes_agg nts '
         'join (select cohort.subject_id from default.test_cohort cohort order by cohort.subject_id) cohort '
         'on nts.subject_id = cohort.subject_id '
         'order by cohort.subject_id asc, nts.chart_date asc;')
cursor.execute(query)
patients_date_notes_df = as_pandas(cursor)
print('Total number of train patients_dates: ' +str(len(patients_date_notes_df)))
t1 = time.time()
processing_time = t1-t0
print('Test set notes Agg query time: ' + str(processing_time))

In [None]:
query = ('select cohort.subject_id, count(nts.chart_date) as number_dates '
         'from default.diabetic_patients_notes_agg nts join default.test_cohort cohort '
         'on nts.subject_id = cohort.subject_id group by cohort.subject_id order by cohort.subject_id;')
cursor.execute(query)
patients_max_visits = as_pandas(cursor)

In [None]:
t0 = time.time()
test_subject_id, test_patients_notes_fetures, test_index_0, test_index_1, test_patients_notes_last_date, test_patient_mortality = process_dataset(cohort_patients_df, patients_date_notes_df,patients_max_visits)
t1 = time.time()
processing_time = t1-t0
print('Training set processing time: ' + str(processing_time))

In [None]:
save_notes_dataset_objects(test_subject_id, test_patients_notes_fetures, test_index_0, test_index_1, test_patients_notes_last_date, test_patient_mortality, prefix = 'test_')

In [None]:
train_subject_id_cp, train_patients_notes_fetures_cp, train_index_0_cp, train_index_1_cp, train_patients_notes_last_date, train_patient_mortality = load_notes_dataset_object(prefix = 'train_') 

## Unblanced Cohort

Now we will process the Unbalanced Cohort, no need to have test and train as this split will be done later on the main Network notebook

In [None]:
query = 'select cohort.subject_id as subject_id, cohort.mortality_flag from default.diabetic_patients_cohort cohort order by cohort.subject_id'
cursor.execute(query)
cohort_patients_df = as_pandas(cursor)
cohort_patients_df

In [None]:
t0 = time.time()
query = ('select nts.subject_id, nts.chart_date, nts.notes_agg from default.diabetic_patients_notes_agg nts '
         'order by nts.subject_id asc, nts.chart_date asc;')
cursor.execute(query)
patients_date_notes_df = as_pandas(cursor)
print('Total number of train patients_dates: ' +str(len(patients_date_notes_df)))
t1 = time.time()
processing_time = t1-t0
print('Original set notes Agg query time: ' + str(processing_time))

In [None]:
query = ('select nts.subject_id as subject_id, count(nts.chart_date) as number_dates '
         'from default.diabetic_patients_notes_agg nts '
         'group by nts.subject_id order by nts.subject_id;')
cursor.execute(query)
patients_max_visits = as_pandas(cursor)

In [None]:
t0 = time.time()
orig_subject_id, orig_patients_notes_fetures, orig_index_0, orig_index_1, orig_patients_notes_last_date, orig_patient_mortality = process_dataset(cohort_patients_df, patients_date_notes_df,patients_max_visits)
t1 = time.time()
processing_time = t1-t0
print('Original set processing time: ' + str(processing_time))

In [None]:
save_notes_dataset_objects(orig_subject_id, orig_patients_notes_fetures, orig_index_0, orig_index_1, orig_patients_notes_last_date, orig_patient_mortality, prefix = 'orig_')