In [0]:
import torch
from torch.utils import data

import torch.nn as nn
import torch.nn.functional as F

import pandas as pd
%matplotlib inline
from matplotlib import pyplot as plt

import numpy as np
import pickle

from google.colab import auth

device = "cuda" if torch.cuda.is_available() else "cpu"

In [0]:
!wget -r -N -c -np --user kyleliu --ask-password https://physionet.org/files/picdb/1.0.0/

In [0]:
# Read Data into DF

admissions = pd.read_csv('physionet.org/files/picdb/1.0.0/ADMISSIONS.csv.gz', compression='gzip')
chartevents = pd.read_csv('physionet.org/files/picdb/1.0.0/CHARTEVENTS.csv.gz', compression='gzip')
diagnoses_icd = pd.read_csv('physionet.org/files/picdb/1.0.0/DIAGNOSES_ICD.csv.gz', compression='gzip')
d_icd_diagnoses = pd.read_csv('physionet.org/files/picdb/1.0.0/D_ICD_DIAGNOSES.csv.gz', compression='gzip')
d_items = pd.read_csv('physionet.org/files/picdb/1.0.0/D_ITEMS.csv.gz', compression='gzip')
d_labitems = pd.read_csv('physionet.org/files/picdb/1.0.0/D_LABITEMS.csv.gz', compression='gzip')
emr_symptoms = pd.read_csv('physionet.org/files/picdb/1.0.0/EMR_SYMPTOMS.csv.gz', compression='gzip')
icu_stays = pd.read_csv('physionet.org/files/picdb/1.0.0/ICUSTAYS.csv.gz', compression='gzip')
input_events = pd.read_csv('physionet.org/files/picdb/1.0.0/INPUTEVENTS.csv.gz', compression='gzip')
lab_events = pd.read_csv('physionet.org/files/picdb/1.0.0/LABEVENTS.csv.gz', compression='gzip')
patients = pd.read_csv('physionet.org/files/picdb/1.0.0/PATIENTS.csv.gz', compression='gzip')
prescriptions = pd.read_csv('physionet.org/files/picdb/1.0.0/PRESCRIPTIONS.csv.gz', compression='gzip')
surgery_vital_signs = pd.read_csv('physionet.org/files/picdb/1.0.0/SURGERY_VITAL_SIGNS.csv.gz', compression='gzip')

In [0]:
# Easier to use: 

item_dict = dict() 
for _, row in d_items.iterrows(): 
  item_dict[row.ITEMID] = row.LABEL

lab_item_dict = dict()
for _, row in d_labitems.iterrows(): 
  lab_item_dict[row.ITEMID] = row.LABEL

ICD_CN_TO_ICD = dict() 
for _, row in d_icd_diagnoses.iterrows(): 
  ICD_CN_TO_ICD[row.ICD10_CODE_CN] = row.ICD10_CODE 


Here we include only the first admission of each patient.

In [0]:
# Clean: Include only the first admission

admissions = admissions.sort_values(by = ['ADMITTIME'])

admits_to_keep = []
seen_patients = set()

for _, row in admissions.iterrows(): 
  if row.SUBJECT_ID not in seen_patients: 
    admits_to_keep.append(row.HADM_ID)
    seen_patients.add(row.SUBJECT_ID)

In [0]:
def remove_admits(df): 
  return df[df['HADM_ID'].isin(admits_to_keep)]

admissions = remove_admits(admissions)
chartevents = remove_admits(chartevents)
diagnoses_icd = remove_admits(diagnoses_icd)
emr_symptoms = remove_admits(emr_symptoms)
icu_stays = remove_admits(icu_stays)
input_events = remove_admits(input_events)
lab_events = remove_admits(lab_events)
prescriptions = remove_admits(prescriptions)
surgery_vital_signs = remove_admits(surgery_vital_signs)


Helper functions to parse admit times.

In [0]:
from datetime import date, timedelta, time, datetime

def to_datetime(x): 
  li = x.split()
  my_date = li[0].split("-")
  my_time = li[1].split(":")

  ret = datetime(int(my_date[0]), int(my_date[1]), int(my_date[2]), int(my_time[0]), int(my_time[1]), int(my_time[2]))
  
  return ret

age_at_admission = dict()  
birth_date = dict()
admit_date = dict() 
for _, row in patients.iterrows(): 
  birth_date[row.SUBJECT_ID] = to_datetime(row.DOB)

for _, row in admissions.iterrows(): 
  admit_date[row.SUBJECT_ID] = to_datetime(row.ADMITTIME)
  age_at_admission[row.SUBJECT_ID] = to_datetime(row.ADMITTIME) - birth_date[row.SUBJECT_ID]

In [0]:
# Time since admission (hours)
def normalize_time(patient_id, x): 
  delta = to_datetime(x) - admit_date[patient_id]
  return delta.total_seconds() / 3600.0 

In [0]:
patient_set = [p for p in patients.SUBJECT_ID]

In [0]:
chartevents['HOURS_IN'] = chartevents.apply(lambda row: normalize_time(row.SUBJECT_ID, row.CHARTTIME), axis=1)
lab_events['HOURS_IN'] = lab_events.apply(lambda row: normalize_time(row.SUBJECT_ID, row.CHARTTIME), axis=1)
surgery_vital_signs['HOURS_IN'] = surgery_vital_signs.apply(lambda row: normalize_time(row.SUBJECT_ID, row.MONITORTIME), axis=1)

