In [45]:
import sys
sys.path.append('../')
import datasets
from dataproc import extract_wvs
from dataproc import get_discharge_summaries
from dataproc import group_and_sort
from dataproc import filter_patients_and_labels
from dataproc import concat_and_split_disch
from dataproc import build_vocab
from dataproc import sort_by_length
from dataproc import vocab_index_descriptions
from dataproc import word_embeddings
from constants import DISCH_DIR, DATA_DIR

import numpy as np
import pandas as pd

from collections import Counter, defaultdict
import csv
import operator

Let's do some data processing in a much better way, with a notebook.

First, let's define some stuff.

In [2]:
Y = 'full' #use all available labels in the dataset for prediction
notes_file = '%s/NOTEEVENTS.csv' % DISCH_DIR # raw note events downloaded from MIMIC-III
vocab_size = 'full' #don't limit the vocab size to a specific number
vocab_min = 3 #discard tokens appearing in fewer than this many documents
split = [.7, .15, .15] #train/dev/test

# Data processing

## Combine diagnosis and procedure codes and reformat them

In [119]:
dfproc = pd.read_csv('%s/raw/PROCEDURES_ICD.csv' % DATA_DIR)
dfdiag = pd.read_csv('%s/raw/DIAGNOSES_ICD.csv' % DATA_DIR)

In [120]:
reload(datasets)
dfdiag['absolute_code'] = dfdiag.apply(lambda row: str(datasets.reformat(str(row[4]), True)), axis=1)
dfproc['absolute_code'] = dfproc.apply(lambda row: str(datasets.reformat(str(row[4]), False)), axis=1)

In [121]:
dfcodes = pd.concat([dfdiag, dfproc])

In [122]:
dfcodes.to_csv('%s/ALL_CODES.csv' % DISCH_DIR, index=False,
               columns=['ROW_ID', 'SUBJECT_ID', 'HADM_ID', 'SEQ_NUM', 'absolute_code'],
               header=['ROW_ID', 'SUBJECT_ID', 'HADM_ID', 'SEQ_NUM', 'ICD9_CODE'])

## How many codes are there?

In [3]:
#In the full dataset (not just discharge summaries)
df = pd.read_csv('%s/ALL_CODES.csv' % DISCH_DIR, dtype={"ICD9_CODE": str})
len(df['ICD9_CODE'].unique())

8994

In [4]:
#In the discharge summaries
codes = set()
df = pd.read_csv('%s/notes_full_labeled.csv' % DISCH_DIR)
for row in df.itertuples():
    for c in str(row[4]).split(';'):
        codes.add(c)
len(codes)

8922

## Tokenize and preprocess raw text

Let's start from scratch by making the disch_full.csv from the raw notes

This will:
- Select only discharge summaries and their addenda
- remove stop words
- remove punctuation and numeric-only tokens, removing 500 but keeping 250ml
- lowercase all tokens

In [5]:
reload(get_discharge_summaries)
#This reads all notes, selects only the discharge summaries, and tokenizes them, returning the output filename
disch_full_file = get_discharge_summaries.write_discharge_summaries(Y, notes_file=notes_file,
                                                                    out_file="%s/disch_full.csv" % DISCH_DIR)

50it [00:00, 496.27it/s]

processing notes file
writing to /nethome/jmullenbach3/mimicdata/disch/disch_full.csv
0
0 disch, 0 not disch


10093it [00:19, 484.08it/s]

10000
10000 disch, 0 not disch


20134it [00:36, 674.64it/s]

20000
20000 disch, 0 not disch


30103it [00:52, 570.01it/s]

30000
30000 disch, 0 not disch


40076it [01:12, 474.28it/s]

40000
40000 disch, 0 not disch


50050it [01:35, 431.22it/s]

50000
50000 disch, 0 not disch


64726it [01:48, 5471.45it/s]

60000
59607 disch, 393 not disch


73272it [01:49, 11693.06it/s]

70000
59652 disch, 10348 not disch


85084it [01:49, 21338.51it/s]

80000
59652 disch, 20348 not disch


93357it [01:50, 24700.33it/s]

90000
59652 disch, 30348 not disch


107507it [01:50, 32482.03it/s]

100000
59652 disch, 40348 not disch


130823it [01:50, 50844.77it/s]

110000
59652 disch, 50348 not disch
120000
59652 disch, 60348 not disch
130000
59652 disch, 70348 not disch
140000
59652 disch, 80348 not disch


178708it [01:50, 92187.19it/s]

150000
59652 disch, 90348 not disch
160000
59652 disch, 100348 not disch
170000
59652 disch, 110348 not disch
180000
59652 disch, 120348 not disch


208499it [01:51, 108684.57it/s]

190000
59652 disch, 130348 not disch
200000
59652 disch, 140348 not disch
210000
59652 disch, 150348 not disch


238519it [01:51, 126904.93it/s]

220000
59652 disch, 160348 not disch
230000
59652 disch, 170348 not disch
240000
59652 disch, 180348 not disch
250000
59652 disch, 190348 not disch


287313it [01:51, 148472.99it/s]

260000
59652 disch, 200348 not disch
270000
59652 disch, 210348 not disch
280000
59652 disch, 220348 not disch
290000
59652 disch, 230348 not disch


303223it [01:51, 151506.44it/s]

300000
59652 disch, 240348 not disch
310000
59652 disch, 250348 not disch


318915it [01:52, 81325.99it/s] 

320000
59652 disch, 260348 not disch


331072it [01:52, 43433.66it/s]

330000
59652 disch, 270348 not disch


340206it [01:53, 30795.96it/s]

340000
59652 disch, 280348 not disch


352583it [01:53, 23064.44it/s]

350000
59652 disch, 290348 not disch


363602it [01:54, 19814.43it/s]

360000
59652 disch, 300348 not disch


373123it [01:54, 18723.67it/s]

370000
59652 disch, 310348 not disch


382752it [01:55, 18060.80it/s]

380000
59652 disch, 320348 not disch


391901it [01:56, 17845.46it/s]

390000
59652 disch, 330348 not disch


402593it [01:56, 17274.07it/s]

400000
59652 disch, 340348 not disch


413782it [01:57, 18624.67it/s]

410000
59652 disch, 350348 not disch


422039it [01:57, 19896.96it/s]

420000
59652 disch, 360348 not disch


432457it [01:58, 19906.13it/s]

430000
59652 disch, 370348 not disch


442450it [01:58, 19688.94it/s]

440000
59652 disch, 380348 not disch


452218it [01:59, 19167.98it/s]

450000
59652 disch, 390348 not disch


462039it [01:59, 17920.61it/s]

460000
59652 disch, 400348 not disch


473455it [02:00, 18713.38it/s]

470000
59652 disch, 410348 not disch


483296it [02:00, 18970.70it/s]

480000
59652 disch, 420348 not disch


492631it [02:01, 16961.69it/s]

490000
59652 disch, 430348 not disch


501926it [02:02, 18352.76it/s]

500000
59652 disch, 440348 not disch


513373it [02:02, 18605.42it/s]

510000
59652 disch, 450348 not disch


522848it [02:03, 18787.67it/s]

520000
59652 disch, 460348 not disch


531955it [02:03, 17827.55it/s]

530000
59652 disch, 470348 not disch


542225it [02:04, 19947.33it/s]

540000
59652 disch, 480348 not disch


552223it [02:04, 19697.77it/s]

550000
59652 disch, 490348 not disch


561857it [02:05, 18872.12it/s]

560000
59652 disch, 500348 not disch


573593it [02:05, 19512.11it/s]

570000
59652 disch, 510348 not disch


583070it [02:06, 17185.27it/s]

580000
59652 disch, 520348 not disch


592506it [02:06, 18597.95it/s]

590000
59652 disch, 530348 not disch


603740it [02:07, 18418.82it/s]

600000
59652 disch, 540348 not disch


613201it [02:07, 18589.16it/s]

610000
59652 disch, 550348 not disch


622836it [02:08, 18760.84it/s]

620000
59652 disch, 560348 not disch


632206it [02:09, 18427.88it/s]

630000
59652 disch, 570348 not disch


644123it [02:09, 17008.90it/s]

640000
59652 disch, 580348 not disch


651996it [02:10, 18223.84it/s]

650000
59652 disch, 590348 not disch


663418it [02:10, 19006.69it/s]

