In [None]:
# default_exp preprocessing.clean

In [None]:
#all_slow

# Clean

> Functions to split the raw EHR dataset, clean and save for further processing & vocab creation.

In [None]:
#hide
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:85% !important; }</style>"))

In [None]:
#hide
%reload_ext autoreload
%autoreload 2

In [None]:
#export
from lemonpie.basics import *
from lemonpie.preprocessing import clean
from fastai.imports import *
import ray

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
ray.init()

2022-09-20 13:43:38,393	INFO services.py:1245 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


{'node_ip_address': '192.168.86.91',
 'raylet_ip_address': '192.168.86.91',
 'redis_address': '192.168.86.91:6379',
 'object_store_address': '/tmp/ray/session_2022-09-20_13-43-36_807789_9927/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2022-09-20_13-43-36_807789_9927/sockets/raylet',
 'webui_url': '127.0.0.1:8265',
 'session_dir': '/tmp/ray/session_2022-09-20_13-43-36_807789_9927',
 'metrics_export_port': 59579,
 'node_id': '76cb48652fa309ae956b87798a80ffb0fddcc57927b3d8b43eee85c7'}

In [None]:
COHERENT_DATA_STORE = '/home/vinod/code/datasets/coherent'
COHERENT_DATAGEN_DATE = '08-10-2021'
COHERENT_CONDITIONS = {
    "heart_failure" : "88805009",
    "coronary_heart" : "53741008",
    "myocardial_infarction" : "22298006",
    "stroke" : "230690007",
    "cardiac_arrest" : "410429000"
}

# Coherent Preprocessing

**Retain only patients with FHIR bundles.**

In [None]:
def retain_fhir_patients(coherent_path, csv_names):
    """Retain only patients with FHIR bundles."""

    # read pids with FHIR bundles
    file_list = os.listdir(f'{coherent_path}/output/fhir')
    fhir_pids = [((file).split("_")[-1]).split(".")[0] for file in file_list]

    # filter and retain only FHIR patients in all files
    print(f"Writing filtered files to {coherent_path}/raw_original/")
    for file in csv_names:
        old_df = pd.read_csv(f"{coherent_path}/output/csv/{file}.csv", low_memory=False)
        if file == 'patients':
            fhir_mask = old_df.Id.isin(fhir_pids)
        else:
            fhir_mask = old_df.PATIENT.isin(fhir_pids)
        new_df = old_df[fhir_mask]
        assert len(new_df) == fhir_mask.sum(), f"Count error in {file}"
        new_df.to_csv(f"{coherent_path}/raw_original/{file}.csv", index=False)
        print(f"Created {file} with {len(new_df)} records.")
    

**Remove ECG from observations and create ecg.csv**

In [None]:
def moveout_ecg(coherent_path):
    """Move ECG data out of Observations into its own csv."""
    
    old_obs = pd.read_csv(f"{coherent_path}/raw_original/observations.csv", low_memory=False)
    ecg_obs = old_obs[old_obs["CODE"] == "29303009"]
    new_obs = old_obs.drop(ecg_obs.index)
    assert len(new_obs) == len(old_obs) - len(ecg_obs), "Mismatch after ECG removal from Observations"
    new_obs.to_csv(f"{coherent_path}/raw_original/observations.csv", index=False)
    print(f"Updated observations without ECG data = {len(new_obs)} records")

    ecg_obs.reset_index(inplace=True, drop=True)
    odd_indxs = [i for i in range(1, len(ecg_obs), 2)]
    ecg_obs.drop(odd_indxs, inplace=True)
    ecg_obs.drop(columns=["ENCOUNTER", "CODE", "DESCRIPTION", "UNITS", "TYPE"], inplace=True)
    ecg_obs.rename(str.lower, axis='columns', inplace=True)
    ecg_obs.to_csv(f"{coherent_path}/ecg.csv", index=False)
    print(f"Saved ECG data to {coherent_path}/ecg.csv with {len(ecg_obs)} records")

**Create `modalities.csv`**

In [None]:
def create_modalities_csv(coherent_path):
    """Create modalities csv."""
    
    # dna - counts off by 1, because no FHIR bunde for 1 pt with dna data
    dna_files = os.listdir(f'{coherent_path}/output/dna')
    dna_pids = [file.split("_")[-2] for file in dna_files]
    
    # mri
    mri_files = os.listdir(f'{coherent_path}/output/dicom')
    mri_pids = [file.split("_")[-1].split(".")[0][:-1]  for file in mri_files]
    
    # ecg
    ecg_data = pd.read_csv(f"{coherent_path}/ecg.csv")
    ecg_pids = ecg_data.patient.unique()

    # create modalities csv
    patients = pd.read_csv(f"{coherent_path}/raw_original/patients.csv", low_memory=False)
    modalities = patients[["Id", "FIRST", "LAST"]].copy()
    modalities.rename(str.lower, axis='columns', inplace=True)

    modalities["mri"] = patients.Id.isin(mri_pids)
    modalities["dna"] = patients.Id.isin(dna_pids)
    modalities["ecg"] = patients.Id.isin(ecg_pids)

    modalities.to_csv(f"{coherent_path}/modalities.csv", index=False)
    print(f"Saved modalities to {coherent_path}/modalities.csv")

