In [1]:
from fastai.imports import *
from lemonpie.basics import *
from lemonpie.preprocessing import vocab
from lemonpie.preprocessing.transform import *
from lemonpie.data import *
from lemonpie import models

import torch
import pytorch_lightning as pl

In [2]:
COHERENT_DATA_STORE = '/home/vinod/code/datasets/coherent'
COHERENT_DATAGEN_DATE = '08-10-2021'
COHERENT_CONDITIONS = {
    "heart_failure" : "88805009",
    "coronary_heart" : "53741008",
    "myocardial_infarction" : "22298006",
    "stroke" : "230690007",
    "cardiac_arrest" : "410429000"
}

In [3]:
COHERENT_LABELS = list(COHERENT_CONDITIONS.keys())
COHERENT_LABELS

['heart_failure',
 'coronary_heart',
 'myocardial_infarction',
 'stroke',
 'cardiac_arrest']

# Get data

In [4]:
coherent_data = MultimodalEHRData(
    COHERENT_DATA_STORE, 
    COHERENT_LABELS,     
    age_start=240,
    age_range=120,
    start_is_date=False,
    age_in_months=True, 
    lazy_load_gpu=True)

In [5]:
demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd = vocab.get_all_emb_dims(vocab.EhrVocabList.load(COHERENT_DATA_STORE))
dls, pos_wts = coherent_data.get_data()

In [6]:
coherent_data.modality_types

{'train': ['0', '21', '30', '31', '11', '1', '20', '10'],
 'valid': ['21', '30', '31', '11', '1', '20', '10'],
 'test': ['21', '30', '31', '11', '1', '20', '10']}

In [7]:
dls

{'train': <torch.utils.data.dataloader.DataLoader at 0x7f04e5ffda30>,
 'valid': <torch.utils.data.dataloader.DataLoader at 0x7f04e5f9b340>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x7f04e6163550>}

In [8]:
train_dl = dls["train"]
valid_dl = dls["valid"]
test_dl = dls["test"]

# Unimodal Models - One Per Modality
These are stubs / dummys to be replaced by the real ones.

In [9]:
class UnimodalModel(torch.nn.Module):
    def __init__(self, input_dims: tuple):
        super().__init__()

        # args
        self.input_dims = input_dims

        # model
        self.model = nn.Sequential(nn.Linear(10, 20),
                      nn.ReLU(),
                      nn.Linear(20, 5),
                      nn.Sigmoid())
        
    
    def forward(self, x):
        assert x[0].size() == self.input_dims
        bs = len(x)
        fake_x = torch.randn(bs, 10)
        return self.model(fake_x)



In [10]:
mri_model = UnimodalModel((4,4))
mri_model(torch.randn(3,4,4)) # batch_size=3

tensor([[0.5353, 0.5137, 0.4573, 0.4930, 0.4591],
        [0.5549, 0.5022, 0.5493, 0.4980, 0.4605],
        [0.5374, 0.5110, 0.5124, 0.5315, 0.4823]], grad_fn=<SigmoidBackward0>)

In [11]:
dna_model = UnimodalModel((3,2))
dna_model(torch.randn(2,3,2))

tensor([[0.5637, 0.5173, 0.4554, 0.5702, 0.3944],
        [0.5217, 0.5636, 0.4956, 0.6345, 0.4506]], grad_fn=<SigmoidBackward0>)

In [12]:
ecg_model = UnimodalModel((5,))
ecg_model(torch.randn((4,5,)))

tensor([[0.4998, 0.4238, 0.4822, 0.4581, 0.4511],
        [0.5122, 0.5331, 0.7241, 0.4770, 0.3822],
        [0.4849, 0.4792, 0.5786, 0.4899, 0.4497],
        [0.4949, 0.4775, 0.5812, 0.4526, 0.4469]], grad_fn=<SigmoidBackward0>)