660000
59652 disch, 600348 not disch


672888it [02:11, 14062.46it/s]

670000
59652 disch, 610348 not disch


682330it [02:12, 17651.94it/s]

680000
59652 disch, 620348 not disch


693584it [02:12, 18685.86it/s]

690000
59652 disch, 630348 not disch


703200it [02:13, 18734.45it/s]

700000
59652 disch, 640348 not disch


712635it [02:13, 18524.40it/s]

710000
59652 disch, 650348 not disch


722055it [02:14, 18721.99it/s]

720000
59652 disch, 660348 not disch


733608it [02:14, 19240.22it/s]

730000
59652 disch, 670348 not disch


743699it [02:15, 26770.54it/s]

740000
59652 disch, 680348 not disch


755563it [02:15, 34021.51it/s]

750000
59652 disch, 690348 not disch


767787it [02:15, 38115.97it/s]

760000
59652 disch, 700348 not disch


775789it [02:15, 39101.80it/s]

770000
59652 disch, 710348 not disch


783670it [02:16, 33943.86it/s]

780000
59652 disch, 720348 not disch


795485it [02:16, 37294.02it/s]

790000
59652 disch, 730348 not disch


807378it [02:16, 38843.69it/s]

800000
59652 disch, 740348 not disch


815183it [02:16, 38614.91it/s]

810000
59652 disch, 750348 not disch


826766it [02:17, 38130.03it/s]

820000
59652 disch, 760348 not disch


834500it [02:17, 38399.79it/s]

830000
59652 disch, 770348 not disch


846014it [02:17, 38278.83it/s]

840000
59652 disch, 780348 not disch


857446it [02:18, 37788.20it/s]

850000
59652 disch, 790348 not disch


865110it [02:18, 37910.13it/s]

860000
59652 disch, 800348 not disch


876548it [02:18, 37953.54it/s]

870000
59652 disch, 810348 not disch


884180it [02:18, 37988.76it/s]

880000
59652 disch, 820348 not disch


895654it [02:19, 38139.38it/s]

890000
59652 disch, 830348 not disch


907263it [02:19, 38192.41it/s]

900000
59652 disch, 840348 not disch


914944it [02:19, 38132.31it/s]

910000
59652 disch, 850348 not disch


926344it [02:19, 36046.66it/s]

920000
59652 disch, 860348 not disch


933624it [02:20, 36182.97it/s]

930000
59652 disch, 870348 not disch


944546it [02:20, 35598.72it/s]

940000
59652 disch, 880348 not disch


954558it [02:20, 30322.80it/s]

950000
59652 disch, 890348 not disch


963648it [02:21, 27104.97it/s]

960000
59652 disch, 900348 not disch


976312it [02:21, 30385.38it/s]

970000
59652 disch, 910348 not disch


987429it [02:21, 34488.87it/s]

980000
59652 disch, 920348 not disch


994814it [02:22, 35726.28it/s]

990000
59652 disch, 930348 not disch


1006084it [02:22, 36905.88it/s]

1000000
59652 disch, 940348 not disch


1017605it [02:22, 37679.16it/s]

1010000
59652 disch, 950348 not disch


1025447it [02:22, 38465.35it/s]

1020000
59652 disch, 960348 not disch


1036966it [02:23, 38159.77it/s]

1030000
59652 disch, 970348 not disch


1044627it [02:23, 38159.61it/s]

1040000
59652 disch, 980348 not disch


1056169it [02:23, 38167.13it/s]

1050000
59652 disch, 990348 not disch


1063766it [02:23, 37110.07it/s]

1060000
59652 disch, 1000348 not disch


1075038it [02:24, 37364.01it/s]

1070000
59652 disch, 1010348 not disch


1085911it [02:24, 34781.19it/s]

1080000
59652 disch, 1020348 not disch


1096790it [02:24, 35759.98it/s]

1090000
59652 disch, 1030348 not disch


1104034it [02:25, 35437.52it/s]

1100000
59652 disch, 1040348 not disch


1114705it [02:25, 35300.27it/s]

1110000
59652 disch, 1050348 not disch


1125341it [02:25, 35260.47it/s]

1120000
59652 disch, 1060348 not disch


1136008it [02:25, 35477.88it/s]

1130000
59652 disch, 1070348 not disch


1139557it [02:26, 35272.58it/s]

1140000
59652 disch, 1080348 not disch


1153555it [02:26, 30184.76it/s]

1150000
59652 disch, 1090348 not disch


1164053it [02:26, 33177.10it/s]

1160000
59652 disch, 1100348 not disch


1174540it [02:27, 34249.15it/s]

1170000
59652 disch, 1110348 not disch


1184903it [02:27, 34174.55it/s]

1180000
59652 disch, 1120348 not disch


1195533it [02:27, 35020.21it/s]

1190000
59652 disch, 1130348 not disch


1206039it [02:28, 34402.92it/s]

1200000
59652 disch, 1140348 not disch


1216449it [02:28, 34610.02it/s]

1210000
59652 disch, 1150348 not disch


1226862it [02:28, 34510.40it/s]

1220000
59652 disch, 1160348 not disch


1233750it [02:28, 34298.17it/s]

1230000
59652 disch, 1170348 not disch


1246247it [02:29, 38782.35it/s]

1240000
59652 disch, 1180348 not disch
1250000
59652 disch, 1190348 not disch


1264661it [02:29, 42767.15it/s]

1260000
59652 disch, 1200348 not disch


1281543it [02:29, 48996.53it/s]

1270000
59652 disch, 1210348 not disch
1280000
59652 disch, 1220348 not disch


1299954it [02:30, 53789.85it/s]

1290000
59652 disch, 1230348 not disch
1300000
59652 disch, 1240348 not disch


1318300it [02:30, 58669.30it/s]

1310000
59652 disch, 1250348 not disch
1320000
59652 disch, 1260348 not disch


1337045it [02:30, 61151.99it/s]

1330000
59652 disch, 1270348 not disch
1340000
59652 disch, 1280348 not disch


1361212it [02:31, 56852.10it/s]

1350000
59652 disch, 1290348 not disch
1360000
59652 disch, 1300348 not disch


1379936it [02:31, 59713.29it/s]

1370000
59652 disch, 1310348 not disch
1380000
59652 disch, 1320348 not disch


1398852it [02:31, 61928.97it/s]

1390000
59652 disch, 1330348 not disch
1400000
59652 disch, 1340348 not disch


1417672it [02:32, 61883.66it/s]

1410000
59652 disch, 1350348 not disch
1420000
59652 disch, 1360348 not disch


1436785it [02:32, 63128.81it/s]

1430000
59652 disch, 1370348 not disch
1440000
59652 disch, 1380348 not disch


1462090it [02:32, 63092.90it/s]

1450000
59652 disch, 1390348 not disch
1460000
59652 disch, 1400348 not disch


1481199it [02:33, 63505.05it/s]

1470000
59652 disch, 1410348 not disch
1480000
59652 disch, 1420348 not disch


1500294it [02:33, 63280.47it/s]

1490000
59652 disch, 1430348 not disch
1500000
59652 disch, 1440348 not disch


1519387it [02:33, 63114.87it/s]

1510000
59652 disch, 1450348 not disch
1520000
59652 disch, 1460348 not disch


1538480it [02:34, 63259.04it/s]

1530000
59652 disch, 1470348 not disch
1540000
59652 disch, 1480348 not disch


1557410it [02:34, 62948.42it/s]

1550000
59652 disch, 1490348 not disch
1560000
59652 disch, 1500348 not disch


1576528it [02:34, 63154.50it/s]

1570000
59652 disch, 1510348 not disch
1580000
59652 disch, 1520348 not disch


1601799it [02:35, 63100.89it/s]

1590000
59652 disch, 1530348 not disch
1600000
59652 disch, 1540348 not disch


1620729it [02:35, 62507.91it/s]

1610000
59652 disch, 1550348 not disch
1620000
59652 disch, 1560348 not disch


1639592it [02:35, 62438.22it/s]

1630000
59652 disch, 1570348 not disch
1640000
59652 disch, 1580348 not disch


1658561it [02:36, 62874.02it/s]

1650000
59652 disch, 1590348 not disch
1660000
59652 disch, 1600348 not disch


1679073it [02:36, 63094.30it/s]