In [0]:
SEPTIC = {2793: 193.595, 1280: 14.989166666666666, 1319: 0.7147222222222223, 1338: 62.74388888888889, 1323: 84.45166666666667, 1292: 66.86444444444444, 1298: 63.765277777777776, 1330: 32.00194444444445, 3073: 67.66916666666667, 3308: 160.06861111111112, 3779: 341.9925, 4131: 196.6175, 1326: 182.1588888888889, 1262: 86.35083333333333, 1312: 400.67027777777776, 3068: 112.52888888888889, 4056: 110.33972222222222, 5082: 130.1888888888889, 5129: 906.4261111111111, 1315: 337.6511111111111, 1273: 11.250833333333333, 6436: 318.9583333333333, 6752: 2331.063888888889, 8198: 1127.8408333333334, 1318: 4.063888888888889, 1320: 42.129444444444445, 1333: 58.63444444444445, 3585: 82.62444444444445, 3786: 86.25638888888889, 1272: 67.58472222222223, 1275: 116.75777777777778, 1299: 179.26527777777778, 3925: 57.754444444444445, 3948: 122.77944444444445, 4492: 55.33861111111111, 5340: 2.3766666666666665, 2939: 66.9263888888889, 3839: 62.606944444444444, 3130: 227.7, 3167: 3.7055555555555557, 3742: 88.08833333333334, 4287: 72.23666666666666, 4502: 12.394166666666667, 4695: 90.98472222222222, 4782: 166.7513888888889, 8524: 4.151666666666666, 9312: 61.70055555555555, 9504: 1951.1802777777777, 9521: 217.85722222222222, 9946: 27.14472222222222, 9957: 321.5752777777778, 10088: 125.885, 10121: 57.72083333333333, 10166: 0.2286111111111111, 10172: 3770.6111111111113, 10247: 67.02583333333334, 10249: 170.19833333333332, 10280: 565.0188888888889, 10402: 99.44583333333334, 10409: 622.7705555555556, 10430: 290.94444444444446, 10435: 569.5397222222222, 10444: 21.735277777777778, 10509: 95.78166666666667, 10523: 166.8225, 10567: 124.41444444444444, 10600: 256.0683333333333, 10647: 190.41611111111112, 10655: 71.56527777777778, 10669: 1.355, 4319: 15.455555555555556, 4929: 254.065, 5092: 3.5255555555555556, 2782: 122.80333333333333, 5176: 63.47083333333333, 7579: 176.04361111111112, 7748: 87.85277777777777, 8152: 19.52888888888889, 8935: 3.9275, 9152: 404.3177777777778, 9370: 164.07, 9375: 67.46555555555555, 9396: 248.24444444444444, 9447: 99.33638888888889, 9508: 475.7602777777778, 9529: 150.75555555555556, 9549: 91.45527777777778, 9568: 229.4013888888889, 9704: 143.8513888888889, 9726: 69.695, 9738: 20.369722222222222, 9741: 108.98194444444445, 1285: 141.9175, 7875: 394.52694444444444, 8088: 254.315, 8248: 6.469722222222222, 7291: 319.5075, 2972: 533.2352777777778, 3415: 178.0425, 3649: 9.02111111111111, 1288: 3.6125, 1294: 46.236111111111114, 1310: 315.58972222222224, 1199: 199.3788888888889, 5963: 339.5622222222222, 6167: 3.993333333333333, 6457: 194.45527777777778, 10728: 73.84277777777778, 10732: 71.80972222222222, 10766: 95.49777777777778, 10887: 132.22861111111112, 10893: 100.75083333333333, 10913: 223.33027777777778, 10954: 174.05777777777777, 11118: 1205.431388888889, 11145: 268.4841666666667, 11174: 334.21444444444444, 11250: 600.6988888888889, 11254: 61.841944444444444, 11403: 233.88722222222222, 11463: 236.24777777777777, 11543: 72.28666666666666, 11646: 80.75194444444445, 11649: 301.5316666666667, 11739: 506.30805555555554, 11773: 16.904722222222222, 11786: 1111.831111111111, 11790: 212.48611111111111, 11826: 322.19305555555553, 11834: 1413.1802777777777, 11849: 2.894722222222222, 11896: 2.275, 11915: 13.36138888888889, 11969: 95.67388888888888, 12084: 67.45138888888889, 12140: 11.063611111111111, 12176: 99.75972222222222, 12335: 113.81166666666667, 12468: 123.22805555555556, 12560: 291.8977777777778, 12574: 437.2625, 12636: 164.09694444444443, 12680: 17.435555555555556, 12713: 801.6494444444445, 12896: 35.2675, 12970: 22.635, 12986: 215.15444444444444, 12992: 1.8555555555555556, 12998: 265.0155555555556, 13113: 29.455, 13117: 666.6102777777778, 13122: 6.269722222222223, 13177: 145.80527777777777, 13196: 3.001388888888889, 13231: 84.62277777777778, 1289: 1.9461111111111111, 1308: 92.88666666666667, 1334: 66.15222222222222, 7035: 5.079444444444444, 7271: 29.36111111111111, 7654: 589.9316666666666, 5486: 192.7, 6854: 21.101944444444445, 1325: 50.31027777777778, 1329: 37.35166666666667, 1306: 2.0027777777777778, 1266: 132.9361111111111, 1311: 0.57, 6058: 98.17277777777778, 9809: 298.9038888888889, 10025: 83.37861111111111, 10037: 494.7583333333333, 10094: 78.81361111111111, 10151: 0.2822222222222222, 10210: 215.27555555555554, 10212: 2666.486388888889, 10260: 114.53694444444444, 10440: 3.45, 10460: 143.38166666666666, 10549: 82.72527777777778, 10557: 501.7269444444444, 10616: 31.836944444444445, 10680: 278.90722222222223, 10735: 57.88527777777778, 10737: 501.7113888888889, 10785: 78.98527777777778, 10927: 192.81, 11019: 169.5775, 11068: 221.17361111111111, 11076: 38.684444444444445, 11078: 213.64222222222222, 11128: 56.50555555555555, 11143: 621.0858333333333, 11260: 525.0486111111111, 11426: 161.0688888888889, 11429: 1613.5227777777777, 11521: 241.88611111111112, 11545: 288.02416666666664, 11643: 47.759166666666665, 11725: 26.93888888888889, 11824: 312.5213888888889, 11942: 129.79361111111112, 11992: 1341.0283333333334, 12024: 7.134722222222222, 12094: 96.63944444444445, 12165: 144.39638888888888, 12173: 151.16027777777776, 12188: 570.27, 8290: 48.865, 8627: 58.58444444444444, 9089: 265.59277777777777, 9559: 75.06833333333333, 9602: 93.1525, 9644: 340.0602777777778, 9675: 77.09361111111112, 9733: 2060.5244444444443, 9745: 100.49305555555556, 9844: 266.2866666666667, 9852: 0.4633333333333333, 9905: 76.76222222222222, 9985: 96.82611111111112, 10218: 365.3677777777778, 10283: 186.08944444444444, 10284: 333.4036111111111, 10286: 49.22138888888889, 10374: 74.94166666666666, 10427: 582.2675, 3502: 4.221388888888889, 3557: 84.08083333333333, 4191: 17.41583333333333, 4275: 192.2613888888889, 3867: 53.37833333333333, 5812: 44.69166666666667, 7613: 577.0005555555556, 8058: 194.27277777777778, 8306: 334.5441666666667, 8794: 196.66861111111112, 9525: 13.445277777777777, 9657: 475.7475, 9673: 274.3911111111111, 9688: 71.70083333333334, 9917: 38.50361111111111, 9923: 118.03027777777778, 9954: 105.785, 9973: 213.25833333333333, 9980: 785.1655555555556, 10125: 6.964444444444444, 10343: 1.1558333333333333, 10566: 237.41805555555555, 10625: 98.17972222222222, 10656: 464.9433333333333, 10821: 8.61888888888889, 10852: 131.9947222222222, 10889: 5.635833333333333, 11037: 173.38833333333332, 11132: 166.8138888888889, 11184: 463.2705555555556, 11207: 74.52333333333333, 11273: 164.06222222222223, 11296: 438.4775, 11305: 875.1633333333333, 11430: 162.8461111111111, 11443: 143.04861111111111, 11486: 77.73944444444444, 556: 776.8925, 1278: 34.888333333333335, 1300: 124.81194444444445, 8128: 119.59361111111112, 8439: 188.4936111111111, 8700: 90.39361111111111, 8810: 498.18638888888887, 9121: 49.74944444444444, 9479: 70.58666666666667, 9540: 226.57472222222222, 9573: 7.885277777777778, 9625: 253.48361111111112, 9723: 264.04, 9769: 81.79861111111111, 9771: 11.460555555555555, 9775: 453.7477777777778, 9800: 10.661944444444444, 9817: 178.60416666666666, 9819: 266.51666666666665, 9875: 809.3347222222222, 9955: 5.408055555555555, 9970: 109.68833333333333, 10184: 99.59972222222223, 10223: 2981.6805555555557, 10267: 179.24027777777778, 10289: 193.62277777777777, 12269: 469.76194444444445, 12337: 305.31444444444446, 12376: 221.8225, 12418: 15.044444444444444, 12491: 76.08638888888889, 12555: 47.376666666666665, 12577: 1.843611111111111, 12605: 323.72333333333336, 12783: 245.0361111111111, 13023: 0.8258333333333333, 13065: 64.01916666666666, 13105: 15.22638888888889, 13195: 36.3, 13313: 293.0855555555556, 13477: 172.03333333333333, 13508: 20.05361111111111, 10431: 0.4875, 10528: 13.068333333333333, 10668: 44.26638888888889, 10847: 280.5752777777778, 11009: 843.4447222222223, 11126: 5.453888888888889, 11482: 41.545833333333334, 11511: 1525.4469444444444, 11582: 166.04111111111112, 11599: 18.44361111111111, 11653: 121.0188888888889, 11728: 56.18527777777778, 11779: 8.62111111111111, 11872: 18.922777777777778, 11875: 233.05333333333334, 11881: 169.9961111111111, 12027: 421.8222222222222, 12034: 149.3, 12111: 552.8938888888889, 12158: 26.698055555555555, 12182: 0.26916666666666667, 12209: 69.34527777777778, 12213: 90.52111111111111, 12295: 26.893333333333334, 12302: 74.89833333333333, 12347: 113.52666666666667, 12381: 309.68111111111114, 12412: 84.1763888888889, 12435: 52.12305555555555, 12597: 74.0075, 12703: 93.93138888888889, 12714: 60.99055555555555, 12806: 239.52388888888888, 12930: 39.318333333333335, 12999: 141.8188888888889, 8146: 357.3358333333333, 8305: 1553.6263888888889, 8757: 201.66916666666665, 9203: 68.07333333333334, 9342: 91.02277777777778, 9411: 21.965, 9431: 263.2288888888889, 9485: 704.0888888888888, 9518: 94.91944444444445, 9646: 171.50416666666666, 9788: 315.1933333333333, 9830: 239.8852777777778, 9984: 0.38, 10022: 158.59916666666666, 10024: 544.2280555555556, 10042: 84.92972222222222, 10109: 5.07, 10200: 153.37944444444443, 10213: 151.0063888888889, 10308: 165.0475, 10418: 18.224444444444444, 10488: 153.59694444444443, 10521: 53.01166666666666, 10525: 260.9947222222222, 10530: 31.852222222222224, 10579: 0.31333333333333335, 10593: 537.8997222222222, 13317: 192.16916666666665, 13362: 154.3486111111111, 9158: 0.3244444444444444, 9519: 474.64972222222224, 9558: 182.4138888888889, 9584: 127.6536111111111, 9669: 85.27638888888889, 9686: 99.03055555555555, 9872: 23.851111111111113, 9897: 158.4363888888889, 9948: 99.09, 10045: 525.64, 10057: 50.565, 10100: 13.713888888888889, 10101: 4.020277777777777, 10113: 28.232222222222223, 7986: 852.4825, 8372: 307.68611111111113, 9128: 593.7966666666666, 9209: 1534.1041666666667, 9283: 553.6713888888889, 9306: 507.6119444444444, 9469: 271.3597222222222, 9471: 388.19638888888886, 9475: 0.6363888888888889, 9494: 161.87694444444443, 9586: 184.07805555555555, 9643: 385.62888888888887, 9693: 76.07222222222222, 9732: 38.31777777777778, 9803: 139.0547222222222, 9890: 28.3525, 10028: 85.64222222222222, 10079: 444.6725, 10141: 165.4086111111111, 10179: 21.690555555555555, 10198: 4.755, 10233: 269.23277777777776, 10302: 7.553333333333334, 8664: 333.92333333333335, 8760: 749.5038888888889, 8890: 1.6102777777777777, 9250: 292.35333333333335, 9310: 175.9961111111111, 9419: 1251.1294444444445, 9724: 55.701388888888886, 9766: 3.243611111111111, 9787: 116.70527777777778, 9838: 80.7275, 9851: 174.70166666666665, 10029: 106.80194444444444, 10055: 5.733055555555556, 10075: 18.56861111111111, 10093: 2.5052777777777777, 10255: 5.821111111111111, 10258: 0.25, 10291: 178.49305555555554, 10372: 233.61027777777778, 10381: 548.2769444444444, 10408: 770.0969444444445, 11597: 23.849166666666665, 11624: 2.0741666666666667, 11627: 113.90055555555556, 11635: 385.22277777777776, 11660: 90.02583333333334, 11717: 128.91305555555556, 11766: 598.1936111111111, 11785: 383.8697222222222, 11810: 3.6483333333333334, 11831: 410.0391666666667, 12030: 85.36777777777777, 12127: 457.185, 12139: 56.29972222222222, 12163: 58.15277777777778, 12252: 157.2625, 12298: 66.95472222222222, 12326: 146.915, 12391: 14.741111111111111, 12396: 71.5875, 12411: 3.2288888888888887, 12473: 3.039722222222222, 12489: 9.428055555555556, 12549: 311.8569444444444, 12559: 151.08833333333334, 12701: 386.25166666666667, 12708: 66.77583333333334, 12733: 123.73583333333333, 12786: 22.932777777777776, 12793: 357.12666666666667, 12847: 214.46527777777777, 12859: 96.16666666666667, 12873: 19.961388888888887, 12893: 51.93388888888889, 12933: 6.107777777777778, 13194: 226.34777777777776, 13259: 32.905, 13262: 27.9475, 13334: 126.62722222222222, 13433: 50.02388888888889, 13629: 19.009722222222223, 10311: 23.822777777777777, 10511: 35.5525, 10529: 459.465, 10533: 4.708888888888889, 10596: 173.0077777777778, 10624: 91.14888888888889, 10798: 432.0313888888889, 10810: 99.83833333333334, 10873: 883.0736111111111, 10892: 190.28944444444446, 11016: 101.04638888888888, 11114: 118.82194444444444, 11117: 1342.743611111111, 11142: 0.4922222222222222, 11147: 153.08916666666667, 11185: 1116.3347222222221, 11200: 167.83444444444444, 11223: 85.66388888888889, 11340: 188.3025, 11397: 3.05, 11501: 326.9263888888889, 11522: 9.58388888888889, 11612: 21.988055555555555, 11907: 78.23861111111111, 11987: 237.4661111111111, 11998: 44.86194444444445, 12038: 267.1486111111111, 12130: 223.7888888888889, 12195: 174.03555555555556, 12293: 72.32333333333334, 12332: 70.29138888888889, 12360: 8.756944444444445, 12374: 137.38722222222222, 12379: 193.23833333333334, 12380: 2.4627777777777777, 12424: 3.0366666666666666, 12458: 151.86944444444444, 12507: 338.7841666666667, 12527: 71.59944444444444, 12585: 105.13083333333333, 12728: 151.00222222222223, 12833: 35.456944444444446, 12844: 0.7658333333333334, 12908: 170.09194444444444, 10316: 83.39916666666667, 10324: 22.58888888888889, 10328: 1.346111111111111, 10442: 157.69333333333333, 10674: 507.2758333333333, 10745: 1795.8130555555556, 10750: 173.76666666666668, 10792: 237.2513888888889, 10832: 95.04166666666667, 10850: 667.5294444444445, 10939: 243.9713888888889, 10962: 41.24472222222222, 11002: 10.0025, 11065: 502.2005555555556, 11119: 205.915, 11146: 252.44583333333333, 11190: 122.06583333333333, 11333: 221.6822222222222, 11339: 92.95555555555555, 11344: 96.9036111111111, 11438: 4.238333333333333, 11465: 331.66833333333335, 11473: 7.341666666666667, 11600: 226.98638888888888, 11629: 1.7766666666666666, 11755: 162.1861111111111, 11778: 90.75777777777778, 11782: 22.5225, 11861: 428.1936111111111, 11870: 28.456666666666667, 12048: 52.649166666666666, 12068: 114.20694444444445, 12105: 150.98888888888888, 12171: 6.237777777777778, 12236: 14.10361111111111, 12243: 108.77277777777778, 12275: 238.38111111111112, 12454: 3.3219444444444446, 12488: 159.70888888888888, 12503: 118.24583333333334, 12533: 16.64138888888889, 12596: 244.65, 12658: 169.67388888888888, 12666: 121.5038888888889, 10614: 424.1069444444444, 10677: 1101.8083333333334, 10733: 1061.611111111111, 10797: 190.52166666666668, 10833: 1647.1744444444444, 10908: 318.2994444444444, 10938: 197.78805555555556, 11160: 100.4961111111111, 11295: 5.5377777777777775, 11298: 103.95777777777778, 11311: 166.60333333333332, 11380: 242.82944444444445, 11425: 1704.951388888889, 11557: 126.84305555555555, 11558: 378.14916666666664, 11566: 355.1175, 11574: 94.60333333333334, 11585: 21.794166666666666, 11626: 7.69, 11641: 3.966388888888889, 11645: 225.54611111111112, 11655: 648.4683333333334, 11674: 367.5786111111111, 11714: 535.6975, 11808: 202.29166666666666, 11839: 15.023333333333333, 11980: 242.60555555555555, 12168: 107.13722222222222, 12179: 39.864444444444445, 12279: 0.34194444444444444, 12300: 384.49666666666667, 12315: 278.4025, 12319: 91.85805555555555, 12354: 139.34583333333333, 12357: 25.668333333333333, 12362: 6.973055555555556, 12438: 748.7480555555555, 12568: 145.74055555555555, 12583: 443.01, 12615: 140.13666666666666, 12622: 307.1113888888889, 12632: 124.68555555555555, 12655: 154.59777777777776, 12662: 109.13944444444445, 12669: 119.13333333333334, 12710: 195.68472222222223, 12872: 0.3477777777777778, 13031: 0.6086111111111111, 13035: 111.27, 13052: 704.8783333333333, 13093: 73.47027777777778, 13235: 171.1472222222222, 13286: 120.84305555555555, 13294: 102.20944444444444, 10271: 74.46277777777777, 10336: 817.4286111111111, 10401: 148.68805555555556, 10472: 1842.983611111111, 10524: 48.0675, 10546: 151.40194444444444, 10555: 103.79305555555555, 10571: 2840.2180555555556, 10638: 190.73666666666668, 10658: 147.04944444444445, 10709: 435.11333333333334, 10749: 2.7177777777777776, 10783: 118.73166666666667, 10787: 77.80194444444444, 10796: 144.35722222222222, 10840: 458.13944444444445, 10858: 35.80916666666667, 10898: 49.46805555555556, 10942: 178.61138888888888, 10951: 24.03, 11109: 11.324166666666667, 11186: 164.36694444444444, 11244: 385.9894444444444, 11310: 58.2075, 11421: 403.20916666666665, 11452: 9.731944444444444, 11580: 302.24555555555554, 11690: 15.386666666666667, 11832: 7.653333333333333, 11895: 351.8802777777778, 11993: 199.31305555555556, 12018: 283.6455555555556, 12043: 469.9588888888889, 12089: 32.428333333333335, 12172: 262.6125, 12341: 15.058055555555555, 12369: 235.3175, 12370: 122.51166666666667, 12392: 8.895277777777778, 12434: 358.75416666666666, 12512: 25.017777777777777, 12639: 118.44194444444445, 12717: 60.36138888888889, 10443: 0.8761111111111111, 10449: 245.0225, 10510: 172.20222222222222, 10519: 23.238611111111112, 10545: 524.4988888888889, 10667: 173.16583333333332, 10828: 143.96277777777777, 10855: 125.40555555555555, 10912: 16.226666666666667, 10918: 99.6061111111111, 10973: 1.301111111111111, 11000: 17.710555555555555, 11149: 115.70555555555555, 11182: 64.94305555555556, 11280: 530.4777777777778, 11317: 56.49472222222222, 11462: 407.8522222222222, 11520: 2.4138888888888888, 11605: 78.8561111111111, 11642: 6.253888888888889, 11791: 3.721388888888889, 11807: 146.67583333333334, 11814: 49.64194444444445, 11828: 2.700833333333333, 11864: 105.13333333333334, 11871: 186.67666666666668, 12011: 644.32, 12025: 171.5125, 12103: 99.25083333333333, 12234: 71.73138888888889, 12346: 3.6030555555555557, 12361: 4.328333333333333, 12368: 26.761111111111113, 12399: 314.52444444444444, 12445: 96.05472222222222, 12467: 110.28277777777778, 12484: 283.2838888888889, 12617: 160.20388888888888, 12744: 2.9530555555555558, 12824: 8.017222222222221, 12828: 152.1002777777778, 12858: 27.74, 12865: 93.29611111111112, 12926: 36.365833333333335, 12934: 0.9008333333333334, 12961: 64.2136111111111, 12682: 31.05388888888889, 12849: 46.605, 12875: 126.57388888888889, 12922: 33.53805555555556, 13060: 16.516666666666666, 13084: 222.15833333333333, 13097: 318.2236111111111, 13111: 314.56083333333333, 13227: 156.6225, 13246: 358.8786111111111, 13425: 0.41055555555555556, 13466: 68.83027777777778, 13517: 56.80638888888889, 13168: 352.4558333333333, 13266: 2.1575, 13267: 9.014166666666666, 13367: 6.1925, 13452: 10.146944444444445, 13497: 29.910833333333333, 8781: 23.078333333333333, 8787: 1313.2925, 9051: 105.55388888888889, 9064: 317.7736111111111, 9252: 121.23694444444445, 9348: 178.75833333333333, 9548: 144.9947222222222, 9641: 69.83583333333333, 9691: 221.13888888888889, 9716: 63.6975, 9736: 236.31194444444444, 9747: 230.54666666666665, 9829: 118.23777777777778, 9840: 325.8797222222222, 9846: 102.16416666666667, 9858: 23.06888888888889, 9915: 191.57083333333333, 9920: 6.023055555555556, 10018: 389.13944444444445, 10162: 205.22222222222223, 10208: 415.1533333333333, 10297: 216.36583333333334, 10364: 336.7472222222222, 10385: 2.1330555555555555, 10441: 21.845833333333335, 10583: 720.4105555555556, 10609: 99.55666666666667, 10613: 90.70916666666666, 10619: 341.0394444444444, 10777: 228.04916666666668, 10781: 193.21277777777777, 10861: 36.80555555555556, 10886: 60.93833333333333, 10916: 118.33111111111111, 10998: 220.9986111111111, 11013: 304.62111111111113, 12920: 110.82722222222222, 12932: 0.5305555555555556, 12990: 5.868611111111111, 13028: 384.14472222222224, 13032: 1448.016388888889, 13089: 54.86138888888889, 13107: 65.6675, 13132: 96.17777777777778, 13228: 59.42944444444444, 13245: 23.761388888888888, 13305: 4.800833333333333, 13320: 83.62777777777778, 13411: 34.52972222222222, 13413: 24.5175, 13514: 66.00944444444444, 13548: 3.393611111111111, 13579: 12.748333333333333, 3457: 228.81916666666666, 3728: 37.705555555555556, 4891: 355.39222222222224, 3445: 21.4025, 4452: 69.305, 12977: 164.32527777777779, 12980: 217.70527777777778, 13092: 14.879722222222222, 13134: 452.635, 13332: 189.5425, 13387: 115.23611111111111, 13525: 35.13, 5727: 35.705, 5800: 89.48416666666667, 12791: 247.6772222222222, 12817: 56.780833333333334, 12827: 124.65611111111112, 12880: 55.51722222222222, 12892: 3.944722222222222, 12927: 1.6130555555555555, 12956: 62.73583333333333, 13099: 402.4563888888889, 13256: 347.0563888888889, 13299: 18.849444444444444, 13324: 138.44361111111112, 13526: 11.4325, 13336: 77.61916666666667, 13381: 46.215555555555554, 13403: 181.47416666666666, 13471: 141.4947222222222, 13632: 69.25833333333334, 3093: 290.8722222222222, 5498: 83.89583333333333, 7118: 138.3411111111111, 4042: 33.395833333333336, 4214: 59.46944444444444, 4265: 208.4572222222222, 3749: 14.193055555555556, 3999: 12.450833333333334, 11023: 100.38888888888889, 11038: 840.8463888888889, 11057: 168.81694444444443, 11102: 326.00055555555554, 11237: 246.11944444444444, 11258: 626.1077777777778, 11308: 145.61277777777778, 11326: 294.5177777777778, 11513: 39.07138888888889, 11747: 95.56444444444445, 11761: 480.75055555555554, 11772: 402.4002777777778, 11848: 196.00083333333333, 11885: 6.086111111111111, 11997: 58.7675, 12057: 32.11388888888889, 12077: 215.76, 12090: 621.1294444444444, 12215: 205.89527777777778, 12270: 934.075, 12278: 207.57083333333333, 12285: 823.3205555555555, 12299: 33.405833333333334, 12324: 96.08361111111111, 12386: 24.369722222222222, 12404: 52.42, 12516: 24.976944444444445, 12603: 15.461944444444445, 12614: 12.471944444444444, 12626: 44.606944444444444, 12634: 79.82916666666667, 12663: 685.1758333333333, 12672: 11.402777777777779, 12879: 105.11777777777777, 12914: 22.828611111111112, 12928: 157.58777777777777, 12946: 140.32472222222222, 13044: 483.4555555555556, 13053: 98.92722222222223, 13074: 85.05833333333334, 13081: 69.58638888888889, 13091: 167.86833333333334, 13143: 68.34555555555555, 13217: 171.61111111111111, 13353: 3.0208333333333335, 8504: 104.37972222222223, 8957: 42.56333333333333, 9172: 95.59916666666666, 9477: 437.1327777777778, 9496: 2.9027777777777777, 9565: 10.2575, 9577: 479.81583333333333, 9629: 179.95666666666668, 9730: 74.75194444444445, 9845: 190.2525, 9854: 3.2794444444444446, 9899: 22.814166666666665, 10020: 3.9611111111111112, 10077: 1.5047222222222223, 5780: 2.141388888888889, 3431: 24.6675, 3508: 56.82805555555556, 4123: 165.99194444444444, 6364: 2.3705555555555557, 7088: 200.4686111111111, 9015: 3.8336111111111113, 9188: 278.24583333333334, 9193: 1.9302777777777778, 9435: 447.6877777777778, 9473: 62.86361111111111, 9600: 20.329444444444444, 9640: 149.26222222222222, 9650: 144.01472222222222, 9748: 2.141111111111111, 9855: 75.72722222222222, 10027: 425.2997222222222, 10030: 221.23472222222222, 10098: 103.25472222222223, 10104: 108.9738888888889, 10118: 380.7947222222222, 10131: 4.520833333333333, 10178: 2628.2361111111113, 10263: 53.35527777777778, 10434: 402.14166666666665, 10469: 162.4475, 10505: 146.4711111111111, 10623: 235.2161111111111, 10659: 267.4433333333333, 10679: 140.04722222222222, 10757: 5.973888888888889, 10770: 357.4941666666667, 10831: 279.9425, 10321: 142.67, 10347: 1331.568888888889, 10361: 300.56805555555553, 10378: 213.32916666666668, 10416: 61.058055555555555, 10634: 31.28, 10635: 50.14416666666666, 10694: 8.060555555555556, 10704: 121.86027777777778, 10720: 245.90083333333334, 10773: 134.43583333333333, 10849: 156.91444444444446, 10890: 44.54361111111111, 10971: 135.16194444444446, 11001: 0.48055555555555557, 11014: 380.9583333333333, 11134: 0.5108333333333334, 11172: 242.21138888888888, 11193: 690.1769444444444, 11442: 182.7436111111111, 11464: 41.880833333333335, 11589: 172.96916666666667, 11594: 88.88166666666666, 11647: 73.86, 11678: 2.7425, 11763: 27.84861111111111, 11865: 47.986666666666665, 11887: 213.38805555555555, 11914: 18.725833333333334, 11920: 665.8247222222222, 11940: 478.67833333333334, 12033: 72.17194444444445, 12058: 215.62722222222223, 12060: 168.40444444444444, 12106: 7.647222222222222, 12131: 65.04666666666667, 12230: 1.3755555555555556, 12265: 190.9936111111111, 12309: 672.8875, 12323: 49.73722222222222, 12421: 11.893611111111111, 12475: 3.891111111111111, 12565: 264.3063888888889, 12573: 294.7961111111111, 12591: 511.91777777777776, 12740: 5.221111111111111, 12745: 1.0225, 5656: 144.74916666666667, 6208: 2.182777777777778, 7163: 248.7425, 7520: 1000.7547222222222, 4837: 192.99972222222223, 6919: 560.6516666666666, 7660: 2.361666666666667, 7941: 239.66694444444445, 8358: 243.1186111111111, 8378: 4.407222222222222, 8605: 1453.5113888888889, 8820: 592.3063888888889, 9340: 663.2280555555556, 9500: 128.70138888888889, 9582: 312.63972222222225, 9628: 0.23333333333333334, 9713: 308.56444444444446, 9720: 83.35166666666667, 9757: 50.52611111111111, 9801: 169.05277777777778, 9804: 87.11277777777778, 9873: 4.476111111111111, 10063: 118.98777777777778, 10155: 549.4475, 10189: 1.0194444444444444, 10224: 146.81861111111112, 10306: 166.5088888888889, 12965: 148.9438888888889, 12975: 149.97916666666666, 13182: 72.50944444444444, 13289: 6.140555555555555, 13290: 62.98416666666667, 13366: 260.94166666666666, 13507: 74.015, 8514: 771.9983333333333, 9395: 13.150277777777777, 9442: 119.95416666666667, 9564: 345.1986111111111, 9607: 566.0705555555555, 9608: 263.94972222222225, 9698: 1.1725, 9710: 40.654444444444444, 10035: 127.14305555555555, 10091: 148.73805555555555, 10292: 443.7608333333333, 9175: 589.3063888888889, 9514: 4.045277777777778, 9530: 96.38111111111111, 9572: 30.6175, 9579: 5.143611111111111, 9587: 642.1955555555555, 9592: 352.49583333333334, 9655: 328.06416666666667, 9694: 105.64444444444445, 10388: 82.32916666666667, 10626: 271.4555555555556, 10652: 340.69666666666666, 10762: 3.1422222222222222, 10891: 147.42805555555555, 10964: 185.33138888888888, 11106: 84.98555555555555, 11155: 79.03166666666667, 11156: 81.49888888888889, 11214: 96.7461111111111, 11257: 336.95166666666665, 11277: 147.94027777777777, 11312: 101.61638888888889, 11315: 286.64166666666665, 11338: 278.2322222222222, 11518: 1157.2594444444444, 11519: 335.30555555555554, 11523: 237.41833333333332, 11560: 161.46305555555554, 11777: 44.1075, 11816: 99.27472222222222, 11919: 171.18722222222223, 11935: 48.80583333333333, 11985: 2.2475, 12076: 183.3475, 12112: 5.913611111111111, 12202: 71.6836111111111, 12253: 321.7825, 12313: 30.951944444444443, 12331: 95.86694444444444, 12344: 185.1825, 12352: 10.01388888888889, 12371: 13.742222222222223, 12385: 164.41694444444445, 12474: 603.1727777777778, 12494: 233.40527777777777, 12543: 335.9144444444444, 12618: 735.2438888888889, 12640: 286.1791666666667, 12743: 21.11277777777778, 12751: 6.044444444444444, 12777: 52.16222222222222, 12848: 118.98777777777778, 8690: 441.4161111111111, 9360: 24.195, 9537: 42.885, 9590: 183.5225, 9593: 449.7425, 9636: 39.4475, 9683: 41.23694444444445, 9700: 10.789166666666667, 9712: 94.28055555555555, 9728: 85.07444444444444, 9795: 304.18611111111113, 9849: 551.21, 9901: 11.201666666666666, 10002: 278.7961111111111, 10044: 11.299444444444445, 10052: 127.18111111111111, 10120: 43.89666666666667, 10239: 37.82138888888889, 10310: 119.03138888888888, 10390: 11.606666666666667, 10445: 47.5925, 10458: 344.4113888888889, 10476: 329.1175, 10482: 312.02166666666665, 10641: 363.8666666666667, 10741: 336.49333333333334, 10914: 4.904166666666667, 10955: 299.6822222222222, 10956: 171.40916666666666, 10959: 54.71111111111111, 10961: 123.87333333333333, 11007: 1887.7627777777777, 11039: 77.71916666666667, 11205: 243.48611111111111, 11307: 176.22555555555556, 11475: 131.7236111111111, 11497: 3.915277777777778, 11506: 859.4436111111111, 11516: 2.278888888888889, 11531: 3.8466666666666667, 11676: 8.894166666666667, 11692: 189.47416666666666, 11716: 0.12472222222222222, 11746: 342.7547222222222, 11757: 150.05027777777778, 11820: 69.5725, 11897: 62.88194444444444, 11913: 275.13444444444445, 11941: 94.30111111111111, 11996: 4.134444444444444, 12022: 123.7275, 12023: 46.908055555555556, 12047: 50.844722222222224, 12085: 81.3475, 12119: 96.17611111111111, 12122: 103.00722222222223, 12226: 71.00194444444445, 12322: 71.56777777777778, 12409: 83.18111111111111, 12441: 549.8130555555556, 12469: 139.0036111111111, 12495: 121.30722222222222, 12594: 213.0525, 12613: 1163.381111111111, 12651: 295.1380555555556, 12748: 263.7347222222222, 12754: 149.94944444444445, 12797: 29.13, 11063: 6.210833333333333, 11077: 365.21305555555557, 11115: 234.2602777777778, 11133: 2.9808333333333334, 11159: 505.19055555555553, 11525: 163.33833333333334, 11603: 65.25638888888889, 12001: 21.433888888888887, 12016: 213.05444444444444, 12114: 378.3066666666667, 12338: 64.50111111111111, 12351: 142.1977777777778, 12459: 198.16805555555555, 12511: 86.72722222222222, 12644: 35.32361111111111, 12739: 283.7897222222222, 12762: 29.216944444444444, 12798: 8.069166666666666, 12910: 495.86805555555554, 12941: 280.27444444444444, 13003: 162.07638888888889, 13163: 4.676666666666667, 13251: 5.933055555555556, 13339: 414.93694444444446, 13379: 253.27194444444444, 13434: 20.845555555555556, 13458: 222.865, 13494: 48.16305555555556, 13615: 27.13, 10334: 53.20527777777778, 10383: 157.35722222222222, 10464: 223.2675, 10532: 29.77166666666667, 10574: 453.5461111111111, 10607: 8.092222222222222, 10714: 246.83277777777778, 10752: 436.1286111111111, 10786: 30.85972222222222, 10800: 331.85083333333336, 10883: 226.8788888888889, 10897: 114.29333333333334, 10995: 89.81222222222222, 11072: 215.50027777777777, 11093: 73.68694444444445, 11161: 143.3202777777778, 11208: 169.98111111111112, 11216: 632.5825, 11266: 6.2027777777777775, 11287: 235.44833333333332, 11304: 20.475555555555555, 11533: 69.8463888888889, 11851: 0.5844444444444444, 11972: 2.118611111111111, 12041: 405.25638888888886, 12080: 163.07694444444445, 12082: 281.8, 12180: 38.34444444444444, 12192: 74.15333333333334, 12247: 116.49833333333333, 12417: 166.3538888888889, 12493: 3.0494444444444446, 12519: 253.25722222222223, 12584: 77.50916666666667, 12619: 1.341388888888889, 12721: 1655.6461111111112, 12723: 554.9266666666666, 12779: 288.4766666666667, 9816: 0.3005555555555556, 9867: 174.84194444444444, 9931: 100.53083333333333, 10115: 324.0308333333333, 10144: 187.16916666666665, 10331: 154.15416666666667, 10477: 343.4025, 10479: 3.2827777777777776, 10520: 318.34194444444444, 10572: -0.0002777777777777778, 10573: 1.4041666666666666, 10666: 313.1438888888889, 10804: 90.98361111111112, 10993: 7.701666666666667, 11074: 28.750833333333333, 11095: 130.32722222222222, 11300: 479.66, 11432: 165.60111111111112, 11488: 48.07805555555556, 11499: 0.13194444444444445, 11526: 437.41, 11550: 5.165, 11637: 74.76166666666667, 11638: 1.4066666666666667, 11855: 314.7613888888889, 11862: 118.37333333333333, 11886: 265.5233333333333, 11950: 121.44638888888889, 11962: 5.131388888888889, 12036: 76.03861111111111, 12126: 57.90222222222222, 12133: 72.13444444444444, 12183: 139.64944444444444, 12214: 23.44361111111111, 12271: 221.76277777777779, 12320: 100.94472222222223, 12325: 86.90722222222222, 12327: 54.96472222222222, 12882: 197.97083333333333, 12897: 141.92944444444444, 12898: 25.762777777777778, 12911: 447.7005555555556, 12917: 402.24583333333334, 13010: 0.5641666666666667, 13124: 240.21305555555554, 13221: 3.0130555555555554, 13264: 0.06444444444444444, 13279: 714.3644444444444, 13386: 1.3841666666666668, 13489: 10.309722222222222, 13510: 31.2425, 13112: 319.0397222222222, 13123: 133.7736111111111, 13252: 264.7969444444444, 13331: 91.3825, 13405: 4.049722222222222, 13456: 236.04944444444445, 12784: 194.92111111111112, 12811: 2.435, 12821: 102.38166666666666, 12891: 54.32694444444444, 12996: 98.48166666666667, 13009: 16.886388888888888, 13018: 159.91361111111112, 13050: 225.24194444444444, 13095: 87.00944444444444, 13106: 15.115833333333333, 13191: 30.700277777777778, 13359: 178.02027777777778, 13428: 295.9266666666667, 13449: 103.07638888888889, 13481: 149.90805555555556, 12487: 20.066111111111113, 12508: 93.62944444444445, 12657: 946.9763888888889, 12707: 1.6488888888888888, 12822: 46.473333333333336, 12885: 1.4077777777777778, 12906: 187.7025, 12929: 149.05277777777778, 13129: 167.54361111111112, 13382: 2.6277777777777778, 13402: 201.91416666666666, 13455: 72.04277777777777, 13515: 21.72222222222222, 4699: 96.55277777777778, 4972: 51.53194444444444, 5578: 344.18805555555554, 6052: 138.5247222222222, 6352: 131.82, 8210: 1157.4519444444445, 8748: 195.60944444444445, 8763: 147.19194444444443, 8955: 402.24333333333334, 9526: 52.59916666666667, 9535: 8.812777777777777, 9619: 146.08444444444444, 9679: 78.63, 9702: 20.649166666666666, 9706: 72.11194444444445, 9921: 483.34777777777776, 9939: 25.704722222222223, 9969: 132.5725, 9982: 69.29861111111111, 10060: 0.6841666666666667, 10110: 3.345, 10164: 183.9988888888889, 10187: 146.0125, 10300: 144.7722222222222, 10369: 330.65972222222223, 10394: 1.1333333333333333, 10457: 324.81166666666667, 10484: 330.65944444444443, 10489: 446.1380555555556, 10522: 203.83666666666667, 10588: 640.7122222222222, 10592: 327.2608333333333, 10716: 500.50138888888887, 10756: 152.68027777777777, 10815: 175.46555555555557, 10836: 118.95805555555556, 10894: 119.75055555555555, 10903: 394.25055555555554, 10968: 4.126111111111111, 11006: 3.2080555555555557, 11020: 138.815, 11087: 29.926388888888887, 11178: 2.611388888888889, 11286: 915.3666666666667, 11396: 74.38944444444445, 11447: 26.294166666666666, 11493: 194.04805555555555, 11496: 123.02583333333334, 11532: 46.948055555555555, 11578: 772.6672222222222, 11581: 106.09388888888888, 11648: 97.55805555555555, 11664: 5.374722222222222, 11715: 101.04833333333333, 11856: 103.33527777777778, 11953: 140.71083333333334, 11981: 117.91805555555555, 12037: 2.9277777777777776, 12101: 167.36083333333335, 12164: 123.5963888888889, 12334: 74.14277777777778, 12339: 50.909444444444446, 12353: 33.38805555555555, 12355: 49.46888888888889, 12394: 69.37694444444445, 8453: 222.15805555555556, 8458: 1347.945, 8749: 50.71944444444444, 8825: 132.73777777777778, 9264: 1042.4788888888888, 9332: 43.39555555555555, 9448: 114.22222222222223, 9534: 164.315, 9665: 52.243611111111115, 9780: 351.4611111111111, 9784: 350.0469444444444, 9796: 84.35777777777778, 9818: 339.68361111111113, 9978: 15.7525, 10017: 46.13027777777778, 10058: 68.98777777777778, 10314: 340.10527777777776, 10346: 271.6711111111111, 10353: 32.68944444444445, 10393: 50.14138888888889, 10563: 1372.7625, 10578: 149.30138888888888, 10584: 412.3777777777778, 10742: 148.71138888888888, 10982: 99.33222222222223, 10987: 117.37166666666667, 11018: 124.70444444444445, 11148: 71.565, 11294: 120.76527777777778, 11332: 4202.185, 11359: 56.40083333333333, 11382: 115.08222222222223, 11535: 3.4291666666666667, 11538: 134.11138888888888, 11732: 564.1833333333333, 11934: 129.71583333333334, 11974: 174.35944444444445, 11989: 40.9075, 12258: 383.62111111111113, 12455: 24.886944444444445, 12529: 248.1797222222222, 12579: 262.2288888888889, 12659: 63.9925, 12750: 73.07694444444445, 12825: 17.826944444444443, 12890: 118.39694444444444, 12919: 81.8175, 12937: 53.727222222222224, 13033: 3.4705555555555554, 13141: 2.540277777777778, 13209: 5.676944444444445, 13226: 108.17277777777778, 13378: 4.3225, 13389: 309.2963888888889, 13436: 113.17333333333333, 13478: 218.48333333333332, 12431: 118.63194444444444, 12546: 4.284166666666667, 12563: 8.050833333333333, 12602: 45.95055555555555, 12611: 268.2125, 12625: 37.14333333333333, 12641: 8.369444444444444, 12652: 30.636944444444445, 12763: 43.27333333333333, 12894: 221.31833333333333, 13086: 185.64916666666667, 13183: 15.337777777777777, 13216: 5.792777777777777, 13233: 686.5669444444444, 13254: 21.01222222222222, 13306: 5.848888888888889, 13310: 0.2663888888888889, 13351: 150.05416666666667, 13355: 55.10888888888889, 13356: 47.21194444444444, 13376: 145.63583333333332, 13418: 0.5247222222222222, 13424: 116.67083333333333, 13439: 537.8541666666666, 4761: 22.17111111111111, 6682: 29.59111111111111, 9206: 285.7030555555556, 9544: 58.71888888888889, 9585: 140.89111111111112, 9617: 37.06722222222222, 9637: 140.8527777777778, 9660: 80.87833333333333, 9707: 431.18416666666667, 9974: 48.665277777777774, 9998: 85.54722222222222, 10033: 2.698888888888889, 10078: 68.3, 10107: 100.625, 10412: 120.72305555555556, 10542: 294.505, 10650: 318.6011111111111, 10654: 212.5586111111111, 10814: 5.163888888888889, 10901: 83.0886111111111, 10978: 31.13222222222222, 11248: 43.39194444444445, 11284: 32.93527777777778, 11334: 1.7447222222222223, 11345: 45.43972222222222, 11468: 5.7225, 11476: 152.81027777777777, 11491: 147.53027777777777, 11541: 244.4111111111111, 11555: 9.566944444444445, 11587: 43.11, 11617: 337.29527777777776, 11640: 215.49527777777777, 11669: 53.41277777777778, 11687: 124.13027777777778, 11758: 100.34944444444444, 11762: 0.2588888888888889, 11923: 193.16694444444445, 11943: 186.12777777777777, 12044: 349.01944444444445, 12166: 108.14083333333333, 12167: 123.69083333333333, 12185: 1.9966666666666666, 12190: 334.3275, 12401: 58.531666666666666, 12415: 69.08055555555555, 12466: 130.0572222222222, 12521: 71.69805555555556, 12554: 460.73805555555555, 12581: 347.5877777777778, 12637: 49.30972222222222, 12727: 102.57861111111112, 12803: 1037.1522222222222, 12860: 2.9138888888888888, 12909: 210.08916666666667, 12916: 2.2327777777777778, 12931: 22.595, 12988: 188.50444444444443, 13066: 207.92138888888888, 13152: 116.66166666666666, 13296: 210.3947222222222, 13373: 239.40472222222223, 13421: 134.4438888888889, 13422: 7.5675, 13448: 684.0313888888888, 13483: 3.4702777777777776, 13518: 76.9636111111111, 5606: 880.3475, 7415: 676.7766666666666, 8174: 329.9791666666667, 8277: 466.3911111111111, 8398: 119.91666666666667, 8963: 434.53194444444443, 9204: 528.1563888888888, 9313: 87.97416666666666, 9336: 4654.768055555555, 9386: 61.75888888888889, 9472: 59.31138888888889, 9597: 166.90305555555557, 9622: 217.70916666666668, 9705: 30.261111111111113, 9734: 958.4508333333333, 9755: 60.55861111111111, 9762: 6.176111111111111, 9971: 238.38472222222222, 10066: 139.80555555555554, 10114: 82.23166666666667, 10202: 93.0038888888889, 10366: 550.1838888888889, 10471: 518.1830555555556, 10527: 50.76416666666667, 10531: 49.82972222222222, 10629: 63.54055555555556, 10753: 188.25972222222222, 10799: 136.86666666666667, 10859: 109.01972222222223, 10944: 94.54777777777778, 11069: 187.9575, 11080: 242.26527777777778, 11121: 25.34722222222222, 11138: 99.10055555555556, 11249: 549.4180555555556, 11362: 53.93055555555556, 11454: 161.9261111111111, 11472: 13.056666666666667, 11505: 7.001111111111111, 11527: 203.6825, 11559: 884.5269444444444, 11666: 52.806111111111115, 11684: 18.753888888888888, 11697: 193.46555555555557, 11770: 294.74333333333334, 11911: 407.7963888888889, 11916: 382.2238888888889, 11971: 113.94055555555556, 11984: 69.1463888888889, 12026: 6.051111111111111, 12064: 2.23, 12135: 5.725833333333333, 12136: 233.10444444444445, 12175: 1.896111111111111, 12216: 121.71694444444445, 12244: 9.889722222222222, 12277: 837.3716666666667, 12317: 68.97777777777777, 12395: 214.1275, 12660: 186.5713888888889, 12772: 239.86138888888888, 12826: 201.5377777777778, 12853: 74.10472222222222, 12870: 123.38944444444445, 12877: 39.44138888888889, 12989: 15.311666666666667, 13046: 69.04916666666666, 13062: 73.33194444444445, 13116: 269.16777777777776, 13130: 46.48694444444445, 13205: 167.91722222222222, 13298: 289.58472222222224, 13325: 139.76611111111112, 13404: 261.04805555555555, 13492: 78.92444444444445}

