# Fusion

> Fusion.

In [3]:
from lemonpie.basics import *
from lemonpie.preprocessing import clean
from fastai.imports import *
import ray

In [39]:
ray.init()

2022-09-28 21:07:55,589	INFO services.py:1245 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


{'node_ip_address': '192.168.86.91',
 'raylet_ip_address': '192.168.86.91',
 'redis_address': '192.168.86.91:6379',
 'object_store_address': '/tmp/ray/session_2022-09-28_21-07-53_738287_11841/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2022-09-28_21-07-53_738287_11841/sockets/raylet',
 'webui_url': '127.0.0.1:8265',
 'session_dir': '/tmp/ray/session_2022-09-28_21-07-53_738287_11841',
 'metrics_export_port': 55296,
 'node_id': 'f829f0cd921d340e36f6e16565515d477a36173ef7b06d427d37de6b'}

In [40]:
%reload_ext autoreload
%autoreload 2

In [5]:
COHERENT_DATA_STORE = '/home/vinod/code/datasets/coherent'
COHERENT_DATAGEN_DATE = '08-10-2021'
COHERENT_CONDITIONS = {
    "heart_failure" : "88805009",
    "coronary_heart" : "53741008",
    "myocardial_infarction" : "22298006",
    "stroke" : "230690007",
    "cardiac_arrest" : "410429000"
}

# Coherent Preprocessing

**Retain only patients with FHIR bundles.**

In [42]:
def retain_fhir_patients(coherent_path, csv_names):
    """Retain only patients with FHIR bundles."""

    # read pids with FHIR bundles
    file_list = os.listdir(f'{coherent_path}/output/fhir')
    fhir_pids = [((file).split("_")[-1]).split(".")[0] for file in file_list]

    # filter and retain only FHIR patients in all files
    print(f"Writing filtered files to {coherent_path}/raw_original/")
    for file in csv_names:
        old_df = pd.read_csv(f"{coherent_path}/output/csv/{file}.csv", low_memory=False)
        if file == 'patients':
            fhir_mask = old_df.Id.isin(fhir_pids)
        else:
            fhir_mask = old_df.PATIENT.isin(fhir_pids)
        new_df = old_df[fhir_mask]
        assert len(new_df) == fhir_mask.sum(), f"Count error in {file}"
        new_df.to_csv(f"{coherent_path}/raw_original/{file}.csv", index=False)
        print(f"Created {file} with {len(new_df)} records.")
    

**Remove ECG from observations and create ecg.csv**

In [43]:
def moveout_ecg(coherent_path):
    """Move ECG data out of Observations into its own csv."""
    
    old_obs = pd.read_csv(f"{coherent_path}/raw_original/observations.csv", low_memory=False)
    ecg_obs = old_obs[old_obs["CODE"] == "29303009"]
    new_obs = old_obs.drop(ecg_obs.index)
    assert len(new_obs) == len(old_obs) - len(ecg_obs), "Mismatch after ECG removal from Observations"
    new_obs.to_csv(f"{coherent_path}/raw_original/observations.csv", index=False)
    print(f"Updated observations without ECG data = {len(new_obs)} records")

    ecg_obs.reset_index(inplace=True, drop=True)
    odd_indxs = [i for i in range(1, len(ecg_obs), 2)]
    ecg_obs.drop(odd_indxs, inplace=True)
    ecg_obs.drop(columns=["ENCOUNTER", "CODE", "DESCRIPTION", "UNITS", "TYPE"], inplace=True)
    ecg_obs.rename(str.lower, axis='columns', inplace=True)
    ecg_obs.to_csv(f"{coherent_path}/ecg.csv", index=False)
    print(f"Saved ECG data to {coherent_path}/ecg.csv with {len(ecg_obs)} records")

**Create `modalities.csv`**

In [81]:
def create_modalities_csv(coherent_path):
    """Create modalities csv."""
    
    # dna - counts off by 1, because no FHIR bunde for 1 pt with dna data
    dna_files = os.listdir(f'{coherent_path}/output/dna')
    dna_pids = [file.split("_")[-2] for file in dna_files]
    
    # mri
    mri_files = os.listdir(f'{coherent_path}/output/dicom')
    mri_pids = [file.split("_")[-1].split(".")[0][:-1]  for file in mri_files]
    
    # ecg
    ecg_data = pd.read_csv(f"{coherent_path}/ecg.csv")
    ecg_pids = ecg_data.patient.unique()

    # create modalities csv
    patients = pd.read_csv(f"{coherent_path}/raw_original/patients.csv", low_memory=False)
    modalities = patients[["Id", "FIRST", "LAST"]].copy()
    modalities.rename(str.lower, axis='columns', inplace=True)

    modalities["mri"] = modalities.id.isin(mri_pids)
    modalities["dna"] = modalities.id.isin(dna_pids)
    modalities["ecg"] = modalities.id.isin(ecg_pids)

    modalities["mri"].replace({True:1, False:0}, inplace=True)
    modalities["dna"].replace({True:10, False:0}, inplace=True)
    modalities["ecg"].replace({True:20, False:0}, inplace=True)
    modalities["type"] = modalities["mri"] + modalities["dna"] + modalities["ecg"]

    modalities.to_csv(f"{coherent_path}/modalities.csv", index=False)
    print(f"Saved modalities to {coherent_path}/modalities.csv")