In [None]:
def coherent_preprocess(coherent_path=COHERENT_DATA_STORE, csv_names=FILENAMES):
    """Perform coherent-specific preprocessing."""

    # create raw_original dir
    raw_dir = Path(f'{coherent_path}/raw_original')
    raw_dir.mkdir(parents=True, exist_ok=True)

    # filter patients to keep only those with FHIR bundles
    print("--Filtering & retaining patients with FHIR bundles--")
    retain_fhir_patients(coherent_path, csv_names)

    # move ECG data out of observations
    print("--Moving ECG data out of observations into its own ecg.csv--")
    moveout_ecg(coherent_path)

    # create modalities file
    print("--Creating modalities.csv--")
    create_modalities_csv(coherent_path)

In [None]:
coherent_preprocess()

--Filtering & retaining patients with FHIR bundles--
Writing filtered files to /home/vinod/code/datasets/coherent/raw_original/
Created patients with 1278 records.
Created observations with 705436 records.
Created allergies with 106 records.
Created careplans with 6135 records.
Created medications with 209401 records.
Created imaging_studies with 3752 records.
Created procedures with 56092 records.
Created conditions with 15956 records.
Created immunizations with 11900 records.
--Moving ECG data out of observations into its own ecg.csv--
Updated observations without ECG data = 703292 records
Saved ECG data to /home/vinod/code/datasets/coherent/ecg.csv with 1072 records
--Creating modalities.csv--
Saved modalities to /home/vinod/code/datasets/coherent/modalities.csv


In [None]:
clean.clean_raw_ehrdata(COHERENT_DATA_STORE, 0.2, 0.2, COHERENT_CONDITIONS, COHERENT_DATAGEN_DATE)