In [0]:
import math 
## Feature Set

## Chart Features
chart_feats = [1001, 1002, 1003, 1004, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016]

for i in chart_feats: 
  print(item_dict[str(i)])
# Surgery Vital Signs
surgery_feats = surgery_vital_signs['ITEMID'].value_counts().index.tolist() 

for i in surgery_feats: 
  print(item_dict[i])

lab_feats = [5225, 
             5097, 
             5141, 
             5129, 
             5257, 
             5114,
             5113,
             5115,
             5132,
             5136,
             5226,
             5230,
             5218,
             5224,
             5212,
             5033,
             5041,
             5223,
             5215,
             5174,
             5111,
             6317,
             5094,
             5492,
             5002,
             5075,
             5237,
             5249,
             5235,
             5239,
             5227,
             5026,
             5031,
             5024,
             6085
             ]

for i in lab_feats: 
  print(lab_item_dict[i])
  


In [0]:
def get_feature_name(idx): 
  if idx < (len(lab_feats)): 
    return lab_item_dict[lab_feats[idx]]
  elif idx < (len(lab_feats) + len(chart_feats)): 
    return item_dict[str(chart_feats[idx - len(lab_feats)])]
  elif idx < (len(lab_feats) + len(chart_feats) + len(surgery_feats)): 
    return item_dict[surgery_feats[idx - len(lab_feats) - len(chart_feats)]]
  elif idx < (len(lab_feats) + len(chart_feats) + len(surgery_feats) + 2):
    return 'gender'
  else: 
    return 'age'

