In [31]:
from fastai.imports import *
from lemonpie.basics import *
from lemonpie.preprocessing import vocab
from lemonpie.preprocessing.transform import *
from lemonpie.data import *
from lemonpie import models

import torch
import pytorch_lightning as pl

In [2]:
COHERENT_DATA_STORE = '/home/vinod/code/datasets/coherent'
COHERENT_DATAGEN_DATE = '08-10-2021'
COHERENT_CONDITIONS = {
    "heart_failure" : "88805009",
    "coronary_heart" : "53741008",
    "myocardial_infarction" : "22298006",
    "stroke" : "230690007",
    "cardiac_arrest" : "410429000"
}

In [3]:
COHERENT_LABELS = list(COHERENT_CONDITIONS.keys())
COHERENT_LABELS

['heart_failure',
 'coronary_heart',
 'myocardial_infarction',
 'stroke',
 'cardiac_arrest']

# Get data

In [4]:
coherent_data = MultimodalEHRData(
    COHERENT_DATA_STORE, 
    COHERENT_LABELS,     
    age_start=240,
    age_range=120,
    start_is_date=False,
    age_in_months=True, 
    lazy_load_gpu=True)

In [5]:
demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd = vocab.get_all_emb_dims(vocab.EhrVocabList.load(COHERENT_DATA_STORE))
dls, pos_wts = coherent_data.get_data()

In [6]:
coherent_data.modality_types

{'train': ['0', '21', '30', '31', '11', '1', '20', '10'],
 'valid': ['21', '30', '31', '11', '1', '20', '10'],
 'test': ['21', '30', '31', '11', '1', '20', '10']}

In [7]:
dls

{'train': <torch.utils.data.dataloader.DataLoader at 0x7fb4d3dc0340>,
 'valid': <torch.utils.data.dataloader.DataLoader at 0x7fb4d3e1b880>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x7fb4d3dd90a0>}

In [8]:
train_dl = dls["train"]
valid_dl = dls["valid"]
test_dl = dls["test"]

# Models for each modality

In [9]:
class UnimodalModel(torch.nn.Module):
    def __init__(self, input_dims: tuple):
        super().__init__()

        # args
        self.input_dims = input_dims

        # model
        self.model = nn.Sequential(nn.Linear(10, 20),
                      nn.ReLU(),
                      nn.Linear(20, 5),
                      nn.Sigmoid())
        
    
    def forward(self, x):
        assert x[0].size() == self.input_dims
        bs = len(x)
        fake_x = torch.randn(bs, 10)
        return self.model(fake_x)



In [10]:
mri_model = UnimodalModel((4,4))
mri_model(torch.randn(3,4,4)) # batch_size=3

tensor([[0.5871, 0.4717, 0.4356, 0.5364, 0.5261],
        [0.5878, 0.5089, 0.4622, 0.5733, 0.5609],
        [0.4698, 0.4853, 0.5092, 0.6191, 0.5108]], grad_fn=<SigmoidBackward0>)

In [11]:
dna_model = UnimodalModel((3,2))
dna_model(torch.randn(2,3,2))

tensor([[0.5298, 0.4821, 0.4820, 0.5009, 0.5471],
        [0.6797, 0.5119, 0.4718, 0.4024, 0.5689]], grad_fn=<SigmoidBackward0>)

In [12]:
ecg_model = UnimodalModel((5,))
ecg_model(torch.randn((4,5,)))

tensor([[0.4962, 0.5424, 0.5331, 0.4642, 0.5518],
        [0.3958, 0.3892, 0.6022, 0.3669, 0.4795],
        [0.4391, 0.5904, 0.5600, 0.5358, 0.6003],
        [0.5202, 0.4764, 0.5483, 0.4702, 0.5856]], grad_fn=<SigmoidBackward0>)

In [13]:
ehr_model = models.EHR_LSTM(
    demograph_dims,
    rec_dims,
    demograph_dims_wd,
    rec_dims_wd,
    len(COHERENT_LABELS),
    pos_wts["train"], 
    pos_wts["valid"],
    optim="adam",
    base_lr=0.001,
)
ehr_trainer = pl.Trainer(precision=16, accelerator='gpu', devices=-1, max_epochs=5)
# ehr_output = ehr_trainer.test(ehr_model, test_dl)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


# Late Fusion

| Modality Type | Modalities            |   
|---	        |---	                |
| **0**	        | **EHR**               |
| **1**         | EHR + **MRI**         |
| **10**        | EHR + **DNA**         |      
| 11   	        | EHR + MRI + DNA       |
| **20**        | EHR + **ECG**         |
| 21            | EHR + MRI + ECG       |
| 30            | EHR + DNA + ECG       |
| 31            | EHR + MRI + DNA + ECG |


In [16]:
batch = next(iter(valid_dl))   
batch

