In [None]:
import torch

from torch.nn import Module, ModuleList, Embedding, Linear
from torch.nn.init import trunc_normal_, zeros_
from einops import rearrange, reduce, repeat
from utils.components import TransformerBlock, ConditionalLayerNorm


In [None]:
class ModalEncoder(Module):
    def __init__(self, cfg):
        super().__init__()
        self.in_projection = Embedding(cfg.num_features, cfg.dim * cfg.dim_input)
        self.feature_bias = Embedding(cfg.num_features, cfg.dim)
        self.queries = Embedding(1, cfg.dim)       
        self.kv_norm = ConditionalLayerNorm(cfg.dim, cfg.dim_ctx)
        self.cross_attn = TransformerBlock(cfg.dim, dim_ctx=cfg.dim_ctx)

    def forward(self, data, idx, ctx = None):
        # ensure correct shapes
        B, N, _, I = data.size()
        # get dynamic weights 
        w = self.in_projection(idx)
        w = rearrange(w, '... f (d i) -> ... f d i', i = I)
        # linear projection
        kv = torch.einsum('b n f i, ... f d i -> b n f d', data, w)
        # normalize and add feature-bias
        kv = self.kv_norm(kv, ctx) + self.feature_bias(idx)
        # expand query vectors
        q = repeat(self.queries.weight, 'q d -> b n q d', b = B, n = N)
        # cross attend
        q = self.cross_attn(q = q, kv = kv, ctx = ctx).squeeze(2)
        return q

In [None]:
class Model(Module):
    def __init__(self, cfg):
        super().__init__()
        # Embeddings
        self.out_projection = Embedding(cfg.num_features, cfg.dim * cfg.dim_output)

        self.positions = Embedding(cfg.num_positions, cfg.dim)
        
        self.queries = Embedding(cfg.num_latents, cfg.dim)

        # Networks
        self.encoder = ModalEncoder(cfg)
        self.processor = ModuleList([TransformerBlock(cfg.dim, dim_ctx=cfg.dim_ctx) for _ in range(cfg.num_blocks)])

        # Init
        self.apply(self.base_init)
        self.apply(self.zero_init)

    @staticmethod
    def base_init(m):
        if isinstance(m, Linear):
            trunc_normal_(m, std = 0.02)
        if isinstance(m, Embedding):
            trunc_normal_(m, std = 0.02)

    @staticmethod
    def zero_init(m):
        if isinstance(m, TransformerBlock):
            zeros_(m.att.to_out.weight)
            zeros_(m.ffn.to_out.weight)
            zeros_(m.att_norm.to_out.weight)
            zeros_(m.ffn_norm.to_out.weight)
        

    def forward(self, data, coords, noise):