def get_feature_name_flattened(idx): 
  hours_in = idx // (len(lab_feats) + len(chart_feats) + len(surgery_feats))

  idx -= hours_in * (len(lab_feats) + len(chart_feats) + len(surgery_feats))

  if hours_in == WINDOW_SIZE: 
    if idx < 2: 
      return 'gender'
    else: 
      return 'age'
  else: 
    if idx < (len(lab_feats)): 
      return lab_item_dict[lab_feats[idx]]
    elif idx < (len(lab_feats) + len(chart_feats)): 
      return item_dict[str(chart_feats[idx - len(lab_feats)])]
    elif idx < (len(lab_feats) + len(chart_feats) + len(surgery_feats)): 
      return item_dict[surgery_feats[idx - len(lab_feats) - len(chart_feats)]]

In [0]:
def key_fn(tup):
  return abs(tup[0])

def sort_importance(coefficients, feat_name_fn):
  coef_shape = coefficients.shape
  print(coef_shape)
  importance = []

  for i in range(coef_shape[1]):
    importance.append((coefficients[0,i], feat_name_fn(i)))
    # print('index {} gives {}'.format(i, feat_name_fn(i)))
  
  return sorted(importance, key=key_fn, reverse=True)

In [0]:
def convert_float_tensor(X):
  return torch.tensor(X).float()

