## Imports

In [None]:
from healnet.models import HealNet
from healnet.etl import MMDataset
import torch
import einops
from torch.utils.data import Dataset, DataLoader

from typing import *

%load_ext autoreload
%autoreload 2

## Synthetic data example

To illustrate how HEALNet can be used in any pipeline, we create three synthetic modalities, i.e., three possible modalities: 

* Tabular data: `(1, 2000)`
    * Table with with 2000 features `tab_d`. For 1D modalities, we add a channel dimension with `tab_c=1`
* 2D Image: `(224, 224, 3)`
    * Image corresponding to height, width, and colour channel. 
* 3D Image: `(12, 224, 224, 4)`
    * Image dims corresponding to depth, height, weight, and colour channel
* Target: `(n, )`
    * Asusming `n` observations


In [None]:
n = 100 # number of samples
b = 4 # batch size

# latent channels x dims
l_c = 16
l_d = 16

# 2D image
img_c = 3 # image channels
h = 224 # image height
w = 224 # image width
# 3D image
d = 12 # depth

# 1D tabular
tab_c = 1 # tabular channels
tab_d = 5000 # tabular features
# 
n_classes = 4 # classification

tab_tensor = torch.rand(size=(n, tab_c, tab_d)) 
img_2d_tensor = torch.rand(size=(n, h, w, img_c))
img_3d_tensor = torch.rand(size=(n, d, h, w, img_c))


# derive a target
target = torch.rand(size=(n,))

Given the original data as tensors, we instantiate `MMDataset`, a lightweight wrapper for the torch `Dataset` and pass this into a DataLoader

In [None]:
data = MMDataset([tab_tensor, img_2d_tensor, img_3d_tensor], target)
train, test, val = torch.utils.data.random_split(data, [0.7, 0.15, 0.15]) # create 70-15-15 train-val-test split

loader_args = {
    "shuffle": True, 
    "batch_size": 16, 
}

train_loader = DataLoader(train, **loader_args)
val_loader = DataLoader(val, **loader_args)
test_loader = DataLoader(test, **loader_args)
# fetch batch 

[tab_sample, img_sample_2d, img_sample_3d], target = next(iter(train_loader))

### Instantiate HEALNet

The non-optional arguments to instantiate HEALNet are: 

* **n_modalities** (int): Maximum number of modalities for forward pass. Note that fewer modalities can be passed if modalities for individual samples are missing (see `.forward()`)
*  **channel_dims** (List[int]): Number of channels or tokens for each modality. Length must match ``n_modalities``. The channel_dims are non-spatial dimensions where positional encoding is not required. 
* **num_spatial_axes** (List[int]): Spatial axes for each modality.The each spatial axis will be assigned positional encodings, so that ``num_spatial_axis`` is 2 for 2D images, 3 for Video/3D images. 
* **out_dims** (int): Output shape of task-specific head. Forward pass returns logits of this shape. 


As such, the input for each modality should be of shape ``(b, (*spatial_dims) c)``, where ``c`` corresponds to the dimensions where positional encoding does not matter (e.g., color channels, set-based features, or tabular features). The `spatial_dims` are the dimensions where preserving structural signal is crucial for the model to learn (e.g., the height x width x depth of the 3D image). 

#### On tabular modalities

One common exception to this are tabular modalities. Many tabular modalities do not contain inherent structure and are just an unordered bag of features. In this case, positional encodings add noise (as they don't mean anything. Therefore, we encode this as 2000 channels. 


In [None]:
spatial_axes = []
channels = []
for tensor in [tab_sample, img_sample_2d, img_sample_3d]:
    b, *spatial_dims, c = tensor.shape
    spatial_axes.append(len(spatial_dims))
    channels.append(c)
    
print(f"{spatial_axes=}, {channels=}")

model = HealNet(
            n_modalities=3, 
            channel_dims=channels, # (5000 (tabular), 3 (2 D img), 3 (2D image))
            num_spatial_axes=spatial_axes, # spatial/temporal tensor dimensions
            out_dims = n_classes,  
            l_d=l_d, 
            l_c=l_c, 
            fourier_encode_data=True, 
        )

print(model)

spatial_axes=[1, 2, 3], channels=[5000, 3, 3]
HealNet(
  (layers): ModuleList(
    (0-2): 3 x ModuleList(
      (0): PreNorm(
        (fn): Attention(
          (to_q): Linear(in_features=64, out_features=512, bias=False)
          (to_kv): Linear(in_features=5005, out_features=1024, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
          (to_out): Sequential(
            (0): Linear(in_features=512, out_features=64, bias=True)
            (1): LeakyReLU(negative_slope=0.01)
          )
        )
        (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm_context): LayerNorm((5005,), eps=1e-05, elementwise_affine=True)
      )
      (1): PreNorm(
        (fn): FeedForward(
          (net): Sequential(
            (0): Linear(in_features=64, out_features=512, bias=True)
            (1): SELU()
            (2): Linear(in_features=256, out_features=64, bias=True)
            (3): Dropout(p=0.0, inplace=False)
          )
        )
        (norm): La

: 

In [None]:
# forward pass
logits = model([tab_sample, img_sample_2d, img_sample_3d])
latent = model([tab_sample, img_sample_2d, img_sample_3d], return_embeddings=True)

assert logits.shape == (b, n_classes)
assert latent.shape == (b, l_c, l_d)

print(f"{tab_sample.shape=}")
print(f"{img_sample_2d.shape=}")
print(f"{img_sample_3d.shape=}")
print(f"{logits.shape=}")
print(f"{latent.shape=}")

## Handling missing modalities

HEALNet natively handles missing modalities through its iterative architecture. If you encounter a missing data point in your pipeline, you can simply skip it by passing in `None` instead of the tensor. The model will stil return the embedding or prediction based on the present modalities. 

Note that `verbose=True` will log during each forward pass, so it's recommended to turn this off in the train loop. 

In [None]:
logits_missing = model([tab_sample, None, img_sample_3d], verbose=True)
latent_missing = model([None, img_sample_2d, None], return_embeddings=True, verbose=True)

Missing modalities indices: [1]
Skipping update in fusion layer 1 for missing modality 2
Skipping update in fusion layer 2 for missing modality 2
Skipping update in fusion layer 3 for missing modality 2
Missing modalities indices: [0, 2]
Skipping update in fusion layer 1 for missing modality 1
Skipping update in fusion layer 1 for missing modality 3
Skipping update in fusion layer 2 for missing modality 1
Skipping update in fusion layer 2 for missing modality 3
Skipping update in fusion layer 3 for missing modality 1
Skipping update in fusion layer 3 for missing modality 3