1670000
59652 disch, 1610348 not disch
1680000
59652 disch, 1620348 not disch


1700022it [02:36, 66332.94it/s]

1690000
59652 disch, 1630348 not disch
1700000
59652 disch, 1640348 not disch


1722980it [02:37, 72735.16it/s]

1710000
59652 disch, 1650348 not disch
1720000
59652 disch, 1660348 not disch


1738081it [02:37, 74233.74it/s]

1730000
59652 disch, 1670348 not disch
1740000
59652 disch, 1680348 not disch


1760787it [02:37, 73497.22it/s]

1750000
59652 disch, 1690348 not disch
1760000
59652 disch, 1700348 not disch


1783408it [02:37, 74913.11it/s]

1770000
59652 disch, 1710348 not disch
1780000
59652 disch, 1720348 not disch


1798571it [02:38, 75030.11it/s]

1790000
59652 disch, 1730348 not disch
1800000
59652 disch, 1740348 not disch


1821462it [02:38, 75645.05it/s]

1810000
59652 disch, 1750348 not disch
1820000
59652 disch, 1760348 not disch


1836762it [02:38, 75917.34it/s]

1830000
59652 disch, 1770348 not disch
1840000
59652 disch, 1780348 not disch


1859376it [02:38, 74249.80it/s]

1850000
59652 disch, 1790348 not disch
1860000
59652 disch, 1800348 not disch


1881908it [02:39, 74398.64it/s]

1870000
59652 disch, 1810348 not disch
1880000
59652 disch, 1820348 not disch


1904572it [02:39, 75071.84it/s]

1890000
59652 disch, 1830348 not disch
1900000
59652 disch, 1840348 not disch


1919734it [02:39, 75306.52it/s]

1910000
59652 disch, 1850348 not disch
1920000
59652 disch, 1860348 not disch


1942302it [02:39, 74981.36it/s]

1930000
59652 disch, 1870348 not disch
1940000
59652 disch, 1880348 not disch


1964877it [02:40, 74950.82it/s]

1950000
59652 disch, 1890348 not disch
1960000
59652 disch, 1900348 not disch


1980063it [02:40, 75468.78it/s]

1970000
59652 disch, 1910348 not disch
1980000
59652 disch, 1920348 not disch


2002680it [02:40, 74143.13it/s]

1990000
59652 disch, 1930348 not disch
2000000
59652 disch, 1940348 not disch


2017628it [02:40, 72049.27it/s]

2010000
59652 disch, 1950348 not disch
2020000
59652 disch, 1960348 not disch


2039197it [02:41, 64681.91it/s]

2030000
59652 disch, 1970348 not disch
2040000
59652 disch, 1980348 not disch


2059484it [02:41, 55376.28it/s]

2050000
59652 disch, 1990348 not disch
2060000
59652 disch, 2000348 not disch


2083180it [02:42, 12834.33it/s]

2070000
59652 disch, 2010348 not disch
2080000
59652 disch, 2020348 not disch





Let's read this in and see what kind of data we're working with

In [6]:
df = pd.read_csv('%s/disch_full.csv' % DISCH_DIR)

In [7]:
#How many admissions?
len(df['HADM_ID'].unique())

52726

In [8]:
#Any of them missing HADM_ID?
sum(df['HADM_ID'].isnull())

0

In [9]:
#What about text?
sum(df['TEXT'] == '')

0

In [10]:
#Tokens and types
types = set()
num_tok = 0
for row in df.itertuples():
    for w in row[4].split():
        types.add(w)
        num_tok += 1

In [11]:
print("Num types", len(types))
print("Num tokens", str(num_tok))

('Num types', 150854)
('Num tokens', '79801387')


In [12]:
#Let's sort by SUBJECT_ID and HADM_ID to make a correspondence with our label file
df = df.sort_values(['SUBJECT_ID', 'HADM_ID'])

In [13]:
#Sort the label file by the same
dfl = pd.read_csv('%s/ALL_CODES.csv' % DISCH_DIR)
dfl = dfl.sort_values(['SUBJECT_ID', 'HADM_ID'])

  interactivity=interactivity, compiler=compiler, result=result)


In [14]:
len(dfl['HADM_ID'].unique())

58976

## Consolidate labels with set of discharge summaries

Looks like there were some HADM_ID's that didn't have discharge summaries, so they weren't included with our notes

In [15]:
#Let's filter out these HADM_ID's
hadm_ids = set(df['HADM_ID'])
with open('%s/ALL_CODES.csv' % DISCH_DIR, 'r') as lf:
    with open('%s/ALL_CODES_filtered.csv' % DISCH_DIR, 'w') as of:
        w = csv.writer(of)
        w.writerow(['SUBJECT_ID', 'HADM_ID', 'ICD9_CODE', 'ADMITTIME', 'DISCHTIME'])
        r = csv.reader(lf)
        #header
        next(r)
        for i,row in enumerate(r):
            hadm_id = int(row[2])
            #print(hadm_id)
            #break
            if hadm_id in hadm_ids:
                w.writerow(row[1:3] + [row[-1], '', ''])

In [16]:
dfl = pd.read_csv('%s/ALL_CODES_filtered.csv' % DISCH_DIR, index_col=None)

  interactivity=interactivity, compiler=compiler, result=result)


In [17]:
len(dfl['HADM_ID'].unique())

52726

In [18]:
#we still need to sort it by HADM_ID though
dfl = dfl.sort_values(['SUBJECT_ID', 'HADM_ID'])
dfl.to_csv('%s/ALL_CODES_filtered.csv' % DISCH_DIR, index=False)

## Append labels to notes in a single file

In [21]:
#Now let's append each instance with all of its codes
#this is pretty non-trivial so let's use this script I wrote, which requires the notes to be written to file
sorted_file = '%s/disch_full.csv' % DISCH_DIR
df.to_csv(sorted_file, index=False)

In [22]:
reload(concat_and_split_disch)
labeled = concat_and_split_disch.concat_data('%s/ALL_CODES_filtered.csv' % DISCH_DIR, 'full', sorted_file)

CONCATENATING
0 done
10000 done
20000 done
30000 done
40000 done
50000 done


In [23]:
#what was the name of the file we just created again?
print(labeled)

/nethome/jmullenbach3/mimicdata/disch/notes_full_labeled.csv


Let's sanity check our results. Do we have all hadm id's accounted for, and the same vocab stats?

In [24]:
dfnl = pd.read_csv(labeled)
#Tokens and types
types = set()
num_tok = 0
for row in dfnl.itertuples():
    for w in row[3].split():
        types.add(w)
        num_tok += 1

In [25]:
print("num types", len(types), "num tokens", num_tok)

('num types', 150854, 'num tokens', 79801387)


In [26]:
len(dfnl['HADM_ID'].unique())

52726

## Create train/dev/test splits

In [27]:
#Okay, now we can split and create a vocabulary for real
reload(concat_and_split_disch)
fname = '%s/notes_full_labeled.csv' % DISCH_DIR
base_name = "%s/disch" % DISCH_DIR #for output
tr, dv, te = concat_and_split_disch.split_data(fname, 'full', base_name=base_name, split=split)

SPLITTING
0 read
10000 read
20000 read
30000 read
40000 read
50000 read


## Build vocabulary from training data

In [28]:
vname = '%s/vocab.csv' % DISCH_DIR
build_vocab.build_vocab(vocab_size, Y, vocab_min, infile=tr, vocab_filename=vname)

reading in data...
0 read
10000 read
20000 read
30000 read
40000 read
building matrix
C.shape: (140795, 47723)
removing rare terms
51917 terms qualify out of 140795 total
calculating tf-idf scores
sorting to get top full
writing output
inds: [   0  863  851 ..., 1418  731  702]
sampling of the kept terms
['postoperative', 'disp', 'life', 'refills', 'q', 'valve', 'neg', 'infant', 'her', 'she']
[47.744687730138004, 48.046882977641545, 48.23472141945431, 49.0270146113522, 49.256788036045528, 53.04457515099444, 70.896451957724608, 84.779922783819572, 124.85710216984242, 180.57381166385363]


In [3]:
labeled = '%s/notes_full_labeled.csv' % DISCH_DIR
dfnl = pd.read_csv(labeled)
#Tokens and types
types = set()
num_tok = 0
for row in dfnl.itertuples():
    for w in row[3].split():
        types.add(w)
        num_tok += 1