def perturb_features(model, X, feature_range=None):
  if feature_range is None:
    feature_range = (0, X.shape[2])

  print(feature_range)
  perturb_effects = []
  orig_out = model(torch.tensor(X).float())
  perturbation = np.random.normal(0.0, 0.2, size=X.shape[:2])
  for ind in range(feature_range[0], feature_range[1]):
    variable_name = get_feature_name(ind)
    # print(f'Dealing with variable {variable_name}')
    perturbation = np.random.normal(0.0, 0.2, size=X.shape[:2])
    X[:, :, ind] = X[:, :, ind] + perturbation
    effect = ((orig_out - model(convert_float_tensor(X))) ** 2).mean() ** 0.5
    print(f'Variable {ind+1} name ({variable_name}), perturbation effect: {effect:.4f}')
    perturb_effects.append((effect, variable_name, ind))
    X[:, :, ind] = X[:, :, ind] - perturbation
  
  return sorted(perturb_effects, key=key_fn, reverse=True)

We use these to index into the tensors that follow (i.e. chart_X[patient_index_of[subject_id]] is what you want, not chart_X[subject_id]. Similar for item_id's

In [0]:
# More Helper Dicts
chart_index_of = dict() 
for i in range(len(chart_feats)): 
  chart_index_of[chart_feats[i]] = i
  
lab_index_of = dict() 
for i in range(len(lab_feats)): 
  lab_index_of[lab_feats[i]] = i

surgery_index_of = dict() 
for i in range(len(surgery_feats)): 
  surgery_index_of[surgery_feats[i]] = i


print(chart_index_of)
print(lab_index_of)
print(surgery_index_of)

patient_index_of = dict() 
cc = 0
for p in patient_set: 
  patient_index_of[p] = cc 
  cc += 1
  
  

# Feature / Label Definition Section

Lets only sample up to 100 hours. That seems reasonable, the whole stay of ~75% of patients are captured here.

In [0]:
MAX_HOURS = 100
WINDOW_SIZE = 24
GAP_TIME = 6
PRED_SIZE = 6

In [0]:
def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False

Here, **feat_val** can be either min, max, or mean depending on the task you want. Change the code beneath it to fit the task. Remember to change the loop to go over lab events of the ID you want. Here, it is the **minimum wbc count** measured over an hour-long window.

**feat_cnt** is used to make sure the measurement is made over the prediction interval.

**feat_agg** is the label we use in prediction. feat_val[ i ] [ j ] denotes the lab test value for patient i where our window starts at hour j. In other words, it is the aggregate value over hours [j + WINDOW_SIZE + GAP_SIZE, j + WINDOW_SIZE + GAP_TIME + PRED_SIZE). Here, I'm again computing the min. You should change this if you want max/mean/whatever.

**window_cnt** is the number of times the test was taken in the desired window. The window is [j + WINDOW_SIZE + GAP_SIZE, j + WINDOW_SIZE + GAP_TIME + PRED_SIZE)

This should take a LONG time, like ~30 minutes. In the meantime, you can probably open a second colab and tune some hyperparameters on Mortality/LOS.

In [0]:
subjects_to_remove = set() 

feat_agg = np.zeros((len(patient_set), MAX_HOURS))
feat_cnt = np.zeros((len(patient_set), MAX_HOURS))
for i in range(len(patient_set)): 
  if patient_set[i] in SEPTIC: 
    if int(SEPTIC[patient_set[i]]) < MAX_HOURS: 
      feat_cnt[i][int(SEPTIC[patient_set[i]])] = 1

###
for i in range(len(patient_set)): 
  if patient_set[i] in SEPTIC: 
    for j in range(MAX_HOURS - WINDOW_SIZE - GAP_TIME - PRED_SIZE):
      for k in range(j+WINDOW_SIZE+GAP_TIME, j+WINDOW_SIZE+GAP_TIME+PRED_SIZE): 
        feat_agg[i][j] += feat_cnt[i][k]
###

This is just normal stuff.

In [0]:
chart_X = np.zeros((len(patient_set), MAX_HOURS, len(chart_feats)))
chart_Xcnt = np.zeros((len(patient_set), MAX_HOURS, len(chart_feats)))
lab_X = np.zeros((len(patient_set), MAX_HOURS, len(lab_feats)))
lab_Xcnt = np.zeros((len(patient_set), MAX_HOURS, len(lab_feats)))
surgery_X = np.zeros((len(patient_set), MAX_HOURS, len(surgery_feats)))
surgery_Xcnt = np.zeros((len(patient_set), MAX_HOURS, len(surgery_feats)))

for _, row in lab_events[lab_events['HOURS_IN'] < MAX_HOURS][lab_events['ITEMID'].isin(lab_feats)].iterrows():
  if row.HOURS_IN < 0: 
    subjects_to_remove.add(row.SUBJECT_ID)
  elif is_number(row.VALUE): 
    lab_X[patient_index_of[row.SUBJECT_ID]][int(row.HOURS_IN)][lab_index_of[row.ITEMID]] += row.VALUENUM
    lab_Xcnt[patient_index_of[row.SUBJECT_ID]][int(row.HOURS_IN)][lab_index_of[row.ITEMID]] += 1 

for _, row in surgery_vital_signs[surgery_vital_signs['HOURS_IN'] < MAX_HOURS][surgery_vital_signs['ITEMID'].isin(surgery_feats)].iterrows():
  if row.HOURS_IN < 0: 
    subjects_to_remove.add(row.SUBJECT_ID)
  elif is_number(row.VALUE): 
    surgery_X[patient_index_of[row.SUBJECT_ID]][int(row.HOURS_IN)][surgery_index_of[row.ITEMID]] += row.VALUE
    surgery_Xcnt[patient_index_of[row.SUBJECT_ID]][int(row.HOURS_IN)][surgery_index_of[row.ITEMID]] += 1 

for _, row in chartevents[chartevents['HOURS_IN'] < MAX_HOURS][chartevents['ITEMID'].isin(chart_feats)].iterrows():
  if row.HOURS_IN < 0: 
    subjects_to_remove.add(row.SUBJECT_ID)
    continue 
  elif is_number(row.VALUE): 
    chart_X[patient_index_of[row.SUBJECT_ID]][int(row.HOURS_IN)][chart_index_of[row.ITEMID]] += row.VALUENUM 
    chart_Xcnt[patient_index_of[row.SUBJECT_ID]][int(row.HOURS_IN)][chart_index_of[row.ITEMID]] += 1 




Here I have simple Forward/Backward Imputation implemented. If time, we can try to implement the various other ones mentioned by https://www.nature.com/articles/s41598-018-24271-9 

global_mean is the mean of each feature over all time points and all patients. If a patient has no occurances of a feature at any time point, it's replaced by the global mean. Otherwise, we propagate values forward/backward to replace missing values. 

In [0]:
# Missing Data Imputation

# Forward/Backward Imputation

# Compute Global means first. 

global_chart_mean = np.zeros(len(chart_feats))
global_lab_mean = np.zeros(len(lab_feats))
global_surgery_mean = np.zeros(len(surgery_feats))

for k in range(len(chart_feats)): 
  global_chart_mean[k] = np.sum(chart_X[:, :, k]) / np.sum(chart_Xcnt[:, :, k])
for k in range(len(lab_feats)): 
  global_lab_mean[k] = np.sum(lab_X[:, :, k]) / np.sum(lab_Xcnt[:, :, k])
for k in range(len(surgery_feats)):
  global_surgery_mean[k] = np.sum(surgery_X[:, :, k]) / np.sum(surgery_Xcnt[:, :, k])


def forward_backward_impute(feats, global_mean): 
  # INPUTS: 
  # Feats -- (MAX_HOURS, num_feats)
  # glboal_mean -- (num_feats)
  # OUTPUTS: 
  # ret -- (MAX_HOURS, num_feats) (imputed)
  ret = feats 
  for j in range(feats.shape[1]):
    for i in range(1, MAX_HOURS): 
      if ret[i][j] <= 0: 
        ret[i][j] = ret[i-1][j]
    for i in range(MAX_HOURS-2, -1, -1): 
      if ret[i][j] <= 0: 
        ret[i][j] = ret[i+1][j]
    for i in range(MAX_HOURS): 
      if ret[i][j] <= 0: 
        ret[i][j] = global_mean[j]
  return ret 



In [0]:
# Set up X, Y 


# Set up labels

patient_set = list(patient_set)

gender_one_hot = np.zeros((len(patient_set), 2))
age_vec = np.zeros((len(patient_set), 1))
for _, row in patients.iterrows(): 
  if row.SUBJECT_ID in patient_set: 
    age_vec[patient_index_of[row.SUBJECT_ID]][0] = (age_at_admission[row.SUBJECT_ID].total_seconds() / 3600.0)
    if row.GENDER == 'M': 
      gender_one_hot[patient_index_of[row.SUBJECT_ID]][0] = 1
    else: 
      gender_one_hot[patient_index_of[row.SUBJECT_ID]][1] = 1

static_vec = np.concatenate((gender_one_hot, age_vec), axis = 1)
# [num_patients, 3]

chart_vec = chart_X / (chart_Xcnt + (chart_Xcnt == 0))
lab_vec = lab_X / (lab_Xcnt + (lab_Xcnt == 0))
surgery_vec = surgery_X / (surgery_Xcnt + (surgery_Xcnt == 0))

for i in range(len(patient_set)): 
  chart_vec[i] = forward_backward_impute(chart_vec[i], global_chart_mean)
  lab_vec[i] = forward_backward_impute(lab_vec[i], global_lab_mean)
  surgery_vec[i] = forward_backward_impute(surgery_vec[i],  global_surgery_mean)

time_vec = np.concatenate((lab_vec, chart_vec, surgery_vec), axis=2)
# time_vec [num_patients, max_hours, num_lab_features + num_chart_features + num_vital_features]

# concatenate this with static_vec [num_patients, 3]

**cohort** is a list of indices that you're training / testing on. For instance, if I want the patients with ID's 5, 10, 6, then cohort [patient_index_of[5], patient_index_of[10], patient_index_of[6]].

If you want to test a certain age cohort, then **cohort** should be the a list of indices of patients in that age cohort. 

Ok here I'm going to set up a DataLoader for the tasks.

In [0]:
from torch.utils.data import DataLoader, Dataset

BATCH_SIZE = 8

class MyDataset(Dataset): 
  def __init__(self, feats, values, statics): 
    self.feats = feats # features you're ingesting
    self.values = values # value aggregate in the prediction window
    self.statics = statics # demographic features corresponding to this guy
  def __len__(self): 
    return self.feats.size(0)
  def __getitem__(self, index): 
    return self.feats[index], self.values[index], self.statics[index]

def get_dataloader(indices, model_type, test = False): 
  my_feats = []
  my_values = []
  my_statics = []
  for i in indices:   
    for t_start in range(MAX_HOURS - WINDOW_SIZE - GAP_TIME - PRED_SIZE): 
      if patient_set[i] in SEPTIC: 
        if (t_start+WINDOW_SIZE+GAP_TIME) >= (int)(SEPTIC[patient_set[i]]): 
          continue 
        my_feats.append(time_vec[i][t_start:t_start+WINDOW_SIZE])
        my_values.append(feat_agg[i][t_start])
        my_statics.append(static_vec[i])
      

  if model_type == 'LSTM': 
    if test: 
      return DataLoader(MyDataset(torch.tensor(my_feats).float(), torch.tensor(my_values).float(), torch.tensor(my_statics).float()), 
                      batch_size=1, 
                      shuffle=True, 
                      drop_last=True)
    else: 
      return DataLoader(MyDataset(torch.tensor(my_feats).float(), torch.tensor(my_values).float(), torch.tensor(my_statics).float()), 
                      batch_size=BATCH_SIZE, 
                      shuffle=True, 
                      drop_last=True)
  else: 
    return my_feats, my_values, my_statics


Here, we're going to set up a cohort to test. Call this prior to train_rf or train_lstm.

In [0]:
def set_up_cohort(cohort): 
  global time_vec
  global static_vec
  msk = np.array([False for i in range(len(time_vec))])
  for p in cohort: 
    msk[p] = True 
  scaler1 = StandardScaler()
  scaler2 = StandardScaler() 
  time_vec = np.array(time_vec)
  static_vec = np.array(static_vec)
  shp = time_vec[msk, ...].shape 
  time_vec[msk, ...] = scaler1.fit_transform(time_vec[msk, ...].reshape(-1, shp[-1])).reshape(shp)
  static_vec[msk, ...] = scaler2.fit_transform(static_vec[msk, ...])


# Here you must change depending on the task.

Now we have our RF helper functions.

In [0]:
from sklearn.model_selection import KFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

def train_rf(task, cohort, n_estimators=200, bootstrap=True, max_features='sqrt', class_weight=None): 
  train_indices, test_indices = train_test_split(np.array(cohort), test_size=0.2)

  train_data = get_dataloader(train_indices, 'RF')
  test_data = get_dataloader(test_indices, 'RF')

  x_train = np.array(train_data[0])
  x_train = x_train.reshape((x_train.shape[0], x_train.shape[1] * x_train.shape[2]))
  x_train = np.concatenate((x_train, np.array(train_data[2])), axis=1)

  ### Obviously change this.
  y_train = np.int_(np.array(train_data[1]) > 0)
  ###

  x_test = np.array(test_data[0])
  x_test = x_test.reshape((x_test.shape[0], x_test.shape[1] * x_test.shape[2]))
  x_test = np.concatenate((x_test, np.array(test_data[2])), axis=1)

  ### And this.
  y_test = np.int_(np.array(test_data[1]) > 0)
  ###

  #print((np.sum(y_train) + np.sum(y_test)) / (len(y_train) + len(y_test)))

  model = RandomForestClassifier(n_estimators=n_estimators, 
                             bootstrap = bootstrap,
                             max_features = max_features, 
                             class_weight=class_weight)
  
  model.fit(x_train, y_train)

  return model, x_test, y_test

def evaluate_rf(model, x_test, y_test): 

  rf_predictions = model.predict(x_test)
  rf_probs = model.predict_proba(x_test)[:, 1]

  auc = roc_auc_score(y_test, rf_probs)
  acc = np.sum(rf_predictions == y_test) / len(y_test)

  return auc, acc, rf_predictions, rf_probs 
 

In [0]:
from sklearn.metrics import roc_curve
def plot_roc(title, labels, probs): 
  fpr, tpr, thresholds = roc_curve(labels, probs) 
  plt.figure()
  plt.plot(fpr, tpr, label=title)
  plt.plot([0, 1], [0, 1],'r--')
  plt.xlim([0.0, 1.0])
  plt.ylim([0.0, 1.05])
  plt.xlabel('1 - Specificity')
  plt.ylabel('Sensitivity')
  plt.title('ROC')
  plt.legend(loc="lower right")
  plt.show()

Note we need to further preprocess this data (zero mean, unit variance, PCA, etc..)


RF Helper Funcs

In [0]:
def get_mask(remove_these): 
  msk = [True for i in range(len(patient_set))]
  for p in remove_these: 
    msk[patient_index_of[p]] = False
  return msk 

def run_task_rf(task, min_age=-1, max_age=10000000): 

  cohort = [i for i in range(len(patient_set))]
  mask = get_mask(subjects_to_remove)
  cohort = np.array(cohort)[mask, ...]

  my_msk = [True for i in range(len(cohort))]
  for i in range(len(cohort)): 
    age_here = static_vec[cohort[i]][-1]
    if age_here < min_age: 
      my_msk[i] = False
    elif age_here > max_age: 
      my_msk[i] = False 
  
  cohort = cohort[my_msk, ...]

  set_up_cohort(cohort)
  model, x_test, y_test = train_rf(task, cohort)

  auc, acc, _, _ = evaluate_rf(model, x_test, y_test)

  return auc, acc



In [0]:
 #for i in range(50): 
 # auc, acc = run_task_rf('Sepsis Prediction', -1, 24 * 60)
 # print("AUC of ", auc)
 # print("ACC of ", acc)

In [0]:
TRAIN_BATCHES_PER_EPOCH = 1000
TEST_SAMPLES = 2000

In [0]:
def get_XY(feats, values, statics):
  # feats is B x WINDOW x dim_input
  # values is B
  # statics is B x 3
  X = torch.cat((torch.tensor(feats), torch.tensor(statics).unsqueeze(1).expand((len(feats), len(feats[0]), len(statics[-1])))), dim=-1)
  Y = torch.tensor(torch.tensor(values) > 0).float()

  return X, Y


In [0]:
class LSTM_Classifier(nn.Module):
  def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False):
    super(LSTM_Classifier, self).__init__()

    self.input_size = input_size
    self.hidden_size = hidden_size
    self.num_layers = num_layers 
    self.bidirectional = bidirectional
    self.dropout = dropout

    self.rnn = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True,
                      dropout=dropout, bidirectional=bidirectional)
    self.out = nn.Linear(hidden_size + hidden_size * int(bidirectional), 1)

  def forward(self, input):
    # Input is (B, seq_len, input_size)
    rnn_out, _ = self.rnn(input)
    # rnn_out is (B, seq_len, directions * hidden_size)
    # output is (B, seq_len, 1)
    return self.out(rnn_out)

