## Imports

In [21]:
from healnet.models import HealNet
from torch import nn
import torch
import einops
from torch.utils.data import Dataset, DataLoader
import numpy as np
import scipy
from typing import *

## Synthetic modalities

We instantiate a synthetic multimodal dataset for demo purposes. 

In [47]:
n = 1000 # number of samples
b = 4 # batch size
img_c = 3 # image channels
tab_c = 1 # tabular channels
h = 512 # image height
w = 512 # image width

tab_tensor = torch.rand(size=(n, 5000)) # assume 5k tabular features
img_tensor = torch.rand(size=(n, img_c, h, w)) # c h w

# derive a target
target = torch.rand(size=(n,))


In [51]:
class MMDataset(Dataset): 
    def __init__(self, tensors: List[torch.Tensor], target: torch.Tensor):
        self.tensors = tensors
        self.target = target
        
    def __getitem__(self, idx) -> Tuple[List[torch.Tensor], torch.Tensor]:
        return [t[idx] for t in self.tensors], self.target[idx]
    
    def __len__(self):
        return target.size()[0]

data = MMDataset([tab_tensor, img_tensor], target)
train, test, val = torch.utils.data.random_split(data, [0.7, 0.15, 0.15]) # create 70-15-15 train-val-test split

loader_args = {
    "shuffle": True, 
    "num_workers": 8, 
    "pin_memory": True, 
    "multiprocessing_context": "fork", 
    "persistent_workers": True, 
}

train_loader = DataLoader(train, **loader_args)
val_loader = DataLoader(val, **loader_args)
test_loader = DataLoader(test, **loader_args)
# example use

In [38]:
# example use
[tab_sample, img_sample], target = data[0]