In [5]:
#Add all the words from code descriptions too! Will use this for 
#4-gram similarity importance baseline, full-vocab and mimick-vocab description embedding model
with open('%s/raw/D_ICD_DIAGNOSES.csv' % DATA_DIR, 'r') as f:
    r = csv.reader(f)
    for row in r:
        for w in row[3].split():
            types.add(w)
with open('%s/raw/D_ICD_PROCEDURES.csv' % DATA_DIR, 'r') as f:
    r = csv.reader(f)
    for row in r:
        for w in row[3].split():
            types.add(w)

In [6]:
len(types)

156407

Let's also make a vocab file with all present words that we can train word2vec on, for later comparison

In [7]:
vname_unfiltered = '%s/vocab_unfiltered.csv' % DISCH_DIR
with open(vname_unfiltered, 'w') as of:
    for word in types:
        of.write(word + '\n')

In [11]:
losses = [6, 5, 4, 4, 5]

In [12]:
import numpy as np
np.argmin(losses) - len(losses)

-3

## Sort each data split by length for batching

In [32]:
for splt in ['train', 'dev', 'test']:
    filename = '%s/disch_%s_split.csv' % (DISCH_DIR, splt)
    df = pd.read_csv(filename)
    df['length'] = df.apply(lambda row: len(str(row['TEXT']).split()), axis=1)
    df = df.sort_values(['length'])
    df.to_csv('%s/%s_full.csv' % (DISCH_DIR, splt), index=False)

## Pre-train word embeddings

Let's train word embeddings on all words

In [33]:
reload(word_embeddings)
w2v_file = word_embeddings.word_embeddings('full', 'processed', '%s/disch_full.csv' % DISCH_DIR, 0, 5)

building word2vec vocab on processed data...
training...
writing embeddings to /nethome/jmullenbach3/mimicdata/disch/processed_full.w2v


## Write pre-trained word embeddings with new vocab

In [36]:
reload(extract_wvs)
reload(datasets)
extract_wvs.gensim_to_embeddings('%s/processed_full.w2v' % DISCH_DIR, '%s/vocab.csv' % DISCH_DIR, Y)

51917it [16:45, 51.64it/s] 


In [38]:
reload(extract_wvs)
extract_wvs.gensim_to_embeddings('%s/processed_full.w2v' % DISCH_DIR, '%s/vocab_unfiltered.csv' % DISCH_DIR, Y, outfile='%s/processed_full_unfiltered.embed' % DISCH_DIR)

150854it [2:01:22, 16.34it/s]


## Pre-process code descriptions using the vocab

In [33]:
reload(vocab_index_descriptions)
reload(datasets)
vocab_index_descriptions.vocab_index_descriptions('%s/vocab.csv' % DISCH_DIR)


