In [1]:
from fastai.imports import *
from lemonpie.basics import *
from lemonpie.preprocessing import vocab
from lemonpie.preprocessing.transform import *
from lemonpie.data import *
from lemonpie import models

import torch
import pytorch_lightning as pl
from torchmetrics import MetricCollection, AUROC, Accuracy


In [2]:
COHERENT_DATA_STORE = '/home/vinod/code/datasets/coherent'
COHERENT_DATAGEN_DATE = '08-10-2021'
COHERENT_CONDITIONS = {
    "heart_failure" : "88805009",
    "coronary_heart" : "53741008",
    "myocardial_infarction" : "22298006",
    "stroke" : "230690007",
    "cardiac_arrest" : "410429000"
}

In [3]:
COHERENT_LABELS = list(COHERENT_CONDITIONS.keys())
COHERENT_LABELS

['heart_failure',
 'coronary_heart',
 'myocardial_infarction',
 'stroke',
 'cardiac_arrest']

# Get data

In [4]:
coherent_data = MultimodalEHRData(
    COHERENT_DATA_STORE, 
    COHERENT_LABELS,     
    age_start=240,
    age_range=120,
    start_is_date=False,
    age_in_months=True, 
    lazy_load_gpu=True)

In [5]:
demograph_dims, rec_dims, demograph_dims_wd, rec_dims_wd = vocab.get_all_emb_dims(vocab.EhrVocabList.load(COHERENT_DATA_STORE))
dls, pos_wts = coherent_data.get_data()

In [6]:
coherent_data.modality_types

{'train': ['0', '21', '30', '31', '11', '1', '20', '10'],
 'valid': ['21', '30', '31', '11', '1', '20', '10'],
 'test': ['21', '30', '31', '11', '1', '20', '10']}

In [7]:
dls

{'train': <torch.utils.data.dataloader.DataLoader at 0x7f874511d250>,
 'valid': <torch.utils.data.dataloader.DataLoader at 0x7f8744f2e8e0>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x7f8745106580>}

In [8]:
train_dl = dls["train"]
valid_dl = dls["valid"]
test_dl = dls["test"]

# Unimodal Models - One Per Modality
These are stubs / dummys to be replaced by the real ones.

In [9]:
class UnimodalModel(pl.LightningModule):
    def __init__(self, input_dims: tuple):
        super().__init__()

        # args
        self.input_dims = input_dims

        # model
        self.model = nn.Sequential(
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 30),
            nn.ReLU(),
            nn.Linear(30, 5),
            nn.Sigmoid()
            )
        
    
    def forward(self, x):
        assert x[0].size() == self.input_dims, f"Expected {self.input_dims}, got {x[0].size()}"
        bs = len(x)
        fake_x = torch.randn((bs, 10), device=self.device)
        return self.model(fake_x)

    def training_step(self, *args, **kwargs):
        return super().training_step(*args, **kwargs)

    def configure_optimizers(self):
        return super().configure_optimizers()



## MRI

In [194]:
mri_model = UnimodalModel((4,4))
mri_model(torch.randn(3,4,4)) # batch_size=3

tensor([[0.5356, 0.5292, 0.4622, 0.5050, 0.5145],
        [0.5703, 0.5492, 0.4860, 0.4505, 0.4999],
        [0.5437, 0.5564, 0.4891, 0.4984, 0.5146]], grad_fn=<SigmoidBackward0>)