In [82]:
def coherent_preprocess(coherent_path=COHERENT_DATA_STORE, csv_names=FILENAMES):
    """Perform coherent-specific preprocessing."""

    # create raw_original dir
    raw_dir = Path(f'{coherent_path}/raw_original')
    raw_dir.mkdir(parents=True, exist_ok=True)

    # filter patients to keep only those with FHIR bundles
    print("--Filtering & retaining patients with FHIR bundles--")
    retain_fhir_patients(coherent_path, csv_names)

    # move ECG data out of observations
    print("--Moving ECG data out of observations into its own ecg.csv--")
    moveout_ecg(coherent_path)

    # create modalities file
    print("--Creating modalities.csv--")
    create_modalities_csv(coherent_path)

In [83]:
coherent_preprocess()

--Filtering & retaining patients with FHIR bundles--
Writing filtered files to /home/vinod/code/datasets/coherent/raw_original/
Created patients with 1278 records.
Created observations with 705436 records.
Created allergies with 106 records.
Created careplans with 6135 records.
Created medications with 209401 records.
Created imaging_studies with 3752 records.
Created procedures with 56092 records.
Created conditions with 15956 records.
Created immunizations with 11900 records.
--Moving ECG data out of observations into its own ecg.csv--
Updated observations without ECG data = 703292 records
Saved ECG data to /home/vinod/code/datasets/coherent/ecg.csv with 1072 records
--Creating modalities.csv--
Saved modalities to /home/vinod/code/datasets/coherent/modalities.csv


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  return super().drop(
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  return super().rename(


In [6]:
modalities = pd.read_csv(f"{COHERENT_DATA_STORE}/modalities.csv")

In [7]:
modalities

Unnamed: 0,id,first,last,mri,dna,ecg,type
0,9c452d24-00b0-d58f-4cd5-b82bd6695646,Sydney660,Champlin946,0,0,20,20
1,40c7c5d7-e21d-0aec-3023-bf613f37a5f1,Ryan260,Turcotte120,0,10,0,10
2,e4c173c1-99ab-865e-4094-970fd1ac8df8,Johanna547,Vandervort697,0,0,20,20
3,6f4d77e9-2203-03a3-8966-92a22a21000a,Shawnta32,Zboncak558,0,10,0,10
4,a58de3fd-f026-902b-55c1-872dc042e0c5,Contessa946,Leuschke194,1,10,0,11
...,...,...,...,...,...,...,...
1273,22e4f915-e209-4ff0-b9c9-112b5146b8c3,Sammie902,Crist667,0,0,20,20
1274,bc0cb6be-1caa-e53f-f4c4-e25d91363698,Herb645,Crooks415,0,10,0,10
1275,11d3a003-57bf-d281-60b0-8ba6a523c557,Bernardo699,Quesada500,0,10,0,10
1276,d70cac07-9852-8260-1daa-6dc227f70b39,Gilberto712,Jasso472,0,0,20,20


In [8]:
modalities.type.unique()

array([20, 10, 11,  1, 30, 31, 21,  0])

In [12]:
modalities.groupby(["type"]).count()

Unnamed: 0_level_0,id,first,last,mri,dna,ecg
type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,2,2,2,2,2,2
1,110,110,110,110,110,110
10,615,615,615,615,615,615
11,145,145,145,145,145,145
20,261,261,261,261,261,261
21,17,17,17,17,17,17
30,102,102,102,102,102,102
31,26,26,26,26,26,26


**Clean**

In [88]:
clean.clean_raw_ehrdata(COHERENT_DATA_STORE, 0.1, 0.1, COHERENT_CONDITIONS, COHERENT_DATAGEN_DATE)

Splits:: train: 0.8, valid: 0.1, test: 0.1
Split patients into:: Train: 1022, Valid: 128, Test: 128 -- Total before split: 1278
Saved train data to /home/vinod/code/datasets/coherent/raw_split/train
Saved valid data to /home/vinod/code/datasets/coherent/raw_split/valid
Saved test data to /home/vinod/code/datasets/coherent/raw_split/test
Completed - valid
Completed - test
[2m[36m(pid=30058)[0m Saved cleaned "test" data to /home/vinod/code/datasets/coherent/cleaned/test
[2m[36m(pid=30062)[0m Saved cleaned "valid" data to /home/vinod/code/datasets/coherent/cleaned/valid
[2m[36m(pid=30065)[0m Saved cleaned "train" data to /home/vinod/code/datasets/coherent/cleaned/train
[2m[36m(pid=30065)[0m Saved vocab code tables to /home/vinod/code/datasets/coherent/cleaned/train/codes
Completed - train


# Tests 

### Data

In [89]:
train_dfs, valid_dfs, test_dfs = clean.load_cleaned_ehrdata(COHERENT_DATA_STORE)
code_dfs = clean.load_ehr_vocabcodes(COHERENT_DATA_STORE)

In [90]:
# for df in train_dfs:
#     display(df.head())

In [91]:
thispt = train_dfs[0].iloc[10]

In [92]:
thispt

patient                      967d5226-f8c4-60a8-b882-6ef803af88a6
birthdate                                              1930-04-29
heart_failure                                               False
heart_failure_age                                             NaN
coronary_heart                                              False
coronary_heart_age                                            NaN
myocardial_infarction                                       False
myocardial_infarction_age                                     NaN
stroke                                                       True
stroke_age                                                   87.0
cardiac_arrest                                              False
cardiac_arrest_age                                            NaN
Name: 10, dtype: object

In [93]:
# for df in code_dfs:
#     display(df.head())

Making sure condition counts match - after extracting `y` for each patient

`patients` dfs after cleaning, with `y` extracted

In [94]:
pts_train, pts_valid, pts_test = train_dfs[0], valid_dfs[0], test_dfs[0]

`conditions` dfs

In [95]:
cnd_train, cnd_valid, cnd_test = train_dfs[8], valid_dfs[8], test_dfs[8]

Tests to ensure counts match

In [96]:
clean.test_extract_ys([pts_train, pts_valid, pts_test],[cnd_train, cnd_valid, cnd_test], conditions_dict=COHERENT_CONDITIONS)

Checking train dfs...
Tests passed for train - all condition counts match
Checking valid dfs...
Tests passed for valid - all condition counts match
Checking test dfs...
Tests passed for test - all condition counts match


In [97]:
clean.get_label_counts([pts_train, pts_valid, pts_test], conditions_dict=COHERENT_CONDITIONS)

[{'heart_failure': 257,
  'coronary_heart': 260,
  'myocardial_infarction': 110,
  'stroke': 573,
  'cardiac_arrest': 137},
 {'heart_failure': 43,
  'coronary_heart': 38,
  'myocardial_infarction': 15,
  'stroke': 61,
  'cardiac_arrest': 20},
 {'heart_failure': 32,
  'coronary_heart': 38,
  'myocardial_infarction': 20,
  'stroke': 64,
  'cardiac_arrest': 23}]

### Modalities

In [None]:
ptids_by_modality = modalities.groupby(["type"])["id"]

In [135]:
filtered = {}
for modality, ptids in ptids_by_modality:

    train_ids = pts_train[pts_train["patient"].isin(ptids)]["patient"]
    valid_ids = pts_valid[pts_valid["patient"].isin(ptids)]["patient"]
    test_ids = pts_test[pts_test["patient"].isin(ptids)]["patient"]

    filtered[modality] = [train_ids, valid_ids, test_ids]

In [138]:
for modality in filtered.keys():
    print(modality)

0
1
10
11
20
21
30
31


In [122]:
modality, ptids = next(iter(ptids_by_modality))

In [127]:
pts_train[pts_train["patient"].isin(ptids)]["patient"]

indx
299    72d3121f-639c-4824-876f-7c19dd197b7c
567    47b3acc7-8688-6559-534f-daaec268e3c3
Name: patient, dtype: object

In [129]:
pts_valid[pts_valid["patient"].isin(ptids)]["patient"]

Series([], Name: patient, dtype: object)

# Small run through lemonpie

Steps as detailed here - https://corazonlabs.github.io/lemonpie/quick_walkthru

In [25]:
labels = ['heart_failure', 'coronary_heart', 'myocardial_infarction', 'stroke', 'cardiac_arrest']

**20 years of patient data from Jan 01 1995**

In [26]:
from lemonpie import data

In [27]:
coherent_data = data.EHRData(
    COHERENT_DATA_STORE, 
    labels,     
    age_start='1995-01-01',
    age_range=20,
    start_is_date=True,
    age_in_months=False, 
    lazy_load_gpu=False)

In [28]:
from lemonpie.preprocessing import vocab

In [29]:
demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd = vocab.get_all_emb_dims(vocab.EhrVocabList.load(COHERENT_DATA_STORE))
train_dl, valid_dl, train_pos_wts, valid_pos_wts = coherent_data.get_data(bs=64)

In [30]:
len(train_dl), len(valid_dl)

(12, 2)

#### `EHR_LSTM`

In [33]:
from lemonpie import models

In [12]:
model = models.EHR_LSTM(
    demograph_dims,
    rec_dims,
    demograph_dims_wd,
    rec_dims_wd,
    len(labels),
    train_pos_wts, 
    valid_pos_wts,
    optim="adam",
    base_lr=0.001,
)




In [13]:
model

EHR_LSTM(
  (train_loss_fn): BCEWithLogitsLoss()
  (valid_loss_fn): BCEWithLogitsLoss()
  (embs): ModuleList(
    (0): Embedding(40, 8)
    (1): Embedding(16, 8)
    (2): Embedding(128, 8)
    (3): Embedding(8, 8)
    (4): Embedding(8, 8)
    (5): Embedding(8, 8)
    (6): Embedding(8, 8)
    (7): Embedding(264, 16)
    (8): Embedding(192, 16)
    (9): Embedding(8, 8)
    (10): Embedding(184, 16)
  )
  (embgs): ModuleList(
    (0): EmbeddingBag(664, 16, mode=mean)
    (1): EmbeddingBag(16, 8, mode=mean)
    (2): EmbeddingBag(56, 8, mode=mean)
    (3): EmbeddingBag(240, 16, mode=mean)
    (4): EmbeddingBag(16, 8, mode=mean)
    (5): EmbeddingBag(144, 8, mode=mean)
    (6): EmbeddingBag(192, 16, mode=mean)
    (7): EmbeddingBag(16, 8, mode=mean)
  )
  (input_dp): InputDropout()
  (lstm): LSTM(88, 88, num_layers=4, batch_first=True, dropout=0.3)
  (lin): Sequential(
    (0): Linear(in_features=208, out_features=416, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.3, inplace=Fal

In [16]:
import pytorch_lightning as pl

In [17]:
trainer = pl.Trainer(precision=16, accelerator='gpu', devices=-1, max_epochs=5) #, callbacks=[checkpoint_callback])

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [18]:
trainer.fit(model, train_dl, valid_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name          | Type              | Params
-----------------------------------------------------
0  | train_loss_fn | BCEWithLogitsLoss | 0     
1  | valid_loss_fn | BCEWithLogitsLoss | 0     
2  | embs          | ModuleList        | 12.0 K
3  | embgs         | ModuleList        | 19.5 K
4  | input_dp      | InputDropout      | 0     
5  | lstm          | LSTM              | 250 K 
6  | lin           | Sequential        | 7.4 M 
7  | lin_o         | Linear            | 16.6 K
8  | train_metrics | MetricCollection  | 0     
9  | valid_metrics | MetricCollection  | 0     
10 | test_metrics  | MetricCollection  | 0     
-----------------------------------------------------
7.7 M     Trainable params
0         Non-trainable params
7.7 M     Total params
15.320    Total estimated model params size (MB)


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  8.08it/s]

  rank_zero_warn(


Epoch 0:   0%|          | 0/14 [00:00<?, ?it/s]                            

  rank_zero_warn(
  rank_zero_warn(


Epoch 4: 100%|██████████| 14/14 [00:02<00:00,  5.56it/s, loss=0.871, v_num=5]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 14/14 [00:02<00:00,  5.21it/s, loss=0.871, v_num=5]


In [20]:
test_dl, test_pos_wts = coherent_data.get_test_data()
len(test_dl), test_pos_wts

(2, tensor([2., 2., 6., 1., 5.]))

In [21]:
trainer.test(model, test_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  9.60it/s]

  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 10.14it/s]


[{'test/AUROC': 0.6931756734848022}]

# Fusion Model

In [None]:
coherent_data.

# MRI

In [22]:
img = pd.read_csv(f"{COHERENT_DATA_STORE}/output/csv/imaging_studies.csv")
cnd = pd.read_csv(f"{COHERENT_DATA_STORE}/output/csv/conditions.csv")
mri_encs = img[img.MODALITY_CODE == "MR"].ENCOUNTER
#smh = Silent micro-hemorrhage
smh = cnd[cnd.ENCOUNTER.isin(mri_encs)].CODE == 723857007
assert smh.sum() == len(mri_encs)
smh.sum(), len(mri_encs)

(331, 331)