In [13]:
ehr_model = models.EHR_LSTM(
    demograph_dims,
    rec_dims,
    demograph_dims_wd,
    rec_dims_wd,
    len(COHERENT_LABELS),
    pos_wts["train"], 
    pos_wts["valid"],
    optim="adam",
    base_lr=0.001,
)
ehr_trainer = pl.Trainer(precision=16, accelerator='gpu', devices=-1, max_epochs=5)
# ehr_output = ehr_trainer.test(ehr_model, test_dl)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


# Unimodal Datasets

```python
class MRIDataset(torch.utils.data.Dataset):
    def __init__(self, datastore: str, tensor_sz: tuple):
        super().__init__()
        self.mri_dir = f"{datastore}/output/dicom"
        self.tensor_sz = tensor_sz
    
    def __getitem__(self, i):
        mri_fname = glob.glob(f"{self.mri_dir}/*{i}*")
        if len(mri_fname) == 1:
            return torch.full(self.tensor_sz, 1)
        else:
            raise Exception(f"MRI filename match error - found {len(mri_fname)} files with ptid: {i}.")

    def __len__(self):
        return 1
```

```python
class DNADataset(torch.utils.data.Dataset):
    def __init__(self, datastore: str, tensor_sz: tuple):
        super().__init__()
        self.dna_dir = f"{datastore}/output/dna"
        self.tensor_sz = tensor_sz
    
    def __getitem__(self, i):
        dna_fname = glob.glob(f"{self.dna_dir}/*{i}*")
        if len(dna_fname) == 1:
            return torch.full(self.tensor_sz, 10)
        else:
            raise Exception(f"DNA filename match error - found {len(dna_fname)} files with ptid: {i}.")

    def __len__(self):
        return 1
```

```python
class ECGDataset(torch.utils.data.Dataset):
    def __init__(self, datastore: str, tensor_sz: tuple):
        super().__init__()
        ecg_data = pd.read_csv(f"{datastore}/ecg.csv")
        self.ecg_pids = ecg_data.patient.unique()
        self.tensor_sz = tensor_sz
    
    def __getitem__(self, i):
        
        if i in self.ecg_pids:
            return torch.full(self.tensor_sz, 20)
        else:
            raise Exception(f"ptid: {i} - not found in ECG data.")

    def __len__(self):
        return 1
```

# Late Fusion

| Modality Type | Modalities            |   
|---	        |---	                |
| **0**	        | **EHR**               |
| **1**         | EHR + **MRI**         |
| **10**        | EHR + **DNA**         |      
| 11   	        | EHR + MRI + DNA       |
| **20**        | EHR + **ECG**         |
| 21            | EHR + MRI + ECG       |
| 30            | EHR + DNA + ECG       |
| 31            | EHR + MRI + DNA + ECG |


In [14]:
batch = next(iter(valid_dl))   
batch

[[ptid:1a82483d-7eb2-d5e0-1e1f-398ba129b18b, birthdate:1936-12-22, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
  ptid:6dc8bd6b-e2a8-92bf-613d-8b477eb87d7c, birthdate:1911-12-23, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
  ptid:844a37ff-ce26-6338-fd6a-0bc1e925a702, birthdate:1933-03-15, [('heart_failure', True), ('coronary_heart', False)].., device:cpu],
 tensor([[1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.]]),
 [21, 21, 21],
 [tensor([[[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]],
  
          [[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]],
  
          [[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]]]),
  tensor([[20, 20, 20, 20, 20],
          [20, 20, 20, 20, 20],
          [20, 20, 20, 20, 20]])]]

In [15]:
pts, ys, mod_type, other = batch

In [16]:
torch.sigmoid(ehr_model(pts))

tensor([[0.6411, 0.6585, 0.4290, 0.5529, 0.5619],
        [0.5026, 0.4996, 0.3646, 0.4571, 0.5320],
        [0.5947, 0.6313, 0.4157, 0.5130, 0.5620]], grad_fn=<SigmoidBackward0>)

In [17]:
mri_input, ecg_input = other

In [18]:
mri_model(mri_input)

tensor([[0.5030, 0.4921, 0.4873, 0.5047, 0.4875],
        [0.5078, 0.4766, 0.5463, 0.5213, 0.5173],
        [0.5839, 0.5711, 0.5591, 0.5364, 0.4534]], grad_fn=<SigmoidBackward0>)

