## Imports

In [38]:
from healnet.models import HealNet
from healnet.etl import MMDataset
from healnet.utils.train_utils import count_parameters
import torch
import einops
from torch.utils.data import Dataset, DataLoader

from typing import *

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Synthetic modalities

We instantiate a synthetic multimodal dataset for demo purposes. 

In [39]:
n = 100 # number of samples
b = 4 # batch size

# latent channels x dims
l_c = 16
l_d = 64

# 2D image
img_c = 3 # image channels
h = 224 # image height
w = 224 # image width
# 3D image
d = 12 # depth

# 1D tabular
tab_c = 1 # tabular channels
tab_d = 5000 # tabular features
# 
n_classes = 4 # classification

# tab_tensor = torch.rand(size=(n, tab_d, tab_c)) # assume 5k tabular features
# img_tensor = torch.rand(size=(n, img_c, h, w)) # c h w

tab_tensor = torch.rand(size=(n, tab_c, tab_d)) 
img_2d_tensor = torch.rand(size=(n, h, w, img_c))
img_3d_tensor = torch.rand(size=(n, d, h, w, img_c))


# derive a target
target = torch.rand(size=(n,))

In [40]:
data = MMDataset([tab_tensor, img_2d_tensor, img_3d_tensor], target)
train, test, val = torch.utils.data.random_split(data, [0.7, 0.15, 0.15]) # create 70-15-15 train-val-test split

loader_args = {
    "shuffle": True, 
    "num_workers": 8, 
    "pin_memory": True, 
    "multiprocessing_context": "fork", 
    "persistent_workers": True, 
}

train_loader = DataLoader(train, **loader_args)
val_loader = DataLoader(val, **loader_args)
test_loader = DataLoader(test, **loader_args)
# example use

In [41]:
# example use
[tab_sample, img_sample_2d, img_sample_3d], target = data[0]

# emulate batch dimensio

tab_sample = einops.repeat(tab_sample, 'd c -> b d c', b=b)
img_sample_2d = einops.repeat(img_sample_2d, 'h w c -> b h w c', b=b)
img_sample_3d = einops.repeat(img_sample_3d, 'd h w c -> b d h w c', b=b)


In [43]:
model = HealNet(
            modalities=3, 
            input_channels=[tab_c, img_c, img_c], 
            input_axes=[1, 2, 3], # spatial/temporal tensor dimensions
            num_classes = n_classes,  
            l_d=l_d, 
            l_c=l_c, 
            fourier_encode_data=True, 
        )

print(model)
print(f"HEALNet parameters: {count_parameters(model)}")
# print number of parameters

HealNet(
  (layers): ModuleList(
    (0-2): 3 x ModuleList(
      (0): PreNorm(
        (fn): Attention(
          (to_q): Linear(in_features=64, out_features=64, bias=False)
          (to_kv): Linear(in_features=6, out_features=128, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
          (to_out): Sequential(
            (0): Linear(in_features=64, out_features=64, bias=True)
            (1): LeakyReLU(negative_slope=0.01)
          )
        )
        (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm_context): LayerNorm((6,), eps=1e-05, elementwise_affine=True)
      )
      (1): PreNorm(
        (fn): FeedForward(
          (net): Sequential(
            (0): Linear(in_features=64, out_features=512, bias=True)
            (1): SELU()
            (2): Linear(in_features=256, out_features=64, bias=True)
            (3): Dropout(p=0.0, inplace=False)
          )
        )
        (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    

In [45]:
# forward pass
logits = model([tab_sample, img_sample_2d, img_sample_3d])
latent = model([tab_sample, img_sample_2d, img_sample_3d], return_embeddings=True)

assert logits.shape == (b, n_classes)
assert latent.shape == (b, l_c, l_d)

print(f"{tab_sample.shape=}")
print(f"{img_sample_2d.shape=}")
print(f"{img_sample_3d.shape=}")
print(f"{logits.shape=}")
print(f"{latent.shape=}")

# print(logits.shape, latent.shape)

tab_sample.shape=torch.Size([4, 1, 5000])
img_sample_2d.shape=torch.Size([4, 224, 224, 3])
img_sample_3d.shape=torch.Size([4, 12, 224, 224, 3])
logits.shape=torch.Size([4, 4])
latent.shape=torch.Size([4, 16, 64])


## Handling missing modalities

HEALNet natively handles missing modalities through its iterative architecture. If you encounter a missing data point in your pipeline, you can simply skip it by passing in `None` instead of the tensor. The model will stil return the embedding or prediction based on the present modalities. 

In [48]:
logits_missing = model([tab_sample, None, img_sample_3d], verbose=True)
latent_missing = model([None, img_sample_2d, None], return_embeddings=True, verbose=True)

Missing modalities indices: [1]
Skipping update in fusion layer 1 for missing modality 2
Skipping update in fusion layer 2 for missing modality 2
Skipping update in fusion layer 3 for missing modality 2
Missing modalities indices: [0, 2]
Skipping update in fusion layer 1 for missing modality 1
Skipping update in fusion layer 1 for missing modality 3
Skipping update in fusion layer 2 for missing modality 1
Skipping update in fusion layer 2 for missing modality 3
Skipping update in fusion layer 3 for missing modality 1
Skipping update in fusion layer 3 for missing modality 3
