In [1]:
from fastai.imports import *
from lemonpie.basics import *
from lemonpie.preprocessing import vocab
from lemonpie.preprocessing.transform import *
from lemonpie.data import *
from lemonpie import models

import torch
import pytorch_lightning as pl

In [2]:
COHERENT_DATA_STORE = '/home/vinod/code/datasets/coherent'
COHERENT_DATAGEN_DATE = '08-10-2021'
COHERENT_CONDITIONS = {
    "heart_failure" : "88805009",
    "coronary_heart" : "53741008",
    "myocardial_infarction" : "22298006",
    "stroke" : "230690007",
    "cardiac_arrest" : "410429000"
}

In [3]:
COHERENT_LABELS = list(COHERENT_CONDITIONS.keys())
COHERENT_LABELS

['heart_failure',
 'coronary_heart',
 'myocardial_infarction',
 'stroke',
 'cardiac_arrest']

# Get data

In [11]:
coherent_data = MultimodalEHRData(
    COHERENT_DATA_STORE, 
    COHERENT_LABELS,     
    age_start=240,
    age_range=120,
    start_is_date=False,
    age_in_months=True, 
    lazy_load_gpu=True)

In [12]:
demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd = vocab.get_all_emb_dims(vocab.EhrVocabList.load(COHERENT_DATA_STORE))
dls, pos_wts = coherent_data.get_data()

In [13]:
coherent_data.modality_types

{'train': ['0', '21', '30', '31', '11', '1', '20', '10'],
 'valid': ['21', '30', '31', '11', '1', '20', '10'],
 'test': ['21', '30', '31', '11', '1', '20', '10']}

In [14]:
dls

{'train': <torch.utils.data.dataloader.DataLoader at 0x7f32d59ec910>,
 'valid': <torch.utils.data.dataloader.DataLoader at 0x7f32d59b3b50>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x7f32d58c40d0>}

In [15]:
train_dl = dls["train"]
valid_dl = dls["valid"]
test_dl = dls["test"]

# Models for each modality

In [44]:
class UnimodalModel(torch.nn.Module):
    def __init__(self, input_dims: tuple):
        super().__init__()

        # args
        self.input_dims = input_dims

        # model
        self.model = nn.Sequential(nn.Linear(10, 20),
                      nn.ReLU(),
                      nn.Linear(20, 5),
                      nn.Sigmoid())
        
    
    def forward(self, x):
        assert x[0].size() == self.input_dims
        fake_x = torch.randn(10)
        return self.model(fake_x)



In [46]:
mri_model = UnimodalModel((4,4))
mri_model(torch.randn(64,4,4))

tensor([0.4413, 0.4254, 0.5085, 0.3670, 0.5004], grad_fn=<SigmoidBackward0>)

In [47]:
dna_model = UnimodalModel((3,2))
dna_model(torch.randn(32,3,2))

tensor([0.6357, 0.4934, 0.4595, 0.5790, 0.4097], grad_fn=<SigmoidBackward0>)

In [51]:
ecg_model = UnimodalModel((5,))
ecg_model(torch.randn((1,5,)))

tensor([0.5200, 0.4877, 0.4805, 0.5705, 0.4664], grad_fn=<SigmoidBackward0>)

In [55]:
ehr_model = models.EHR_LSTM(
    demograph_dims,
    rec_dims,
    demograph_dims_wd,
    rec_dims_wd,
    len(COHERENT_LABELS),
    pos_wts["train"], 
    pos_wts["valid"],
    optim="adam",
    base_lr=0.001,
)
ehr_trainer = pl.Trainer(precision=16, accelerator='gpu', devices=-1, max_epochs=5)
# ehr_output = ehr_trainer.test(ehr_model, test_dl)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


# Late Fusion

| Modality Type | Modalities            |   
|---	        |---	                |
| **0**	        | **EHR**               |
| **1**         | EHR + **MRI**         |
| **10**        | EHR + **DNA**         |      
| 11   	        | EHR + MRI + DNA       |
| **20**        | EHR + **ECG**         |
| 21            | EHR + MRI + ECG       |
| 30            | EHR + DNA + ECG       |
| 31            | EHR + MRI + DNA + ECG |


In [80]:
test_batch = next(iter(valid_dl))    
test_batch

[[ptid:1a82483d-7eb2-d5e0-1e1f-398ba129b18b, birthdate:1936-12-22, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
  ptid:6dc8bd6b-e2a8-92bf-613d-8b477eb87d7c, birthdate:1911-12-23, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
  ptid:844a37ff-ce26-6338-fd6a-0bc1e925a702, birthdate:1933-03-15, [('heart_failure', True), ('coronary_heart', False)].., device:cpu],
 tensor([[1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.]]),
 [21, 21, 21],
 [tensor([[[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]],
  
          [[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]],
  
          [[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]]]),
  tensor([[20, 20, 20, 20, 20],
          [20, 20, 20, 20, 20],
          [20, 20, 20, 20, 20]])]]

In [81]:
pts, ys, mod_type, other = test_batch

In [82]:
output_list = []

# output_list.append(ehr_model((pts, ys)))

if mod_type[0] == 1:
    mri_input = other 
    output_list.append(mri_model(mri_input))

elif mod_type[0] == 10:
    dna_input = other 
    output_list.append(dna_model(dna_input))
    
elif mod_type[0] == 11:
    mri_input, dna_input = other 
    output_list.append(mri_model(mri_input))
    output_list.append(dna_model(dna_input))

elif mod_type[0] == 20:
    ecg_input = other 
    output_list.append(ecg_model(ecg_input))

elif mod_type[0] == 21:
    mri_input, ecg_input = other 
    output_list.append(mri_model(mri_input))
    output_list.append(ecg_model(ecg_input))

elif mod_type[0] == 30:
    dna_input, ecg_input = other 
    output_list.append(dna_model(dna_input))
    output_list.append(ecg_model(ecg_input))

elif mod_type[0] == 31:
    mri_input, dna_input, ecg_input = other
    output_list.append(mri_model(mri_input))
    output_list.append(dna_model(dna_input))
    output_list.append(ecg_model(ecg_input))


torch.mean(torch.stack(output_list), dim=0)

tensor([0.4826, 0.5259, 0.4693, 0.5262, 0.4556], grad_fn=<MeanBackward1>)

# Joint Fusion

In [None]:
class JointFusion(torch.nn.Module):
    def __init__(self, mri_model, dna_model, ecg_model):
        super().__init__()

        # args
        self.mri_model = mri_model
        self.dna_model = dna_model
        self.ecg_model = ecg_model

        # model
        self.model = nn.Sequential(nn.Linear(10, 20),
                      nn.ReLU(),
                      nn.Linear(20, 5),
                      nn.Sigmoid())
        
    
    def forward(self, x):
        assert x.size() == self.input_dims
        fake_x = torch.randn(10)
        return self.model(fake_x)