[[ptid:1a82483d-7eb2-d5e0-1e1f-398ba129b18b, birthdate:1936-12-22, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
  ptid:6dc8bd6b-e2a8-92bf-613d-8b477eb87d7c, birthdate:1911-12-23, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
  ptid:844a37ff-ce26-6338-fd6a-0bc1e925a702, birthdate:1933-03-15, [('heart_failure', True), ('coronary_heart', False)].., device:cpu],
 tensor([[1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.]]),
 [21, 21, 21],
 [tensor([[[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]],
  
          [[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]],
  
          [[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]]]),
  tensor([[20, 20, 20, 20, 20],
          [20, 20, 20, 20, 20],
          [20, 20, 20, 20, 20]])]]

In [17]:
pts, ys, mod_type, other = batch

In [30]:
torch.sigmoid(ehr_model(pts))

tensor([[0.5097, 0.4871, 0.4612, 0.4247, 0.5785],
        [0.4101, 0.5711, 0.6798, 0.4514, 0.6669],
        [0.4152, 0.5649, 0.5485, 0.5303, 0.6301]], grad_fn=<SigmoidBackward0>)

In [32]:
mri_input, ecg_input = other

In [33]:
mri_model(mri_input)

tensor([[0.6033, 0.4849, 0.4188, 0.5446, 0.5866],
        [0.5750, 0.4802, 0.4644, 0.5487, 0.5825],
        [0.4990, 0.4340, 0.4549, 0.5617, 0.3821]], grad_fn=<SigmoidBackward0>)

In [34]:
ecg_model(ecg_input)

tensor([[0.4335, 0.5439, 0.5146, 0.4945, 0.5128],
        [0.5305, 0.4065, 0.6597, 0.3474, 0.5529],
        [0.4288, 0.5427, 0.4915, 0.5302, 0.6858]], grad_fn=<SigmoidBackward0>)

In [60]:
def run_late_fusion(dl, verbose=False):

    all_batches = []
    total_length = 0

    for batch in dl:

        batch_output = []
        pts, ys, mod_type, other = batch
        
        total_length += len(pts)
        if verbose:
            print(f"modality: {mod_type[0]}, length: {len(pts)}, first ptid: {pts[0].ptid}")

        batch_output.append(torch.sigmoid(ehr_model(pts)))

        if mod_type[0] == 1:
            mri_input = other 
            batch_output.append(mri_model(mri_input))

        elif mod_type[0] == 10:
            dna_input = other 
            batch_output.append(dna_model(dna_input))
            
        elif mod_type[0] == 11:
            mri_input, dna_input = other 
            batch_output.append(mri_model(mri_input))
            batch_output.append(dna_model(dna_input))

        elif mod_type[0] == 20:
            ecg_input = other 
            batch_output.append(ecg_model(ecg_input))

        elif mod_type[0] == 21:
            mri_input, ecg_input = other 
            batch_output.append(mri_model(mri_input))
            batch_output.append(ecg_model(ecg_input))

        elif mod_type[0] == 30:
            dna_input, ecg_input = other 
            batch_output.append(dna_model(dna_input))
            batch_output.append(ecg_model(ecg_input))

        elif mod_type[0] == 31:
            mri_input, dna_input, ecg_input = other
            batch_output.append(mri_model(mri_input))
            batch_output.append(dna_model(dna_input))
            batch_output.append(ecg_model(ecg_input))

        # avg across multimodal models for every patient in batch
        avgd = torch.mean(torch.stack(batch_output), dim=0)         
        all_batches.append(avgd)


    print(f"Completed {total_length} patients.")

    # flatten across batches - list of lists (batches) to flat
    output = [pt for batch in all_batches for pt in batch]
    return output

In [65]:
test_result = run_late_fusion(test_dl)

Completed 128 patients.


In [63]:
valid_result = run_late_fusion(valid_dl, verbose=True)

modality: 21, length: 3, first ptid: 1a82483d-7eb2-d5e0-1e1f-398ba129b18b
modality: 30, length: 13, first ptid: 2e1cf98c-70ce-4f8f-36da-b2eef4960ecc
modality: 31, length: 2, first ptid: 4a1a224f-54d6-66f1-3755-4c8489f2a5de
modality: 11, length: 14, first ptid: aead835e-66f2-d3b9-099c-c33844a70748
modality: 1, length: 12, first ptid: 972f6a59-3921-36bc-64bb-253d6241748b
modality: 20, length: 34, first ptid: e1875ec6-e10f-c9d8-d388-fd3abfbc4a87
modality: 10, length: 50, first ptid: 1e73e6da-68f0-0ffc-4062-a4647b5b67c4
Completed 128 patients.


In [64]:
train_result = run_late_fusion(train_dl)

Completed 1022 patients.


In [66]:
assert len(test_result) == len(valid_result) == 128
assert len(train_result) == 1022

# Joint Fusion

In [None]:
class JointFusion(torch.nn.Module):
    def __init__(self, mri_model, dna_model, ecg_model):
        super().__init__()

        # args
        self.mri_model = mri_model
        self.dna_model = dna_model
        self.ecg_model = ecg_model

        # model
        self.model = nn.Sequential(nn.Linear(10, 20),
                      nn.ReLU(),
                      nn.Linear(20, 5),
                      nn.Sigmoid())
        
    
    def forward(self, x):
        assert x.size() == self.input_dims
        fake_x = torch.randn(10)
        return self.model(fake_x)