def rnn_train_one_sample(model, criterion, rnn_optimizer, sent_tensor, tag_tensor, alpha = 0.5, clip=None):

    # sent_tensor is (B, Num Hours, Num feats)
    # tag_tensor is (B)

    model.zero_grad() 

    outputs = model(sent_tensor).squeeze(2)

    # loss = criterion(outputs, tag_tensor) * alpha + criterion(outputs[-1], tag_tensor[-1]) * (1.0-alpha)
    loss = criterion(outputs[:, -1], tag_tensor) 

    loss.backward()

    if clip != None: 
      torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=clip)

    rnn_optimizer.step()

    return outputs, loss.item()


In [0]:
import time
import math
import sklearn
from sklearn.metrics import precision_recall_fscore_support

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def evaluate_result(true_tag_list, predicted_tag_list, probs):
  return np.mean(true_tag_list.numpy() == predicted_tag_list), roc_auc_score(true_tag_list, probs)

# Make prediction for one sentence.
def rnn_predict_one_sent(model, sent_tensor):
  
    outputs = model(sent_tensor).squeeze(2).squeeze(0)
    prob = torch.sigmoid(outputs[-1])

    predicted_tag_id = 0
    if prob > 0.5: 
      predicted_tag_id = 1
    
    return predicted_tag_id, prob.item()


