## Imports

In [None]:
from healnet.models import HealNet
from healnet.etl import MMDataset
import torch
import einops
from torch.utils.data import Dataset, DataLoader
from typing import *

%load_ext autoreload
%autoreload 2

## Synthetic modalities

We instantiate a synthetic multimodal dataset for demo purposes. 

In [None]:
n = 1000 # number of samples
b = 4 # batch size
img_c = 3 # image channels
tab_c = 1 # tabular channels
tab_d = 5000 # tabular features
h = 512 # image height
w = 512 # image width
n_classes = 4 # classification

tab_tensor = torch.rand(size=(n, tab_c, tab_d)) # assume 5k tabular features
img_tensor = torch.rand(size=(n, img_c, h, w)) # c h w


# derive a target
target = torch.rand(size=(n,))

In [None]:
data = MMDataset([tab_tensor, img_tensor], target)
train, test, val = torch.utils.data.random_split(data, [0.7, 0.15, 0.15]) # create 70-15-15 train-val-test split

loader_args = {
    "shuffle": True, 
    "num_workers": 8, 
    "pin_memory": True, 
    "multiprocessing_context": "fork", 
    "persistent_workers": True, 
}

train_loader = DataLoader(train, **loader_args)
val_loader = DataLoader(val, **loader_args)
test_loader = DataLoader(test, **loader_args)
# example use

In [None]:
# example use
[tab_sample, img_sample], target = data[0]

# emulate batch dimension
tab_sample = einops.repeat(tab_sample, 'c d -> b c d', b=1)
img_sample = einops.repeat(img_sample, 'c h w -> b c (h w)', b=1)

In [None]:
img_sample.shape

In [None]:
model = HealNet(
            modalities=2, 
            input_channels=[tab_c, img_c], 
            input_axes=[1, 1], # channel axes (0-indexed)
            num_classes = n_classes,  
        )

In [None]:
# forward pass
model([tab_sample, img_sample])