In [19]:
ecg_model(ecg_input)

tensor([[0.4668, 0.4984, 0.5840, 0.4353, 0.3662],
        [0.5483, 0.4521, 0.5888, 0.5441, 0.3307],
        [0.5181, 0.3867, 0.4538, 0.5974, 0.3182]], grad_fn=<SigmoidBackward0>)

In [20]:
def run_late_fusion(dl, verbose=False):

    all_batches = []
    total_length = 0

    for batch in dl:

        batch_output = []
        pts, ys, mod_type, other = batch
        
        total_length += len(pts)
        if verbose:
            print(f"modality: {mod_type[0]}, length: {len(pts)}, first ptid: {pts[0].ptid}")

        batch_output.append(torch.sigmoid(ehr_model(pts)))

        if mod_type[0] == 1:
            mri_input = other 
            batch_output.append(mri_model(mri_input))

        elif mod_type[0] == 10:
            dna_input = other 
            batch_output.append(dna_model(dna_input))
            
        elif mod_type[0] == 11:
            mri_input, dna_input = other 
            batch_output.append(mri_model(mri_input))
            batch_output.append(dna_model(dna_input))

        elif mod_type[0] == 20:
            ecg_input = other 
            batch_output.append(ecg_model(ecg_input))

        elif mod_type[0] == 21:
            mri_input, ecg_input = other 
            batch_output.append(mri_model(mri_input))
            batch_output.append(ecg_model(ecg_input))

        elif mod_type[0] == 30:
            dna_input, ecg_input = other 
            batch_output.append(dna_model(dna_input))
            batch_output.append(ecg_model(ecg_input))

        elif mod_type[0] == 31:
            mri_input, dna_input, ecg_input = other
            batch_output.append(mri_model(mri_input))
            batch_output.append(dna_model(dna_input))
            batch_output.append(ecg_model(ecg_input))

        # avg across multimodal models for every patient in batch
        avgd = torch.mean(torch.stack(batch_output), dim=0)         
        all_batches.append(avgd)


    print(f"Completed {total_length} patients.")

    # flatten across batches - list of lists (batches) to flat
    output = [pt for batch in all_batches for pt in batch]
    return output

In [21]:
test_result = run_late_fusion(test_dl)

Completed 128 patients.


In [22]:
valid_result = run_late_fusion(valid_dl, verbose=True)

modality: 21, length: 3, first ptid: 1a82483d-7eb2-d5e0-1e1f-398ba129b18b
modality: 30, length: 13, first ptid: 2e1cf98c-70ce-4f8f-36da-b2eef4960ecc
modality: 31, length: 2, first ptid: 4a1a224f-54d6-66f1-3755-4c8489f2a5de
modality: 11, length: 14, first ptid: aead835e-66f2-d3b9-099c-c33844a70748
modality: 1, length: 12, first ptid: 972f6a59-3921-36bc-64bb-253d6241748b
modality: 20, length: 34, first ptid: e1875ec6-e10f-c9d8-d388-fd3abfbc4a87
modality: 10, length: 50, first ptid: 1e73e6da-68f0-0ffc-4062-a4647b5b67c4
Completed 128 patients.


In [23]:
train_result = run_late_fusion(train_dl)

Completed 1022 patients.


In [24]:
assert len(test_result) == len(valid_result) == 128
assert len(train_result) == 1022

# Joint Fusion

In [None]:
class JointFusion(torch.nn.Module):
    def __init__(self, mri_model, dna_model, ecg_model):
        super().__init__()

        # args
        self.mri_model = mri_model
        self.dna_model = dna_model
        self.ecg_model = ecg_model

        # model
        self.model = nn.Sequential(nn.Linear(10, 20),
                      nn.ReLU(),
                      nn.Linear(20, 5),
                      nn.Sigmoid())
        
    
    def forward(self, x):
        assert x.size() == self.input_dims
        fake_x = torch.randn(10)
        return self.model(fake_x)