def evaluate_rnn(model, test_dataloader): 
  
  model.eval()
  predicted_tags = []
  probs = []
  labels = []

  iter_count = 0
  for feats, values, statics in test_dataloader:
    iter_count += 1
    if iter_count > TEST_SAMPLES: 
      break 

    sent_tensor, tag_tensor = get_XY(feats, values, statics) 
    sent_tensor = sent_tensor.to(device)
    predicted_tag_id, prob = rnn_predict_one_sent(model, sent_tensor)
    predicted_tags.append(predicted_tag_id)
    probs.append(prob)
    labels.append(tag_tensor[0].item())


  acc, auc = evaluate_result(torch.tensor(labels), predicted_tags, probs)

  return auc, acc, predicted_tags, probs 

def train_model(model, criterion, optimizer, train_dataloader, test_dataloader, n_epochs=5, print_every=100, plot_every=50, learning_rate=1e-3, alpha = 0.5, clip=None): 

  iter_count = 0

  current_loss = 0
  current_norm = 0
  all_losses = []
  all_norms = []

  start = time.time()

  model.train()
  for epoch_i in range(n_epochs):
    num_batches = 0
    for feats, values, statics in train_dataloader: 
      num_batches += 1
      if num_batches > TRAIN_BATCHES_PER_EPOCH: 
        break 

      sent_tensor, tag_tensor = get_XY(feats, values, statics)
    
      sent_tensor = sent_tensor.to(device)
      tag_tensor = tag_tensor.to(device)
  
      output, loss = rnn_train_one_sample(model, criterion, optimizer, sent_tensor, tag_tensor, alpha=alpha, clip=clip)
      current_loss += loss

      if iter_count % print_every == 0:
          print('%d %s %.4f' % (iter_count, timeSince(start), current_loss / print_every))
          current_loss = 0

      iter_count += 1

    auc, acc, _, _ = evaluate_rnn(model, test_dataloader)
    print("Epoch ", epoch_i, " ACC of ", acc, " AUC of ", auc)
  return all_losses, all_norms


