In [1]:
import torch

from torch.nn import Module, Embedding, Linear, ModuleList
from einops import rearrange, reduce, repeat
from utils.components import TransformerBlock, ConditionalLayerNorm, GatedFFN, Attention, Interface

from dataclasses import dataclass
from typing import Optional

from utils.vit import ViT

In [2]:
class ModalEncoder(Module):
    def __init__(self, cfg):
        super().__init__()
        self.in_projection = Embedding(cfg.num_features, cfg.dim * cfg.dim_in)
        self.feature_bias = Embedding(cfg.num_features, cfg.dim)
        self.queries = Embedding(1, cfg.dim)       
        self.kv_norm = ConditionalLayerNorm(cfg.dim, cfg.dim_ctx)
        self.cross_attn = TransformerBlock(cfg.dim, dim_ctx=cfg.dim_ctx)

    def forward(self, data, idx, ctx = None):
        # ensure correct shapes
        B, N, _, I = data.size()
        # get dynamic weights 
        w = self.in_projection(idx)
        w = rearrange(w, '... f (d i) -> ... f d i', i = I)
        b = self.feature_bias(idx)
        b = rearrange(b, "... f d -> ... () f d")
        # linear projection
        kv = torch.einsum('b n f i, ... f d i -> b n f d', data, w)
        # normalize and add feature-bias
        kv = self.kv_norm(kv, ctx) + b
        # expand query vectors
        q = repeat(self.queries.weight, 'q d -> b n q d', b = B, n = N)
        # cross attend
        q = self.cross_attn(q = q, kv = kv, ctx = ctx).squeeze(2)
        return q
    
class ModalDecoder(Module):
    def __init__(self, cfg):
        super().__init__()
        self.norm = ConditionalLayerNorm(cfg.dim, cfg.dim_ctx)
        self.ffn = GatedFFN(cfg.dim)
        self.out_projection = Embedding(cfg.num_features, cfg.dim * cfg.dim_out)

    def forward(self, latent, idx, ctx = None):
        _, _, D = latent.size()
        latent = self.ffn(self.norm(latent, ctx))
        w = self.out_projection(idx)
        w = rearrange(w, "... f (d o) -> ... f d o", d = D)
        out = torch.einsum("b n d, ... f d o -> b n f o", latent, w)
        return out

In [None]:
def predict_ambient_src(self, src, tgt, latents, ctx):
    # O(2*z*(src+tgt) + z**2)
    x = torch.cat([src, tgt], dim = 1)
    for block in self.network:
        x, z = block(x = x, z = latents, ctx = ctx)
    _, tgt = x.split([src.size(1), tgt.size(1)], dim = 1)
    return tgt, z

def predict_latent_src(self, src, tgt, latents, ctx):
    #O(2*tgt*(src + z) + (src * z)**2)
    z = torch.cat([src, latents], dim = 1)
    for block in self.network:
        tgt, z = block(x = tgt, z = z, ctx = ctx)
    _, z = z.split([src.size(1), latents.size(1)], dim=1)
    return tgt, z

def predict_self_vit(self, src, tgt, latents, ctx):
    x = torch.cat([src, latents], dim = 1)
    encoder, decoder = self.network[:-1], self.network[-1]
    for block in encoder:
        x, _ = block(x, x, ctx)
    q = torch.cat([x, tgt], dim = 1)
    q = decoder(q, q, ctx)
    _, tgt = q.split([x.size(1), tgt.size(1)], dim = 1)
    return tgt, None

def predict_cross_vit(self, src, tgt, latents, ctx):
    x = torch.cat([src, latents], dim = 1)
    encoder, decoder = self.network[:-1], self.network[-1]
    for block in encoder:
        x, _ = block(x, x, ctx)
    tgt = decoder(tgt, x, ctx)
    return tgt, None

def predict_perceiver(self, src, tgt, latents, ctx):
    encoder, decoder = self.network[:-1], self.network[-1]
    for block in encoder:
        _, z = block(src, latents, ctx)
    tgt = decoder(tgt, z, ctx)
    return tgt, None