In [195]:
mri_trainer = pl.Trainer(max_epochs=3, accelerator='gpu', devices=1)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [196]:
mri_trainer.fit(mri_model, train_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 1.0 K 
-------------------------------------
1.0 K     Trainable params
0         Non-trainable params
1.0 K     Total params
0.004     Total estimated model params size (MB)


Epoch 2: 100%|██████████| 21/21 [00:03<00:00,  5.27it/s, loss=nan, v_num=13]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 21/21 [00:03<00:00,  5.26it/s, loss=nan, v_num=13]


In [197]:
mri_trainer.save_checkpoint("./mri_model.pth")

## DNA

In [198]:
dna_model = UnimodalModel((3,2))
dna_model(torch.randn(2,3,2)) 

tensor([[0.4793, 0.4843, 0.5255, 0.5074, 0.5071],
        [0.4920, 0.4513, 0.5616, 0.5017, 0.4978]], grad_fn=<SigmoidBackward0>)

In [199]:
dna_trainer = pl.Trainer(max_epochs=3, accelerator='gpu', devices=1)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [200]:
dna_trainer.fit(dna_model, train_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 1.0 K 
-------------------------------------
1.0 K     Trainable params
0         Non-trainable params
1.0 K     Total params
0.004     Total estimated model params size (MB)


Epoch 2: 100%|██████████| 21/21 [00:03<00:00,  5.40it/s, loss=nan, v_num=14]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 21/21 [00:03<00:00,  5.39it/s, loss=nan, v_num=14]


In [201]:
dna_trainer.save_checkpoint("./dna_model.pth")

## ECG

In [202]:
ecg_model = UnimodalModel((5,))
ecg_model(torch.randn((4,5,)))

tensor([[0.5351, 0.5408, 0.5217, 0.4983, 0.4411],
        [0.5453, 0.5590, 0.5234, 0.5139, 0.4308],
        [0.5662, 0.5273, 0.4979, 0.4695, 0.4290],
        [0.5508, 0.5548, 0.5337, 0.4901, 0.4438]], grad_fn=<SigmoidBackward0>)

In [203]:
ecg_trainer = pl.Trainer(max_epochs=3, accelerator='gpu', devices=1)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [204]:
ecg_trainer.fit(ecg_model, train_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 1.0 K 
-------------------------------------
1.0 K     Trainable params
0         Non-trainable params
1.0 K     Total params
0.004     Total estimated model params size (MB)


Epoch 2: 100%|██████████| 21/21 [00:04<00:00,  5.24it/s, loss=nan, v_num=15]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 21/21 [00:04<00:00,  5.23it/s, loss=nan, v_num=15]


In [205]:
ecg_trainer.save_checkpoint("./ecg_model.pth")

## EHR

In [207]:
ehr_model = models.EHR_LSTM(
    demograph_dims,
    rec_dims,
    demograph_dims_wd,
    rec_dims_wd,
    len(COHERENT_LABELS),
    pos_wts["train"], 
    pos_wts["valid"],
    optim="adam",
    base_lr=0.001,
)


In [208]:
ehr_trainer = pl.Trainer(precision=16, accelerator='gpu', devices=-1, max_epochs=3)
# ehr_output = ehr_trainer.test(ehr_model, test_dl)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [209]:
ehr_trainer.fit(ehr_model, train_dl, valid_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type              | Params
----------------------------------------------------
0 | train_loss_fn | BCEWithLogitsLoss | 0     
1 | valid_loss_fn | BCEWithLogitsLoss | 0     
2 | embs          | ModuleList        | 12.0 K
3 | embgs         | ModuleList        | 19.5 K
4 | input_dp      | InputDropout      | 0     
5 | lstm          | LSTM              | 250 K 
6 | linear        | Sequential        | 7.4 M 
7 | train_metrics | MetricCollection  | 0     
8 | valid_metrics | MetricCollection  | 0     
9 | test_metrics  | MetricCollection  | 0     
----------------------------------------------------
7.7 M     Trainable params
0         Non-trainable params
7.7 M     Total params
15.320    Total estimated model params size (MB)


Epoch 2: 100%|██████████| 28/28 [00:09<00:00,  2.92it/s, loss=1.03, v_num=16]

`Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 2: 100%|██████████| 28/28 [00:09<00:00,  2.82it/s, loss=1.03, v_num=16]


- [Lightning warning details](https://github.com/Lightning-AI/lightning/issues/10349#issuecomment-961340903)

In [210]:
ehr_trainer.save_checkpoint("./ehr_model.pth")

# Unimodal Datasets

```python
class MRIDataset(torch.utils.data.Dataset):
    def __init__(self, datastore: str, tensor_sz: tuple):
        super().__init__()
        self.mri_dir = f"{datastore}/output/dicom"
        self.tensor_sz = tensor_sz
    
    def __getitem__(self, i):
        mri_fname = glob.glob(f"{self.mri_dir}/*{i}*")
        if len(mri_fname) == 1:
            return torch.full(self.tensor_sz, 1)
        else:
            raise Exception(f"MRI filename match error - found {len(mri_fname)} files with ptid: {i}.")

    def __len__(self):
        return 1
```

```python
class DNADataset(torch.utils.data.Dataset):
    def __init__(self, datastore: str, tensor_sz: tuple):
        super().__init__()
        self.dna_dir = f"{datastore}/output/dna"
        self.tensor_sz = tensor_sz
    
    def __getitem__(self, i):
        dna_fname = glob.glob(f"{self.dna_dir}/*{i}*")
        if len(dna_fname) == 1:
            return torch.full(self.tensor_sz, 10)
        else:
            raise Exception(f"DNA filename match error - found {len(dna_fname)} files with ptid: {i}.")

    def __len__(self):
        return 1
```

```python
class ECGDataset(torch.utils.data.Dataset):
    def __init__(self, datastore: str, tensor_sz: tuple):
        super().__init__()
        ecg_data = pd.read_csv(f"{datastore}/ecg.csv")
        self.ecg_pids = ecg_data.patient.unique()
        self.tensor_sz = tensor_sz
    
    def __getitem__(self, i):
        
        if i in self.ecg_pids:
            return torch.full(self.tensor_sz, 20)
        else:
            raise Exception(f"ptid: {i} - not found in ECG data.")

    def __len__(self):
        return 1
```

# Late Fusion

| Modality Type | Modalities            |   
|---	        |---	                |
| **0**	        | **EHR**               |
| **1**         | EHR + **MRI**         |
| **10**        | EHR + **DNA**         |      
| 11   	        | EHR + MRI + DNA       |
| **20**        | EHR + **ECG**         |
| 21            | EHR + MRI + ECG       |
| 30            | EHR + DNA + ECG       |
| 31            | EHR + MRI + DNA + ECG |


In [16]:
batch = next(iter(valid_dl))   
batch

[[ptid:1a82483d-7eb2-d5e0-1e1f-398ba129b18b, birthdate:1936-12-22, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
  ptid:6dc8bd6b-e2a8-92bf-613d-8b477eb87d7c, birthdate:1911-12-23, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
  ptid:844a37ff-ce26-6338-fd6a-0bc1e925a702, birthdate:1933-03-15, [('heart_failure', True), ('coronary_heart', False)].., device:cpu],
 tensor([[1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.]]),
 [21, 21, 21],
 [tensor([[[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]],
  
          [[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]],
  
          [[1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1],
           [1, 1, 1, 1]]]),
  tensor([[20, 20, 20, 20, 20],
          [20, 20, 20, 20, 20],
          [20, 20, 20, 20, 20]])]]

In [17]:
pts, ys, mod_type, other = batch

In [18]:
torch.sigmoid(ehr_model(pts))

tensor([[0.5317, 0.5461, 0.4880, 0.4909, 0.4987],
        [0.5642, 0.5709, 0.4845, 0.5405, 0.4524],
        [0.5330, 0.4908, 0.3036, 0.4859, 0.5562]], grad_fn=<SigmoidBackward0>)

In [19]:
mri_input, ecg_input = other

In [20]:
mri_model(mri_input)

NameError: name 'mri_model' is not defined

In [None]:
ecg_model(ecg_input)

tensor([[0.4668, 0.4984, 0.5840, 0.4353, 0.3662],
        [0.5483, 0.4521, 0.5888, 0.5441, 0.3307],
        [0.5181, 0.3867, 0.4538, 0.5974, 0.3182]], grad_fn=<SigmoidBackward0>)

In [21]:
def run_late_fusion(dl, verbose=False):

    all_batches = []
    total_length = 0

    for batch in dl:

        batch_output = []
        pts, ys, mod_type, other = batch
        
        total_length += len(pts)
        if verbose:
            print(f"modality: {mod_type[0]}, length: {len(pts)}, first ptid: {pts[0].ptid}")

        batch_output.append(torch.sigmoid(ehr_model(pts)))

        if mod_type[0] == 1:
            mri_input = other 
            batch_output.append(mri_model(mri_input))

        elif mod_type[0] == 10:
            dna_input = other 
            batch_output.append(dna_model(dna_input))
            
        elif mod_type[0] == 11:
            mri_input, dna_input = other 
            batch_output.append(mri_model(mri_input))
            batch_output.append(dna_model(dna_input))

        elif mod_type[0] == 20:
            ecg_input = other 
            batch_output.append(ecg_model(ecg_input))

        elif mod_type[0] == 21:
            mri_input, ecg_input = other 
            batch_output.append(mri_model(mri_input))
            batch_output.append(ecg_model(ecg_input))

        elif mod_type[0] == 30:
            dna_input, ecg_input = other 
            batch_output.append(dna_model(dna_input))
            batch_output.append(ecg_model(ecg_input))

        elif mod_type[0] == 31:
            mri_input, dna_input, ecg_input = other
            batch_output.append(mri_model(mri_input))
            batch_output.append(dna_model(dna_input))
            batch_output.append(ecg_model(ecg_input))

        # avg across multimodal models for every patient in batch
        avgd = torch.mean(torch.stack(batch_output), dim=0)         
        all_batches.append(avgd)


    print(f"Completed {total_length} patients.")

    # flatten across batches - list of lists (batches) to flat
    output = [pt for batch in all_batches for pt in batch]
    return output

In [22]:
test_result = run_late_fusion(test_dl)

NameError: name 'mri_model' is not defined

In [22]:
valid_result = run_late_fusion(valid_dl, verbose=True)

modality: 21, length: 3, first ptid: 1a82483d-7eb2-d5e0-1e1f-398ba129b18b
modality: 30, length: 13, first ptid: 2e1cf98c-70ce-4f8f-36da-b2eef4960ecc
modality: 31, length: 2, first ptid: 4a1a224f-54d6-66f1-3755-4c8489f2a5de
modality: 11, length: 14, first ptid: aead835e-66f2-d3b9-099c-c33844a70748
modality: 1, length: 12, first ptid: 972f6a59-3921-36bc-64bb-253d6241748b
modality: 20, length: 34, first ptid: e1875ec6-e10f-c9d8-d388-fd3abfbc4a87
modality: 10, length: 50, first ptid: 1e73e6da-68f0-0ffc-4062-a4647b5b67c4
Completed 128 patients.


In [23]:
train_result = run_late_fusion(train_dl)

Completed 1022 patients.


In [24]:
assert len(test_result) == len(valid_result) == 128
assert len(train_result) == 1022

# Joint Fusion

In [48]:
pts, y, m, other = next(iter(valid_dl))

In [49]:
pts

[ptid:1a82483d-7eb2-d5e0-1e1f-398ba129b18b, birthdate:1936-12-22, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
 ptid:6dc8bd6b-e2a8-92bf-613d-8b477eb87d7c, birthdate:1911-12-23, [('heart_failure', True), ('coronary_heart', False)].., device:cpu,
 ptid:844a37ff-ce26-6338-fd6a-0bc1e925a702, birthdate:1933-03-15, [('heart_failure', True), ('coronary_heart', False)].., device:cpu]

In [66]:
m[0]

21

## MRI

In [52]:
mri_model = UnimodalModel.load_from_checkpoint("./mri_model_pth", input_dims=(4,4))

In [53]:
mri_layers_trunc = list(mri_model.model.children())[:-2]

In [54]:
mri_model.model = nn.Sequential(*mri_layers_trunc)

In [55]:
# mri_model

In [56]:
mri_repr = mri_model(mri_input)

In [57]:
mri_repr.shape

torch.Size([3, 30])

## EHR

In [58]:
ehr_model = models.EHR_LSTM.load_from_checkpoint(
    "./ehr_model_pth",
    demograph_dims=demograph_dims,
    rec_dims=rec_dims,
    demograph_dims_wd=demograph_dims_wd,
    rec_dims_wd=rec_dims_wd,
    labels=len(COHERENT_LABELS),
    train_pos_wts=pos_wts["train"], 
    valid_pos_wts=pos_wts["valid"],
    optim="adam",
    base_lr=0.001,
)



In [59]:
ehr_layers_trunc = list(ehr_model.linear.children())[:-2]

In [60]:
ehr_model.linear = nn.Sequential(*ehr_layers_trunc)

In [61]:
# ehr_model

In [62]:
ehr_repr = ehr_model(pts)

In [63]:
ehr_repr.shape

torch.Size([3, 3328])

## Joint Model

In [64]:
concated = torch.concat((ehr_repr, mri_repr), dim=1)

In [65]:
concated.shape

torch.Size([3, 3358])

In [20]:
class JointFusion(pl.LightningModule):
    def __init__(self,
            demograph_dims,
            rec_dims,
            demograph_wd,
            rec_wd,
            num_labels,
            train_pos_wts,
            valid_pos_wts
            ):

        super().__init__()

        ## args
        self.train_loss_fn = nn.BCEWithLogitsLoss(pos_weight=train_pos_wts)
        self.valid_loss_fn = nn.BCEWithLogitsLoss(pos_weight=valid_pos_wts)


        # ehr
        self.ehr_model = models.EHR_LSTM.load_from_checkpoint(
            "./ehr_model.pth",
            demograph_dims=demograph_dims,
            rec_dims=rec_dims,
            demograph_dims_wd=demograph_dims_wd,
            rec_dims_wd=rec_dims_wd,
            labels=len(COHERENT_LABELS),
            train_pos_wts=pos_wts["train"], 
            valid_pos_wts=pos_wts["valid"],
            optim="adam",
            base_lr=0.001,
        )
        ehr_layers_trunc = list(self.ehr_model.linear.children())[:-2]
        self.ehr_model.linear = nn.Sequential(*ehr_layers_trunc)

        # mri
        self.mri_model = UnimodalModel.load_from_checkpoint("./mri_model.pth", input_dims=(4,4))
        mri_layers_trunc = list(self.mri_model.model.children())[:-2]
        self.mri_model.model = nn.Sequential(*mri_layers_trunc)
        
        # dna
        self.dna_model = UnimodalModel.load_from_checkpoint("./dna_model.pth", input_dims=(3,2))
        dna_layers_trunc = list(self.dna_model.model.children())[:-2]
        self.dna_model.model = nn.Sequential(*dna_layers_trunc)

        # ecg
        self.ecg_model = UnimodalModel.load_from_checkpoint("./ecg_model.pth", input_dims=(5,))
        ecg_layers_trunc = list(self.ecg_model.model.children())[:-2]
        self.ecg_model.model = nn.Sequential(*ecg_layers_trunc)

        # model
        self.repr_dims = {}
        self.repr_dims["ehr"] = self.ehr_model.repr_dim
        self.repr_dims["mri"] = 30
        self.repr_dims["dna"] = 30
        self.repr_dims["ecg"] = 30
        self.repr_dims["total"] = sum(self.repr_dims.values())

        self.model = nn.Sequential(nn.Linear(self.repr_dims["total"], 4000),
                      nn.ReLU(),
                      nn.Linear(4000, 5000),
                      nn.ReLU(),
                      nn.Linear(5000, 5),
                      nn.Sigmoid())
        
        ## metrics
        metrics = MetricCollection(
            [
                # Accuracy(),
                AUROC(num_classes=num_labels, pos_label=1, average="micro"),
                # Recall(),
                # Precision(),
                # AveragePrecision(num_classes)
            ]
        )
        self.train_metrics = metrics.clone(prefix="train/")
        self.valid_metrics = metrics.clone(prefix="valid/")
        self.test_metrics = metrics.clone(prefix="test/")


    
    def forward(self, batch):
        pts = batch["patients"]
        pts = [pt.to_gpu(non_block=True) for pt in pts]

        concated_reprs = torch.zeros((len(pts), self.repr_dims["total"]), device=self.device)

        # ehr
        concated_reprs[:, :self.ehr_model.repr_dim] = self.ehr_model(pts)
            
        # other modalities 
        if "mri" in batch.keys():
            # mri_input.to(self.device)
            concated_reprs[:, 3328:3358] = self.mri_model(batch["mri"])
        if "dna" in batch.keys():
            # dna_input.to(self.device)
            concated_reprs[:, 3358:3388] = self.dna_model(batch["dna"])
        if "ecg" in batch.keys():
            # ecg_input.to(self.device)
            concated_reprs[:, 3388:] = self.ecg_model(batch["ecg"])

        # send through fusion
        # concated_reprs.to(self.device, non_blocking=True)
        return self.model(concated_reprs)

    def training_step(self, batch, batch_idx):
        yb = batch["ys"]
        y_hat = self(batch)
        train_loss = self.train_loss_fn(y_hat, yb)

        self.log("train_loss", train_loss, on_step=True, on_epoch=True)
        self.train_metrics.update(y_hat, yb.int())
        self.log_dict(self.train_metrics.compute(), on_step=False, on_epoch=True)

        return train_loss

    def validation_step(self, batch, batch_idx):
        yb = batch["ys"]
        y_hat = self(batch)
        valid_loss = self.valid_loss_fn(y_hat, yb)

        self.log("valid_loss", valid_loss, on_step=True, on_epoch=True)
        self.valid_metrics.update(y_hat, yb.int())
        self.log_dict(self.valid_metrics.compute(), on_step=False, on_epoch=True)

        return valid_loss

    def test_step(self, batch, batch_idx):
        yb = batch["ys"]
        y_hat = self(batch)
        
        self.test_metrics.update(y_hat, yb.int())
        self.log_dict(self.test_metrics.compute(), on_step=False, on_epoch=True)

        return 
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())
        return optimizer



In [21]:
fusion_model = JointFusion(
    demograph_dims,
    rec_dims,
    demograph_dims_wd,
    rec_dims_wd,
    len(COHERENT_LABELS),
    pos_wts["train"], 
    pos_wts["valid"])



In [22]:
fusion_trainer = pl.Trainer(precision=16, accelerator='gpu', devices=-1, max_epochs=5)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [23]:
fusion_trainer.fit(fusion_model, train_dataloaders=valid_dl)

  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type              | Params
----------------------------------------------------
0 | train_loss_fn | BCEWithLogitsLoss | 0     
1 | valid_loss_fn | BCEWithLogitsLoss | 0     
2 | ehr_model     | EHR_LSTM          | 7.6 M 
3 | mri_model     | UnimodalModel     | 850   
4 | dna_model     | UnimodalModel     | 850   
5 | ecg_model     | UnimodalModel     | 850   
6 | model         | Sequential        | 33.7 M
7 | train_metrics | MetricCollection  | 0     
8 | valid_metrics | MetricCollection  | 0     
9 | test_metrics  | MetricCollection  | 0     
----------------------------------------------------
41.4 M    Trainable params
0         Non-trainable params
41.4 M    Total params
82.704    Total estimated model params size (MB)


Epoch 0:   0%|          | 0/7 [00:00<?, ?it/s] 

  rank_zero_warn(


Epoch 4: 100%|██████████| 7/7 [00:01<00:00,  4.28it/s, loss=0.993, v_num=23]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 7/7 [00:02<00:00,  2.72it/s, loss=0.993, v_num=23]


In [24]:
fusion_trainer.test(fusion_model, test_dl)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 8/8 [00:00<00:00, 16.96it/s]


[{'test/AUROC': 0.55997633934021}]