In [0]:
def train_rnn(task, cohort, n_epochs=10, 
              rnn_clip = None, 
              rnn_hidden_size = 64, 
              rnn_num_layers = 2, 
              learning_rate = 1e-4, 
              rnn_dropout = 0.5, 
              rnn_alpha = 0.5, 
              weight_decay = 1e-4,
              rnn_bidirectional = False): 

  num_feats = time_vec.shape[-1] + 3

  train_indices, test_indices = train_test_split(np.array(cohort), test_size=0.2)

  train_dataloader = get_dataloader(train_indices, 'LSTM')
  test_dataloader = get_dataloader(test_indices, 'LSTM', test=True)

  rnn_model = LSTM_Classifier(input_size=num_feats, hidden_size=rnn_hidden_size, num_layers=rnn_num_layers, dropout=rnn_dropout, bidirectional=rnn_bidirectional)
  criterion = nn.BCEWithLogitsLoss(pos_weight = torch.tensor(10))
  rnn_optimizer = torch.optim.Adam(rnn_model.parameters(), lr=learning_rate, weight_decay=weight_decay)

  losses, norms = train_model(rnn_model, criterion, rnn_optimizer, train_dataloader, test_dataloader, n_epochs=n_epochs, alpha=rnn_alpha, clip=rnn_clip)

  return rnn_model, test_dataloader

Helpers with evaluation. We should tune these hyperparameters. In particular I set n_epochs to 2 because I just wanted to know that it works kinda :p

In [0]:
# Evaluation

def get_mask(remove_these): 
  msk = [True for i in range(len(patient_set))]
  for p in remove_these: 
    msk[patient_index_of[p]] = False
  return msk 

def run_task_rnn(task, min_age=-1, max_age=10000000): 

  cohort = [i for i in range(len(patient_set))]
  mask = get_mask(subjects_to_remove)
  cohort = np.array(cohort)[mask, ...]

  my_msk = [True for i in range(len(cohort))]
  for i in range(len(cohort)): 
    age_here = static_vec[cohort[i]][-1]
    if age_here < min_age: 
      my_msk[i] = False
    elif age_here > max_age: 
      my_msk[i] = False 
  
  cohort = cohort[my_msk, ...]

  set_up_cohort(cohort)
  model, test_dataloader = train_rnn(task, cohort, n_epochs=5, rnn_bidirectional=True)

  auc, acc, _, _ = evaluate_rnn(model, test_dataloader)

  return model, auc, acc, test_dataloader

 

Set up data and evaluate.

In [0]:
my_model, auc, acc, td = run_task_rnn('SEPSIS PREDICTION', -1, 24*60)
print("AUC of ", auc)
print("ACC of ", acc)


In [0]:

xt = torch.zeros((TEST_SAMPLES, WINDOW_SIZE, time_vec.shape[-1] + 3))
cnt = 0
for feats, values, statics in td:
    cnt += 1
    if cnt > TEST_SAMPLES: 
      break 
    sent_tensor, tag_tensor = get_XY(feats, values, statics) 
    xt[cnt-1] = sent_tensor.squeeze(0)

In [0]:

effects = perturb_features(my_model, xt, (46, 61))