0it [00:00, ?it/s][A
4it [00:00, 39.09it/s][A
9it [00:00, 41.51it/s][A
13it [00:00, 33.03it/s][A
16it [00:00, 22.81it/s][A
18it [00:00, 17.17it/s][A
20it [00:01, 12.05it/s][A
23it [00:01, 13.85it/s][A
26it [00:01, 15.58it/s][A
29it [00:01, 16.07it/s][A
32it [00:01, 16.80it/s][A
34it [00:01, 16.40it/s][A
37it [00:01, 18.82it/s][A
40it [00:02, 19.65it/s][A
43it [00:02, 21.07it/s][A
46it [00:02, 22.30it/s][A
49it [00:02, 21.28it/s][A
52it [00:02, 17.76it/s][A
54it [00:02, 17.28it/s][A
56it [00:02, 17.12it/s][A
59it [00:03, 19.16it/s][A
62it [00:03, 21.11it/s][A
65it [00:03, 21.32it/s][A
68it [00:03, 23.11it/s][A
71it [00:03, 19.91it/s][A
74it [00:03, 20.85it/s][A
77it [00:03, 18.14it/s][A
79it [00:04, 18.28it/s][A
81it [00:04, 17.42it/s][A
85it [00:04, 20.60it/s][A
93it [00:04, 25.87it/s][A
97it [00:04, 28.92it/s][A
101it [00:04, 26.02it/s][A
105it [00:04, 27.45it/s][A
109it [00:04, 26.61it/s][A
113it [00:05, 28.23it/s][A
117it [00:05, 28.54it/s][A

1

951it [00:44, 14.67it/s][A
954it [00:44, 17.13it/s][A
958it [00:44, 20.42it/s][A
962it [00:44, 23.29it/s][A
965it [00:44, 23.35it/s][A
969it [00:44, 24.54it/s][A
972it [00:45, 23.41it/s][A
977it [00:45, 27.59it/s][A
982it [00:45, 29.96it/s][A
986it [00:45, 26.68it/s][A
989it [00:45, 22.97it/s][A
992it [00:45, 16.32it/s][A
996it [00:46, 19.56it/s][A
1001it [00:46, 23.39it/s][A
1005it [00:46, 21.54it/s][A
1010it [00:46, 25.62it/s][A
1014it [00:46, 23.75it/s][A
1017it [00:46, 24.71it/s][A
1023it [00:46, 29.16it/s][A
1027it [00:47, 24.08it/s][A
1030it [00:47, 20.76it/s][A
1033it [00:47, 21.31it/s][A
1037it [00:47, 22.57it/s][A
1040it [00:47, 19.04it/s][A
1043it [00:48, 16.99it/s][A
1046it [00:48, 17.85it/s][A
1048it [00:48, 18.23it/s][A
1051it [00:48, 18.85it/s][A
1053it [00:48, 16.67it/s][A
1055it [00:48, 14.97it/s][A
1058it [00:48, 17.51it/s][A
1061it [00:49, 19.60it/s][A
1064it [00:49, 19.82it/s][A
1069it [00:49, 23.32it/s][A
1072it [00:49, 21.82it/s][

1866it [01:28, 21.28it/s][A
1870it [01:28, 22.04it/s][A
1873it [01:28, 16.32it/s][A
1875it [01:28, 17.25it/s][A
1877it [01:28, 16.00it/s][A
1880it [01:28, 17.16it/s][A
1883it [01:28, 18.79it/s][A
1886it [01:29, 19.14it/s][A
1889it [01:29, 19.90it/s][A
1893it [01:29, 22.89it/s][A
1897it [01:29, 25.48it/s][A
1901it [01:29, 27.97it/s][A
1905it [01:29, 27.02it/s][A
1908it [01:29, 23.70it/s][A
1911it [01:30, 22.34it/s][A
1914it [01:30, 23.76it/s][A
1920it [01:30, 28.58it/s][A
1925it [01:30, 32.08it/s][A
1929it [01:30, 29.03it/s][A
1934it [01:30, 31.89it/s][A
1938it [01:30, 32.18it/s][A
1942it [01:31, 24.45it/s][A
1945it [01:31, 20.32it/s][A
1950it [01:31, 23.02it/s][A
1954it [01:31, 26.26it/s][A
1958it [01:31, 24.20it/s][A
1961it [01:31, 22.29it/s][A
1964it [01:32, 20.45it/s][A
1967it [01:32, 19.10it/s][A
1970it [01:32, 19.98it/s][A
1974it [01:32, 23.15it/s][A
1977it [01:32, 21.16it/s][A
1980it [01:32, 22.78it/s][A
1983it [01:32, 21.59it/s][A
1986it [01:33,

3739it [02:56, 18.49it/s][A
3742it [02:57, 16.44it/s][A
3745it [02:57, 18.21it/s][A
3748it [02:57, 20.10it/s][A
3751it [02:57, 22.30it/s][A
3754it [02:57, 20.06it/s][A
3757it [02:57, 15.24it/s][A
3759it [02:58, 14.54it/s][A
3761it [02:58, 14.15it/s][A
3763it [02:58, 14.50it/s][A
3765it [02:58, 14.78it/s][A
3767it [02:58, 14.67it/s][A
3770it [02:58, 16.21it/s][A
3773it [02:58, 18.07it/s][A
3775it [02:59, 16.87it/s][A
3782it [02:59, 21.36it/s][A
3785it [02:59, 21.26it/s][A
3789it [02:59, 22.98it/s][A
3792it [02:59, 22.29it/s][A
3795it [02:59, 21.86it/s][A
3800it [02:59, 23.58it/s][A
3804it [03:00, 25.12it/s][A
3807it [03:00, 22.61it/s][A
3811it [03:00, 24.00it/s][A
3814it [03:00, 18.45it/s][A
3817it [03:00, 19.18it/s][A
3820it [03:00, 20.37it/s][A
3823it [03:01, 20.23it/s][A
3827it [03:01, 22.68it/s][A
3830it [03:01, 17.10it/s][A
3833it [03:01, 16.93it/s][A
3835it [03:01, 12.04it/s][A
3837it [03:02,  9.62it/s][A
3839it [03:02,  8.55it/s][A
3842it [03:02,

5631it [04:22, 21.15it/s][A
5635it [04:22, 23.19it/s][A
5640it [04:22, 26.92it/s][A
5644it [04:23, 25.15it/s][A
5647it [04:23, 22.62it/s][A
5650it [04:23, 23.24it/s][A
5653it [04:23, 22.56it/s][A
5656it [04:23, 22.02it/s][A
5661it [04:23, 24.96it/s][A
5664it [04:23, 22.83it/s][A
5667it [04:24, 23.63it/s][A
5670it [04:24, 23.52it/s][A
5673it [04:24, 17.01it/s][A
5676it [04:24, 14.78it/s][A
5678it [04:24, 11.95it/s][A
5686it [04:25, 15.75it/s][A
5689it [04:25, 17.82it/s][A
5692it [04:25, 19.73it/s][A
5695it [04:25, 21.53it/s][A
5699it [04:25, 23.19it/s][A
5702it [04:25, 23.19it/s][A
5705it [04:26, 16.54it/s][A
5708it [04:26, 15.42it/s][A
5710it [04:26, 14.83it/s][A
5713it [04:26, 17.17it/s][A
5716it [04:26, 17.42it/s][A
5719it [04:26, 18.12it/s][A
5722it [04:26, 19.32it/s][A
5725it [04:27, 20.85it/s][A
5728it [04:27, 21.01it/s][A
5731it [04:27, 20.01it/s][A
5734it [04:27, 19.57it/s][A
5738it [04:27, 21.73it/s][A
5742it [04:27, 23.81it/s][A
5745it [04:27,

7541it [05:48, 23.40it/s][A
7545it [05:48, 24.61it/s][A
7550it [05:48, 28.80it/s][A
7555it [05:49, 31.86it/s][A
7559it [05:49, 29.59it/s][A
7563it [05:49, 26.33it/s][A
7567it [05:49, 23.78it/s][A
7570it [05:49, 24.06it/s][A
7574it [05:49, 25.06it/s][A
7578it [05:49, 27.19it/s][A
7582it [05:50, 28.12it/s][A
7585it [05:50, 24.81it/s][A
7588it [05:50, 25.36it/s][A
7591it [05:50, 25.42it/s][A
7596it [05:50, 29.24it/s][A
7600it [05:50, 25.07it/s][A
7603it [05:50, 24.02it/s][A
7606it [05:51, 25.30it/s][A
7610it [05:51, 26.89it/s][A
7613it [05:51, 26.25it/s][A
7616it [05:51, 22.66it/s][A
7620it [05:51, 24.89it/s][A
7624it [05:51, 26.29it/s][A
7629it [05:51, 29.27it/s][A
7633it [05:51, 28.36it/s][A
7636it [05:52, 27.46it/s][A
7639it [05:52, 27.21it/s][A
7642it [05:52, 18.87it/s][A
7645it [05:52, 16.37it/s][A
7647it [05:52, 17.04it/s][A
7650it [05:52, 19.26it/s][A
7653it [05:53, 19.53it/s][A
7657it [05:53, 22.72it/s][A
7660it [05:53, 21.86it/s][A
7663it [05:53,

9399it [07:13, 17.48it/s][A
9401it [07:13, 17.17it/s][A
9403it [07:14, 16.11it/s][A
9407it [07:14, 18.66it/s][A
9410it [07:14, 17.50it/s][A
9412it [07:14, 15.88it/s][A
9414it [07:14, 15.14it/s][A
9416it [07:14, 13.00it/s][A
9418it [07:15, 13.48it/s][A
9420it [07:15, 14.09it/s][A
9422it [07:15, 13.65it/s][A
9424it [07:15, 12.42it/s][A
9426it [07:15, 13.32it/s][A
9428it [07:15, 12.02it/s][A
9430it [07:16, 12.59it/s][A
9432it [07:16, 13.88it/s][A
9434it [07:16, 13.42it/s][A
9437it [07:16, 15.60it/s][A
9440it [07:16, 17.96it/s][A
9443it [07:16, 20.24it/s][A
9447it [07:16, 22.02it/s][A
9450it [07:16, 21.32it/s][A
9455it [07:17, 25.27it/s][A
9458it [07:17, 21.63it/s][A
9461it [07:17, 18.18it/s][A
9464it [07:17, 17.26it/s][A
9467it [07:17, 16.12it/s][A
9469it [07:17, 15.00it/s][A
9474it [07:18, 18.71it/s][A
9478it [07:18, 20.94it/s][A
9481it [07:18, 17.82it/s][A
9485it [07:18, 18.29it/s][A
9488it [07:18, 18.65it/s][A
9491it [07:19, 15.27it/s][A
9493it [07:19,

11171it [08:38, 28.24it/s][A
11175it [08:38, 30.07it/s][A
11179it [08:38, 28.41it/s][A
11183it [08:38, 29.16it/s][A
11187it [08:39, 29.12it/s][A
11191it [08:39, 26.50it/s][A
11195it [08:39, 27.80it/s][A
11198it [08:39, 27.64it/s][A
11201it [08:39, 25.71it/s][A
11204it [08:39, 25.17it/s][A
11207it [08:39, 23.85it/s][A
11210it [08:40, 24.72it/s][A
11213it [08:40, 22.69it/s][A
11217it [08:40, 24.33it/s][A
11220it [08:40, 22.03it/s][A
11226it [08:40, 26.42it/s][A
11230it [08:40, 27.83it/s][A
11234it [08:41, 21.77it/s][A
11237it [08:41, 17.28it/s][A
11240it [08:41, 17.03it/s][A
11243it [08:41, 18.05it/s][A
11246it [08:41, 16.56it/s][A
11250it [08:41, 18.70it/s][A
11253it [08:42, 19.94it/s][A
11258it [08:42, 23.29it/s][A
11262it [08:42, 26.02it/s][A
11266it [08:42, 26.45it/s][A
11269it [08:42, 25.22it/s][A
11272it [08:42, 25.73it/s][A
11275it [08:42, 23.18it/s][A
11278it [08:43, 22.35it/s][A
11281it [08:43, 23.24it/s][A
11284it [08:43, 24.61it/s][A
11287it [0

13018it [10:02, 28.54it/s][A
13022it [10:02, 22.77it/s][A
13025it [10:02, 21.81it/s][A
13028it [10:02, 23.34it/s][A
13031it [10:02, 21.74it/s][A
13035it [10:02, 24.57it/s][A
13039it [10:03, 24.77it/s][A
13042it [10:03, 18.94it/s][A
13045it [10:03, 18.08it/s][A
13050it [10:03, 21.75it/s][A
13053it [10:03, 23.10it/s][A
13057it [10:03, 25.51it/s][A
13061it [10:04, 27.37it/s][A
13065it [10:04, 28.28it/s][A
13069it [10:04, 26.08it/s][A
13074it [10:04, 30.07it/s][A
13078it [10:04, 28.52it/s][A
13083it [10:04, 31.10it/s][A
13087it [10:04, 31.20it/s][A
13091it [10:04, 30.70it/s][A
13095it [10:05, 32.31it/s][A
13099it [10:05, 25.92it/s][A
13102it [10:05, 21.44it/s][A
13106it [10:05, 22.84it/s][A
13109it [10:05, 18.32it/s][A
13112it [10:06, 16.77it/s][A
13115it [10:06, 18.42it/s][A
13118it [10:06, 18.69it/s][A
13122it [10:06, 18.69it/s][A
13124it [10:06, 16.71it/s][A
13129it [10:06, 19.84it/s][A
13133it [10:07, 22.02it/s][A
13136it [10:07, 19.13it/s][A
13139it [1

14823it [11:25, 20.96it/s][A
14826it [11:26, 20.70it/s][A
14830it [11:26, 22.34it/s][A
14835it [11:26, 26.59it/s][A
14839it [11:26, 28.65it/s][A
14843it [11:26, 29.00it/s][A
14847it [11:26, 26.10it/s][A
14850it [11:26, 25.40it/s][A
14853it [11:27, 18.38it/s][A
14856it [11:27, 15.72it/s][A
14858it [11:27, 12.95it/s][A
14860it [11:27, 12.99it/s][A
14862it [11:27, 12.71it/s][A
14864it [11:28, 14.14it/s][A
14866it [11:28, 15.22it/s][A
14869it [11:28, 16.96it/s][A
14872it [11:28, 18.32it/s][A
14874it [11:28, 16.72it/s][A
14876it [11:28, 15.92it/s][A
14878it [11:28, 13.17it/s][A
14880it [11:29, 14.11it/s][A
14883it [11:29, 15.95it/s][A
14886it [11:29, 18.51it/s][A
14889it [11:29, 20.59it/s][A
14892it [11:29, 21.67it/s][A
14895it [11:29, 22.16it/s][A
14898it [11:29, 21.21it/s][A
14901it [11:29, 22.36it/s][A
14904it [11:30, 22.24it/s][A
14907it [11:30, 21.49it/s][A
14910it [11:30, 23.12it/s][A
14913it [11:30, 23.99it/s][A
14916it [11:30, 25.11it/s][A
14919it [1

16627it [12:49, 24.69it/s][A
16630it [12:49, 25.52it/s][A
16633it [12:49, 23.79it/s][A
16636it [12:49, 23.64it/s][A
16639it [12:49, 23.89it/s][A
16642it [12:49, 23.98it/s][A
16645it [12:49, 24.83it/s][A
16648it [12:50, 24.16it/s][A
16651it [12:50, 24.40it/s][A
16654it [12:50, 22.59it/s][A
16658it [12:50, 24.18it/s][A
16661it [12:50, 24.95it/s][A
16664it [12:50, 24.99it/s][A
16667it [12:50, 18.29it/s][A
16670it [12:51, 16.03it/s][A
16672it [12:51, 16.78it/s][A
16677it [12:51, 19.67it/s][A
16680it [12:51, 20.11it/s][A
16683it [12:51, 17.45it/s][A
16685it [12:51, 16.95it/s][A
16687it [12:52, 16.13it/s][A
16689it [12:52, 15.00it/s][A
16691it [12:52, 12.32it/s][A
16693it [12:52, 12.37it/s][A
16695it [12:52, 12.61it/s][A
16698it [12:52, 13.98it/s][A
16700it [12:53, 14.82it/s][A
16703it [12:53, 16.16it/s][A
16705it [12:53, 15.94it/s][A
16707it [12:53, 15.80it/s][A
16710it [12:53, 16.36it/s][A
16713it [12:53, 17.15it/s][A
16715it [12:53, 17.89it/s][A
16718it [1

18392it [14:12, 22.41it/s][A
18395it [14:12, 21.84it/s][A
18398it [14:12, 21.43it/s][A
18401it [14:12, 18.69it/s][A
18406it [14:12, 22.59it/s][A
18410it [14:12, 22.28it/s][A
18413it [14:12, 22.78it/s][A
18416it [14:13, 21.59it/s][A
18419it [14:13, 19.87it/s][A
18422it [14:13, 19.79it/s][A
18424it [14:13, 21.59it/s][A

In [5]:
reload(vocab_index_descriptions)
reload(datasets)
vocab_index_descriptions.vocab_index_descriptions('%s/vocab_unfiltered.csv' % DISCH_DIR, '%s/description_vectors_unfiltered.vocab' % DATA_DIR)

18424it [37:55,  8.10it/s]


## Filter each split to the top K diagnosis/procedure codes

In [2]:
#first calculate the top k
counts = Counter()
dfnl = pd.read_csv('%s/notes_full_labeled.csv' % DISCH_DIR)
for row in dfnl.itertuples():
    for label in str(row[4]).split(';'):
        counts[label] += 1

In [3]:
counts

Counter({'360.03': 2,
         '360.02': 1,
         '360.01': 5,
         '289.51': 19,
         '289.59': 114,
         '801.82': 3,
         '801.80': 2,
         '896.0': 1,
         '896.1': 1,
         '289.52': 3,
         '161.9': 10,
         '288.50': 139,
         '525.79': 1,
         '852.26': 92,
         '852.21': 363,
         '852.20': 239,
         '852.23': 1,
         '852.22': 77,
         '161.1': 6,
         '161.0': 6,
         '161.3': 2,
         '161.2': 3,
         '852.29': 4,
         '153.2': 16,
         '153.3': 36,
         '153.0': 7,
         '153.1': 15,
         '153.6': 37,
         '153.7': 8,
         '153.4': 29,
         '153.5': 4,
         '917.2': 5,
         '917.3': 1,
         '153.8': 31,
         '917.1': 2,
         '345.41': 23,
         '45.9': 1,
         '45.8': 55,
         '45.3': 14,
         '45.0': 877,
         '421.9': 8,
         '421.0': 380,
         '17.33': 6,
         '17.36': 7,
         '17.35': 3,
         '301.83'

In [4]:
proc_5k = set()
diag_5k = set()
i = 0
cnts = sorted(counts.items(), key=operator.itemgetter(1), reverse=True)
while len(proc_5k) < 1000 or len(diag_5k) < 4000:
    code = cnts[i][0]                              
    if '.' in code:             
        if code.index('.') == 3 and len(diag_5k) < 4000:
            diag_5k.add(code)
        elif code.index('.') == 2 and len(proc_5k) < 1000:
             proc_5k.add(code)
    else:                                                                                                                            
        if len(code) == 3 and len(diag_5k) < 4000:                                                                                   
            diag_5k.add(code)
        elif len(code) == 2 and len(proc_5k) < 1000:
            proc_5k.add(code)
    i += 1


In [5]:
#how many of each code are there?
num_diag = 0
num_proc = 0
proc = set()
for code, count in cnts:
    if '.' in code:
        if code.index('.') == 3 or code[0] in ['V', 'E']:
            num_diag += 1
        elif code.index('.') == 2:
            num_proc += 1
            proc.add(code)
    else:
        if len(code) == 3 or code[0] in ['V', 'E']:
            num_diag += 1
        elif len(code) == 2:
            num_proc += 1
            proc.add(code)

In [6]:
print(num_proc, num_diag)

(2003, 6919)


In [7]:
codes_5k = proc_5k.union(diag_5k)
with open('%s/TOP_5000_CODES.csv' % DISCH_DIR, 'w') as of:
    w = csv.writer(of)
    for code in codes_5k:
        w.writerow([code])

In [4]:
codes_50 = sorted(counts.items(), key=operator.itemgetter(1), reverse=True)

In [6]:
codes_50 = [code[0] for code in codes_50[:50]]

In [7]:
codes_50

['401.9',
 '38.93',
 '428.0',
 '427.31',
 '414.01',
 '96.04',
 '96.6',
 '584.9',
 '250.00',
 '96.71',
 '272.4',
 '518.81',
 '99.04',
 '39.61',
 '599.0',
 '530.81',
 '96.72',
 '272.0',
 '285.9',
 '88.56',
 '244.9',
 '486',
 '38.91',
 '285.1',
 '36.15',
 '276.2',
 '496',
 '99.15',
 '995.92',
 'V58.61',
 '507.0',
 '038.9',
 '88.72',
 '585.9',
 '403.90',
 '311',
 '305.1',
 '37.22',
 '412',
 '33.24',
 '39.95',
 '287.5',
 '410.71',
 '276.1',
 'V45.81',
 '424.0',
 '45.13',
 'V15.82',
 '511.9',
 '37.23']

In [8]:
with open('%s/TOP_50_CODES.csv' % DISCH_DIR, 'w') as of:
    w = csv.writer(of)
    for code in codes_50:
        w.writerow([code])

In [9]:
for splt in ['train', 'dev', 'test']:
    print(splt)
    with open('%s/%s_full.csv' % (DISCH_DIR, splt), 'r') as f:
        with open('%s/%s_50.csv' % (DISCH_DIR, splt), 'w') as of:
            r = csv.reader(f)
            w = csv.writer(of)
            #header
            w.writerow(next(r))
            i = 0
            for row in r:
#                 if splt == 'train' and i >= 36998:
#                     break
                codes = set(str(row[3]).split(';'))
                filtered_codes = codes.intersection(set(codes_50))
                if len(filtered_codes) > 0:
                    w.writerow(row[:3] + [';'.join(filtered_codes), row[4]])
                    i += 1

train
dev
test


In [14]:
df = pd.read_csv('%s/train_50.csv' % DISCH_DIR)

In [18]:
df = df.sample(frac=1)

In [27]:
for row in df.itertuples():
    print(row[2])

141121
134677
169066
113918
163301
122555
170502
190054
102641
157253
129578
161239
192006
139742
184266
119156
123036
146794
183668
166101
192707
103597
141254
194451
140382
130064
196050
134850
102783
113759
124463
194531
112680
142983
163230
135819
178435
118787
120032
141586
105445
187570
191070
122724
133936
124518
126054
129627
100445
181979
154591
174955
109080
174618
199483
197492
193191
179358
125898
173546
107153
145974
116604
160191
160863
123476
136227
160250
117669
127899
167243
192333
144959
100738
175324
169236
175892
135621
182444
135894
139888
150908
194440
103252
133641
123997
116110
175057
147304
172002
177983
122269
188025
144362
102889
125868
175955
184708
126881
149261
135714
196143
135859
121123
133449
122468
188803
136724
138528
187776
188732
187112
147717
172145
143701
189077
144831
184159
177735
147496
188297
114997
178691
118115
155360
101763
122585
136748
172959
135473
121710
129505
161398
182907
178271
130856
123745
171866
152593
170218
167553
173317
140498

192351
177026
103932
106229
190017
173465
192568
106770
136302
130565
146495
105279
146586
155493
171658
175760
103195
169041
148681
187281
185570
126744
147559
134189
147569
166507
151911
161071
181546
163133
199844
172567
106452
182705
187798
194969
149744
167883
134366
174440
155563
178914
175931
160658
136571
146468
194762
136694
109931
173517
134640
120222
181578
134624
190139
171997
188239
136572
157159
175150
175518
164690
120044
177445
148370
177529
199385
178481
136527
122857
129831
168130
188722
158283
116284
174504
111118
141062
157915
125302
119754
187804
175745
153265
175298
169690
192925
141701
137846
100811
191278
120702
150015
126020
117641
142325
156471
199593
141243
171920
115846
164341
150633
125596
191544
169220
192555
129828
134521
184834
183330
197042
121399
167280
117041
137107
171038
165655
111867
104594
195632
185989
153677
187735
157053
162560
113492
134350
172469
155000
161627
181716
103563
179998
139862
127415
147385
152548
131203
117993
103565
139617
105817

142290
191929
146962
161951
140497
178464
130176
151459
194386
176583
164846
139731
153733
177836
155525
175241
167090
135435
140212
146277
106153
144770
137140
116122
182546
196822
107917
130822
163740
196850
166154
198053
193335
172524
159865
119210
158369
147697
167420
107492
161558
198554
118689
149688
159785
100402
151618
176519
148557
129730
163511
143672
118722
129421
156697
140628
100211
112328
143406
174134
146788
179644
126813
160172
122192
159143
111319
131145
102234
177739
116367
118896
181403
129909
154357
116681
107075
109028
116501
191920
103005
137965
114105
134436
102339
146418
152667
174066
153143
164379
119061
162902
185392
171946
195147
114725
168629
179237
152440
134538
123859
105918
143660
197166
179603
172527
189751
125606
191424
123368
132056
135286
173520
149818
147794
148589
126390
178489
117467
190537
140840
110621
174578
114027
160904
121039
158006
107093
150402
165554
129887
110745
137235
140266
104850
143578
126406
126947
140252
126319
101994
149260
147567

182352
128120
147259
155072
119687
194254
182236
140753
199466
108045
139977
103819
123772
102513
127841
149884
165287
149256
149868
186293
150902
193024
194047
169184
117549
125023
118492
113567
113254
129564
175791
156530
105763
190598
113936
170023
141167
145764
169827
177278
162999
162426
165821
153722
101005
143224
139189
152237
135175
185691
193573
134128
147783
186746
171476
135402
195219
110739
165532
142459
130436
151192
193604
195432
147241
102898
111341
172919
149462
127959
138939
106827
185376
117745
105677
131003
116095
102912
106296
190281
117632
107534
120810
196460
131768
153652
153335
181185
100804
102124
179377
195996
186328
127785
152199
190183
198770
128643
119133
126688
100764
141014
177817
184591
159236
117413
130140
126330
199943
170026
180827
187852
101046
185952
127259
175905
132864
112403
180977
106319
138229
100665
171768
193636
146086
180701
114834
194894
172299
150225
137073
113312
190313
179653
158461
159459
166275
174646
178299
130416
176657
102010
156232

170280
166538
142953
147889
169312
107212
124628
175544
188404
154822
175004
176974
132745
127724
111092
123369
143298
155477
185356
133906
134182
108746
190470
108813
158469
180667
175438
121613
100257
141864
178330
178918
197101
175063
164976
184712
114318
168075
187527
123818
140282
143664
109145
173436
166377
145143
164088
171221
124082
175538
185061
116576
162299
160511
166336
134391
122216
194998
128224
119670
158939
118874
179533
115229
161234
149481
117319
100391
178447
199487
122163
181072
183314
188471
148452
192135
102195
114751
116864
101077
134304
100870
136948
183508
172332
136841
150534
104683
140803
178446
104158
183487
110636
155553
162085
108153
162867
190446
106108
143690
152652
153202
110395
189720
174216
138299
151775
115664
168046
113249
128878
182657
184117
128487
118760
124071
183233
173943
104898
123420
101092
192507
151985
195684
110445
171445
112399
128907
145967
122980
111465
161985
123235
142533
199482
123931
154558
154221
195716
134967
140608
145728
168067

KeyboardInterrupt: 

In [34]:
Y = 50
#make smaller datasets
for splt in ['train', 'dev', 'test']:
    print(splt)
    df = pd.read_csv('%s/%s_%s.csv' % (DISCH_DIR, splt, str(Y)))
    df = df.sample(frac=1)
    with open('%s/%s_%s_smaller.csv' % (DISCH_DIR, splt, str(Y)), 'w') as of:
        w = csv.writer(of)
        #header
        w.writerow(['SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS', 'length'])
        i = 0
        for row in df.itertuples():
            if (splt == 'train' and i >= 8066) \
                or (splt == 'dev' and i >= 1728) \
                or (splt == 'test' and i >= 1729):
                break
            codes = set(str(row[4]).split(';'))
            filtered_codes = codes.intersection(set(codes_50))
            if len(filtered_codes) > 0:
                w.writerow([row[1], row[2], row[3], ';'.join(filtered_codes), row[5]])
                i += 1

train
dev
test


In [35]:
#sort them
for splt in ['train', 'dev', 'test']:
    print(splt)
    df = pd.read_csv('%s/%s_%s_smaller.csv' % (DISCH_DIR, splt, str(Y)))
    df = df.sort_values(['length'])
    df.to_csv('%s/%s_%s_smaller.csv' % (DISCH_DIR, splt, str(Y)), index=False)

train
dev
test


# Bucket codes by their frequency into a lookup

In [144]:
reload(datasets)
#redo the code lookups IF model was rare-label
code_freqs, n = datasets.load_code_freqs('%s/train_full.csv' % DISCH_DIR)
freqs = sorted(code_freqs.iteritems(), key=operator.itemgetter(1), reverse=True)

In [145]:
fvs = np.array([f[1] for f in freqs])

In [146]:
num_bins = 11

In [147]:
deciles = np.unique([np.percentile(fvs, (i+1)*(100/num_bins)) for i in range(num_bins)])

In [148]:
deciles

array([  1.89659750e-05,   3.79319501e-05,   5.68979251e-05,
         9.48298752e-05,   1.32761825e-04,   2.46557676e-04,
         4.36217426e-04,   9.29332777e-04,   2.63627053e-03,
         3.10852331e-02])

In [149]:
code_freq_bins = defaultdict(int)
for code, freq in code_freqs.iteritems():
    for i,dec in enumerate(deciles):
        if freq <= dec:
            code_freq_bins[code] = i
            break

In [150]:
bins = np.array(code_freq_bins.values())

In [151]:
[len(np.where(bins == i)[0]) for i in range(len(deciles))]

[1900, 1013, 607, 818, 485, 894, 738, 785, 796, 795]

In [152]:
with open('%s/code_freq_bins.csv', 'w') as of:
    w = csv.writer(of)
    for code, bin in code_freq_bins.iteritems():


defaultdict(int,
            {'289.52': 2,
             '360.02': 0,
             '289.50': 4,
             '360.00': 5,
             '289.59': 8,
             '801.82': 2,
             '757.5': 0,
             '801.80': 1,
             '896.0': 0,
             '896.1': 0,
             '360.03': 1,
             '852.25': 7,
             '852.24': 3,
             '525.79': 0,
             '852.26': 8,
             '852.21': 9,
             '852.20': 9,
             '852.23': 0,
             '852.22': 8,
             '161.1': 4,
             '161.0': 4,
             '161.3': 1,
             '289.51': 6,
             '852.29': 3,
             '153.2': 6,
             '153.3': 7,
             '153.0': 4,
             '153.1': 6,
             '153.6': 7,
             '153.7': 5,
             '153.4': 7,
             '153.5': 3,
             '917.2': 3,
             '917.3': 0,
             '153.8': 7,
             '917.1': 1,
             '345.41': 6,
             '45.9': 0,
             '4

## Augment each data split with its ancestor codes

Basically, we have two training regimes: 
- one with only leaf codes (ground truth codes are leaf codes only in MIMIC)
- one with codes augmented with all their ancestors

So we'll have two sets of codes. The stuff below will augment the leaf-only codes with the ancestor codes

In [69]:
from dataproc import format_test_set

In [70]:
reload(format_test_set)
ancestors = format_test_set.construct_ancestors()
_, _, _, c2ind, _, _ = datasets.load_lookups(Y='full', vocab_file='%s/vocab.csv' % DISCH_DIR)

constructing ancestors dictionary


In [127]:
ancestors

defaultdict(<function dataproc.format_test_set.<lambda>>,
            {'360.04': array(['360.0', '360', '360-379.99', '320-389.99', '001-999.99', '@',
                    '360.04'], 
                   dtype='|S10'),
             '360.03': array(['360.0', '360', '360-379.99', '320-389.99', '001-999.99', '@',
                    '360.03'], 
                   dtype='|S10'),
             '360.02': array(['360.0', '360', '360-379.99', '320-389.99', '001-999.99', '@',
                    '360.02'], 
                   dtype='|S10'),
             '360.01': array(['360.0', '360', '360-379.99', '320-389.99', '001-999.99', '@',
                    '360.01'], 
                   dtype='|S10'),
             '360.00': array(['360.0', '360', '360-379.99', '320-389.99', '001-999.99', '@',
                    '360.00'], 
                   dtype='|S10'),
             'E802.1': array(['E802', 'E800-E807.9', 'E800-E999.9', '@', 'E802.1'], 
                   dtype='|S11'),
             'E802.0': array

In [129]:
reload(format_test_set)
for split in ['train', 'dev', 'test']:
    print(split)
    infile = '%s/%s_full_disch.csv' % (DISCH_DIR, split)
    format_test_set.write_codes(ancestors, c2ind, infile)

16it [00:00, 142.45it/s]

train
writing dataset of indexified codes with ancestors
('leaf not in ancestors:', '11.')


47723it [07:20, 53.29it/s] 
13it [00:00, 127.32it/s]

num codes with just root code as ancestor: 366
dev
writing dataset of indexified codes with ancestors
('leaf not in ancestors:', '998.30')
('leaf not in ancestors:', '707.24')
('leaf not in ancestors:', '348.82')
('leaf not in ancestors:', 'V49.86')
('leaf not in ancestors:', '530.13')


1631it [00:17, 95.01it/s] 
13it [00:00, 129.33it/s]

num codes with just root code as ancestor: 150
test
writing dataset of indexified codes with ancestors
('leaf not in ancestors:', 'V49.86')
('leaf not in ancestors:', '32.')
('leaf not in ancestors:', '23.9')
('leaf not in ancestors:', 'V87.41')
('leaf not in ancestors:', 'V87.41')


3372it [00:36, 92.29it/s] 


num codes with just root code as ancestor: 210


In [44]:
reload(datasets)
ind2c, desc_dict = datasets.load_full_codes(True)

In [31]:
c2ind = {c:i for i,c in ind2c.iteritems()}

In [51]:
c2ind

{'289.52': 1,
 '360.02': 2,
 '289.50': 3,
 '360.00': 4,
 '289.59': 6,
 '801.82': 7,
 '801.80': 8,
 '896.0': 9,
 '896.1': 10,
 '360.03': 11,
 '852.25': 12,
 '288.50': 13,
 '525.79': 14,
 '852.26': 15,
 '852.21': 16,
 '852.20': 17,
 '852.23': 18,
 '852.22': 19,
 '161.1': 20,
 '161.0': 21,
 '161.3': 22,
 '289.51': 23,
 '852.29': 24,
 '153.2': 25,
 '153.3': 26,
 '153.0': 27,
 '153.1': 28,
 '153.6': 29,
 '153.7': 30,
 '153.4': 31,
 '153.5': 32,
 '917.2': 33,
 '917.3': 34,
 '153.8': 35,
 '917.1': 36,
 '345.41': 8019,
 '842.0': 37,
 '421.9': 38,
 '421.0': 39,
 '301.83': 41,
 '301.81': 42,
 '270': 43,
 '271': 44,
 '272': 45,
 '429.3': 46,
 '429.4': 251,
 '275': 48,
 '429.6': 256,
 '429.7': 50,
 '278': 51,
 '429.9': 52,
 'V15.06': 53,
 'V15.07': 54,
 'V15.01': 55,
 'V15.02': 56,
 '770.18': 58,
 '360.01': 59,
 '770.12': 60,
 '770.16': 61,
 'E935.2': 62,
 'E935.3': 316,
 'E935.1': 64,
 'E935.6': 65,
 'E935.7': 66,
 'E935.4': 67,
 '852.24': 68,
 'E935.8': 69,
 'E935.9': 70,
 '434.10': 71,
 '434.11

In [53]:
import hierarchical_eval

In [54]:
model_dir = '../../saved_models/conv_attn_Oct_09_06:33'
preds_file = '%s/preds_dev.csv' % model_dir

In [91]:
reload(hierarchical_eval)
y_true, y_true_np, ancestors, new_preds, TD = hierarchical_eval.load_stuff('%s/dev_full_disch_2.csv' % (DISCH_DIR), \
                                                                           '%s/dev_full_disch_2_ancs.csv' % DISCH_DIR, \
                                                                           '%s/heval_data/ICD9_parent_child_relations' % DISCH_DIR, \
                                                                           preds_file, model_dir, 8867, 6739, 8049, c2ind)

In [81]:
predictions = []
with open(preds_file) as f:
    for line in f:
        line = line.strip().split('|')
        if line[1] != '':
            predictions.append([int(x) for x in line[1:]])
        else:
            predictions.append([])

In [85]:
import numpy as np
root = 6739
pred_array = []
for i in range(TD):
    try:
        pred_array.append(np.array(list(predictions[i])))
    except:
        pred_array.append(np.array([root]))
new_preds = []
for i in range(len(pred_array)):
    preds = set(list(pred_array[i]))
    anc_preds = set([])
    for pred in preds:
        anc_preds.update(list(ancestors[pred]))
    new_preds.append(list(anc_preds))


In [98]:
y_true[0].nonzero()[0]

array([  49,   78,  534, 1313, 1449, 1860, 2289, 2531, 3336, 3391, 4121,
       4177, 4706, 5037, 5174, 5945, 5949, 6123, 6221, 6514, 6739, 7224,
       7227, 7436, 8659])

In [105]:
reload(hierarchical_eval)
print(new_preds[0])
pred = set(hierarchical_eval.only_leaves(new_preds[0], ancestors))

['001-999.99', '@', '335', '520-579.99', '383.3', '533.6', '383', '533', '335.0', '380-389.99', '320-389.99', '530-538', '330-337.99']


In [107]:
gs = set(y_true_np[0].nonzero()[0])

In [108]:
gs

{78, 2531, 4177, 4706, 5037}