## Imports

In [None]:
# from mm_lego.models.lego import LegoBlock
from mm_lego.models import LegoBlock, MILAttentionNet, SNN, LegoMerge, LegoFuse
import torch

## Wrapping Encoders in LegoBlocks

One of MM-Lego's key components is the LegoBlock - you can use your initial unimodal architecture and fit a LegoBlock around it. Let's assume a case where you have a tabular and and imaging modality. 

In [None]:
b = 10 # batch size
# Note - dimensions always denoted as
t_c = 1  # number of channels (1 for tabular) ; note that channels correspond to modality input/features
t_d = 2189  # dimensions of each channel
i_c = 100  # >1 if using MIL setup
i_d = 1024  # dimensions per patch
latent = torch.randn(b, 256, 32)

tab_data = torch.randn(b, t_c, t_d)  # expects (b dims channels)
img_data = torch.randn(b, i_c, i_d)

tab_enc = SNN(t_d, final_head=False)
img_enc = MILAttentionNet(torch.Size((i_c, i_d)), final_head=False, size_arg="tcga")

# Lego Wrapper
tab_block = LegoBlock(in_shape=(t_c, t_d), encoder=tab_enc)
img_block = LegoBlock(in_shape=(i_c, i_d), encoder=img_enc)

# Forward pass of block
print(img_block([img_data], return_embeddings=True))

## Merging & Fusing Blocks

After fitting each unimodal block, we can merge them into a multimodal model

In [None]:
merged_model = LegoMerge(blocks=[tab_block, img_block], head_method="slerp", final_head=False)

# forward pass of merged model
merged_model([tab_data, img_data], return_embeddings=True)

fusion_model = LegoFuse(blocks=[tab_block, img_block], fuse_method="stack", head_method="slerp", final_head=False)

# Forward pass of fusion model
fusion_model([tab_data, img_data], return_embeddings=True)