## Imports

In [1]:
from healnet.models import HealNet
from healnet.etl import MMDataset
import torch
import einops
from torch.utils.data import Dataset, DataLoader
from typing import *

%load_ext autoreload
%autoreload 2

## Synthetic modalities

We instantiate a synthetic multimodal dataset for demo purposes. 

In [2]:
n = 1000 # number of samples
b = 4 # batch size

# latent channels x dims
l_c = 256
l_d = 512

# 2D image
img_c = 3 # image channels
h = 512 # image height
w = 512 # image width


# 1D tabular
tab_c = 1 # tabular channels
tab_d = 5000 # tabular features
# 
n_classes = 4 # classification

tab_tensor = torch.rand(size=(n, tab_c, tab_d)) # assume 5k tabular features
img_tensor = torch.rand(size=(n, img_c, h, w)) # c h w


# derive a target
target = torch.rand(size=(n,))

In [3]:
data = MMDataset([tab_tensor, img_tensor], target)
train, test, val = torch.utils.data.random_split(data, [0.7, 0.15, 0.15]) # create 70-15-15 train-val-test split

loader_args = {
    "shuffle": True, 
    "num_workers": 8, 
    "pin_memory": True, 
    "multiprocessing_context": "fork", 
    "persistent_workers": True, 
}

train_loader = DataLoader(train, **loader_args)
val_loader = DataLoader(val, **loader_args)
test_loader = DataLoader(test, **loader_args)
# example use

In [8]:
# example use
[tab_sample, img_sample], target = data[0]

# emulate batch dimension
tab_sample = einops.repeat(tab_sample, 'c d -> b c d', b=b)
img_sample = einops.repeat(img_sample, 'c h w -> b c (h w)', b=b)

In [9]:
img_sample.shape

torch.Size([4, 3, 262144])

In [12]:
model = HealNet(
            modalities=2, 
            input_channels=[tab_c, img_c], 
            input_axes=[1, 1], # channel axes (0-indexed)
            num_classes = n_classes,  
            l_c=l_c, 
            l_d=l_d
        )

In [13]:
# forward pass
logits = model([tab_sample, img_sample])
latent = model([tab_sample, img_sample], return_embeddings=True)

print(logits.shape, latent.shape)

assert logits.shape == (b, n_classes)
assert latent.shape == (b, l_c, l_d)

torch.Size([4, 4]) torch.Size([4, 256, 512])


## Handling missing modalities

HEALNet natively handles missing modalities through its iterative architecture. If you encounter a missing data point in your pipeline, you can simply skip it by passing in `None` instead of the tensor. The model will stil return the embedding or prediction based on the present modalities. 

In [20]:
logits_missing = model([tab_sample, None], verbose=True)
latent_missing = model([None, img_sample], return_embeddings=True, verbose=True)

Missing modalities indices: [1]
Skipping update in fusion layer 1 for missing modality 2
Skipping update in fusion layer 2 for missing modality 2
Skipping update in fusion layer 3 for missing modality 2
Missing modalities indices: [0]
Skipping update in fusion layer 1 for missing modality 1
Skipping update in fusion layer 2 for missing modality 1
Skipping update in fusion layer 3 for missing modality 1