In [43]:
class MaskedInterface(Module):
    def __init__(self, cfg: dataclass):
        super().__init__()
        
        # In/Out
        self.proj_in = Linear(cfg.dim_in, cfg.dim)
        self.proj_noise = Linear(cfg.dim_ctx, cfg.dim_ctx)
        self.dim_ctx = cfg.dim_ctx
        self.proj_out = Linear(cfg.dim, cfg.dim_out)

        # Embeddings
        self.queries = Embedding(1, cfg.dim)
        self.latents = Embedding(cfg.num_latents, cfg.dim)
        self.positions = Embedding(cfg.num_tokens, cfg.dim)

        # Interfaces
        self.network = ModuleList([
            Interface(cfg.dim, cfg.num_compute_blocks, dim_ctx= cfg.dim_ctx, dim_heads= cfg.dim_heads) for _ in range(cfg.num_layers)
        ])

    def forward(self, src, src_pos, tgt_pos, ctx = None):  
        #initialize context
        ctx = self.proj_noise(ctx if ctx is not None else src.new_zeros(self.dim_ctx))
        # initialize src
        src = self.proj_in(src) + self.positions(src_pos)
        # initialize latents
        latents = repeat(self.latents.weight, "z d -> b z d", b = src.size(0))      
        # initialize tgt
        tgt = self.queries(torch.zeros_like(tgt_pos)) + self.positions(tgt_pos)
        # predict
        z = torch.cat([src, latents], dim = 1)
        for block in self.network:
            tgt, z = block(x = tgt, z = z, ctx = ctx)
        _, z = z.split([src.size(1), latents.size(1)], dim=1)
        # project out
        tgt = self.proj_out(tgt)
        return tgt

In [44]:
class MaskedViT(Module):
    def __init__(self, cfg: dataclass):
        super().__init__()
        # Embeddings
        self.queries = Embedding(1, cfg.dim)
        self.network = ViT(cfg)
        self.predict = TransformerBlock(cfg.dim, dim_ctx=cfg.dim_ctx)

    def forward(self, x, src_pos, tgt_pos, ctx = None):   
        src = self.network(x, src_pos)
        tgt = self.queries(torch.zeros_like(tgt_pos)) + self.network.positions(tgt_pos)
        # predict
        kv = torch.cat([src, tgt], dim = 1)
        tgt = self.predict(tgt, kv, ctx = ctx)
        return tgt

In [54]:
class ModalWrapper(Module):
    def __init__(self, interface_cfg:dataclass, modal_cfg: dataclass):
        super().__init__()
        # Networks
        self.encoder = ModalEncoder(modal_cfg)
        self.processor = MaskedInterface(interface_cfg)
        self.decoder = ModalDecoder(modal_cfg)
        
        # Init
        self.apply(self.base_init)
        self.apply(self.zero_init)

    @staticmethod
    def base_init(m):
        if isinstance(m, Linear):
            torch.nn.init.trunc_normal_(m.weight, std = 0.02)
        if isinstance(m, Embedding):
            torch.nn.init.trunc_normal_(m.weight, std = 0.02)

    @staticmethod
    def zero_init(m):
        if isinstance(m, Attention):
            torch.nn.init.zeros_(m.to_out.weight)
        if isinstance(m, GatedFFN):
            torch.nn.init.zeros_(m.to_out.weight)
        if isinstance(m, ConditionalLayerNorm):
            torch.nn.init.zeros_(m.linear.weight)

    def forward(self, data: torch.Tensor, coords: tuple, noise: torch.Tensor):
        K = noise.size(0) // data.size(0)
        pos_src, pos_tgt, var_src, var_tgt = coords
        x = self.encoder(data, var_src, ctx = None)
        z = self.processor(x, pos_src, pos_tgt, ctx = None)
        z_hat = repeat(z, "b ... -> (b k) ...", k = K)
        var_tgt = repeat(var_tgt, "b ... -> (b k) ...", k = K)
        x_hat = self.decoder(z_hat, var_tgt, noise)
        x_hat = rearrange(x_hat, "(b k) ... -> b ... k", k = K)
        return x_hat

In [55]:
@dataclass
class Config:
    dim_in: int
    dim_out: int
    dim: int
    num_features: Optional[int] = None
    num_tokens: Optional[int] = None
    num_compute_blocks: Optional[int] = None
    num_layers: Optional[int] = None
    num_latents: Optional[int] = None
    num_cls: Optional[int] = None
    dim_heads: int = 64
    dim_ctx: int = 8

In [56]:
B, N, F, C = 8, 4096, 73, 4
D = 512
K = 4
M_src, M_tgt, F_src, F_tgt = 256, 1024, 32, 64 

In [57]:
modal_cfg = Config(C, C, D, F, N)
interface_cfg = Config(D, D, D, None, N, 4, 2, 32)
vit_cfg = Config(D, D, D, None, N, None, 12, None, 8)

In [58]:
pos_src = torch.multinomial(torch.ones((B, N)), M_src)
pos_tgt = torch.multinomial(torch.ones((B, N)), M_tgt)
var_src = torch.multinomial(torch.ones((B, F)), F_src)
var_tgt = torch.multinomial(torch.ones((B, F)), F_src)

In [59]:
x = torch.randn((B, M_src, F_src, C))
noise = torch.randn((B * 4, 1, 8))

In [60]:
model = ModalWrapper(interface_cfg, modal_cfg)

In [61]:
x = model(x, (pos_src, pos_tgt, var_src, var_tgt), noise)

In [62]:
x.size()

torch.Size([8, 1024, 32, 4, 4])