In [1]:
import einops
import torch

from einops.layers.torch import EinMix

from utils.components import *
from utils.config import *

In [2]:
world = WorldConfig(field_sizes={"v": 6, "t": 36, "h": 64, "w": 120}, patch_sizes={'vv': 2, 'tt': 6, 'hh': 4, 'ww': 4}, batch_size=16)

In [3]:
network = NetworkConfig(dim=256, num_latents=32, num_layers=12, num_cls=4, use_checkpoint=False)

In [4]:
test_in = EinMix(pattern=f"{world.field_pattern} -> b {world.flat_token_pattern} d", 
              weight_shape=f'v d {world.patch_pattern}', 
              bias_shape=f'{world.flat_token_pattern} d',
              **world.patch_sizes, **world.token_sizes, d = network.dim
              )

test_out = EinMix(
            pattern=f"b {world.flat_token_pattern} d -> {world.field_pattern} e", 
            weight_shape=f'e v {world.patch_pattern} d', 
            bias_shape=f'e v {world.patch_pattern}',
            **world.patch_sizes, **world.token_sizes, d = network.dim, e = network.num_cls
            )

In [5]:
a = torch.randn((world.batch_size, world.field_sizes['v'], world.field_sizes['t'], world.field_sizes['h'], world.field_sizes['w']))
b = test_in(a)
c = test_out(b)

print(a.shape)
print(b.shape)
print(c.shape)

print(test_in.weight.shape)
print(test_in.bias.shape)

torch.Size([16, 6, 36, 64, 120])
torch.Size([16, 8640, 256])
torch.Size([16, 6, 36, 64, 120, 4])
torch.Size([3, 256, 2, 6, 4, 4])
torch.Size([1, 3, 6, 16, 30, 256])


In [6]:
class EinMask(torch.nn.Module):
    def __init__(self, network: NetworkConfig, world: WorldConfig):
        super().__init__()
        # I/O
        self.to_tokens = EinMix(
            pattern=f"{world.field_pattern} -> b {world.flat_token_pattern} d", 
            weight_shape=f'v d {world.patch_pattern}', 
            bias_shape=f'{world.flat_token_pattern} d',
            **world.patch_sizes, **world.token_sizes, 
            d = network.dim
            )
        
        self.to_fields = EinMix(
            pattern=f"b {world.flat_token_pattern} d -> {world.field_pattern} e", 
            weight_shape=f'e v {world.patch_pattern} d', 
            bias_shape=f'e v {world.patch_pattern}',
            **world.patch_sizes, **world.token_sizes, 
            d = network.dim, e = network.num_cls
            )
        
        # embeddings
        self.latents = torch.nn.Embedding(network.num_latents, network.dim)
        self.queries = torch.nn.Embedding(world.num_tokens, network.dim)

        # perceiver
        self.transformer = torch.nn.ModuleList([
            TransformerBlock(
                dim =network.dim, 
                dim_heads=network.dim_heads, 
                dim_ctx = network.dim_noise
                ) for _ in range(network.num_layers)
            ])
        self.write = TransformerBlock(
                dim =network.dim, 
                dim_heads=network.dim_heads, 
                dim_ctx = network.dim_noise
                )

        # Weight initialization
        self.apply(self.base_init)
    
    @staticmethod
    def base_init(m):
        '''Explicit weight initialization for all components'''
        # linear
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.trunc_normal_(m.weight, std = get_weight_std(m.weight))
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
        # embedding
        if isinstance(m, torch.nn.Embedding):
            torch.nn.init.trunc_normal_(m.weight, std = get_weight_std(m.weight))
        # einmix
        if isinstance(m, EinMix):
            torch.nn.init.trunc_normal_(m.weight, std = 0.02)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
        # conditional layer norm
        if isinstance(m, ConditionalLayerNorm):
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
            if m.weight is not None: # CLN weight close to 0
                torch.nn.init.trunc_normal_(m.weight, std = 1e-7)

    def forward(self, fields, context):
        # project field to tokens
        x = self.to_tokens(fields)
        # expand shapes
        z = einops.repeat(self.latents.weight, 'm d -> b m d', b = x.size(0))
        q = einops.repeat(self.queries.weight, 'n d -> b n d', b = x.size(0))
        c = einops.repeat(context, 'b n -> b n d', d = x.size(-1))
        # gather context
        x = x.gather(1, c)
        # apply transformer
        for block in self.transformer:
            kv = torch.cat([x, z], dim = 1)
            z = block(q = z, kv = kv)
        q = self.write(q = q, kv = z)
        # return field
        fields = self.to_fields(q)
        return fields


In [7]:
dirichlet = torch._sample_dirichlet(torch.full((world.batch_size, world.token_sizes['t']), 0.5))
prior = einops.repeat(dirichlet, f'b t -> b {world.flat_token_pattern}', **world.token_sizes)
k = torch.randint(128, 4096, (1,)).item()
src = torch.multinomial(prior, k)
mask = torch.ones_like(prior, dtype= torch.bool).scatter_(1, src, False)

In [8]:
model = EinMask(network, world).to("cuda")
a = torch.randn((world.batch_size, world.field_sizes['v'], world.field_sizes['t'], world.field_sizes['h'], world.field_sizes['w']))

In [9]:
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
    d = model(a.to("cuda"), src.to("cuda"))