Splits:: train: 0.6, valid: 0.2, test: 0.2
Split patients into:: Train: 766, Valid: 256, Test: 256 -- Total before split: 1278
Saved train data to /home/vinod/code/datasets/coherent/raw_split/train
Saved valid data to /home/vinod/code/datasets/coherent/raw_split/valid
Saved test data to /home/vinod/code/datasets/coherent/raw_split/test
Completed - test
Completed - valid
[2m[36m(pid=11052)[0m Saved cleaned "valid" data to /home/vinod/code/datasets/coherent/cleaned/valid
[2m[36m(pid=11055)[0m Saved cleaned "test" data to /home/vinod/code/datasets/coherent/cleaned/test
[2m[36m(pid=11035)[0m Saved cleaned "train" data to /home/vinod/code/datasets/coherent/cleaned/train
Completed - train


[2m[36m(pid=11035)[0m Saved vocab code tables to /home/vinod/code/datasets/coherent/cleaned/train/codes


# COHERENT

In [None]:
train_dfs, valid_dfs, test_dfs = clean.load_cleaned_ehrdata(COHERENT_DATA_STORE)
code_dfs = clean.load_ehr_vocabcodes(COHERENT_DATA_STORE)

In [None]:
# for df in train_dfs:
#     display(df.head())

In [None]:
thispt = train_dfs[0].iloc[10]

In [None]:
thispt

patient                      967d5226-f8c4-60a8-b882-6ef803af88a6
birthdate                                              1930-04-29
heart_failure                                               False
heart_failure_age                                             NaN
coronary_heart                                              False
coronary_heart_age                                            NaN
myocardial_infarction                                       False
myocardial_infarction_age                                     NaN
stroke                                                       True
stroke_age                                                   87.0
cardiac_arrest                                              False
cardiac_arrest_age                                            NaN
Name: 10, dtype: object

In [None]:
# for df in code_dfs:
#     display(df.head())

Making sure condition counts match - after extracting `y` for each patient

`patients` dfs after cleaning, with `y` extracted

In [None]:
pts_train, pts_valid, pts_test = train_dfs[0], valid_dfs[0], test_dfs[0]

`conditions` dfs

In [None]:
cnd_train, cnd_valid, cnd_test = train_dfs[8], valid_dfs[8], test_dfs[8]

Tests to ensure counts match

In [None]:
def test_extract_ys(pt_dfs, cnd_dfs, conditions_dict=COHERENT_CONDITIONS):
    """Test for extract_ys function."""
    for pts_df, cnds_df, split in zip(pt_dfs, cnd_dfs, ['train','valid','test']):
        print(f"Checking {split} dfs...")
        for this_cnd in conditions_dict.keys():
            code = f"{conditions_dict[this_cnd]}||START"
            cnds_df_counts = len(cnds_df[cnds_df['code'] == code])
            pts_df_counts = len(pts_df[pts_df[this_cnd] == 1])
            assert cnds_df_counts == pts_df_counts, f"Error in {split} for {this_cnd} -- {cnds_df_counts} != {pts_df_counts}"

        print(f"Tests passed for {split} - all condition counts match")
    return

In [None]:
test_extract_ys([pts_train, pts_valid, pts_test],[cnd_train, cnd_valid, cnd_test])

Checking train dfs...
Tests passed for train - all condition counts match
Checking valid dfs...
Tests passed for valid - all condition counts match
Checking test dfs...
Tests passed for test - all condition counts match


In [None]:
clean.get_label_counts([pts_train, pts_valid, pts_test], conditions_dict=COHERENT_CONDITIONS)

[{'heart_failure': 189,
  'coronary_heart': 194,
  'myocardial_infarction': 82,
  'stroke': 435,
  'cardiac_arrest': 107},
 {'heart_failure': 68,
  'coronary_heart': 66,
  'myocardial_infarction': 28,
  'stroke': 138,
  'cardiac_arrest': 30},
 {'heart_failure': 75,
  'coronary_heart': 76,
  'myocardial_infarction': 35,
  'stroke': 125,
  'cardiac_arrest': 43}]

In [None]:
labels = ['heart_failure', 'coronary_heart', 'myocardial_infarction', 'stroke', 'cardiac_arrest']

In [None]:
from lemonpie.preprocessing import vocab, transform
from lemonpie.data import *

In [None]:
transform.preprocess_ehr_dataset(
    COHERENT_DATA_STORE, 
    COHERENT_DATAGEN_DATE, 
    conditions_dict=COHERENT_CONDITIONS, 
    age_start=240, 
    age_stop=360, 
    age_in_months=True)

Since data is pre-cleaned, skipping Cleaning, Splitting and Vocab-creation
------------------- Creating patient lists -------------------


FileNotFoundError: [Errno 2] No such file or directory: '/home/vinod/code/datasets/coherent/processed/vocabs.vocablist'

In [None]:
coherent_data = EHRData(COHERENT_DATA_STORE, labels, age_start=240, age_stop=360, age_in_months=True, lazy_load_gpu=False)
demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd = vocab.get_all_emb_dims(transform.EhrVocabList.load(COHERENT_DATA_STORE))
train_dl, valid_dl, train_pos_wts, valid_pos_wts = coherent_data.get_data(bs=1024)

FileNotFoundError: [Errno 2] No such file or directory: '/home/vinod/code/datasets/coherent/processed/vocabs.vocablist'

#### `EHR_LSTM`

In [None]:
model = EHR_LSTM(demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd, num_labels=len(labels)).to(DEVICE)
train_loss_fn, valid_loss_fn = get_loss_fn(train_pos_wts), get_loss_fn(valid_pos_wts)
optimizer = torch.optim.Adagrad(model.parameters())

In [None]:
len(train_dl), len(valid_dl)

(1, 1)

In [None]:
model

EHR_LSTM(
  (embs): ModuleList(
    (0): Embedding(40, 8)
    (1): Embedding(16, 8)
    (2): Embedding(128, 8)
    (3): Embedding(8, 8)
    (4): Embedding(8, 8)
    (5): Embedding(8, 8)
    (6): Embedding(8, 8)
    (7): Embedding(248, 16)
    (8): Embedding(208, 16)
    (9): Embedding(8, 8)
    (10): Embedding(184, 16)
  )
  (embgs): ModuleList(
    (0): EmbeddingBag(536, 16, mode=mean)
    (1): EmbeddingBag(32, 8, mode=mean)
    (2): EmbeddingBag(56, 8, mode=mean)
    (3): EmbeddingBag(232, 16, mode=mean)
    (4): EmbeddingBag(16, 8, mode=mean)
    (5): EmbeddingBag(144, 8, mode=mean)
    (6): EmbeddingBag(184, 16, mode=mean)
    (7): EmbeddingBag(24, 8, mode=mean)
  )
  (input_dp): InputDropout()
  (lstm): LSTM(88, 88, num_layers=4, batch_first=True, dropout=0.3)
  (lin): Sequential(
    (0): Linear(in_features=208, out_features=416, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=416, out_features=832, bias=True)
    (4): ReL

In [None]:
h_1K = RunHistory(labels)

`use_amp=False`