In [1]:
import torch
import einops
import math

from dataclasses import dataclass

from utils.objective import DiscreteDiffusion
from utils.config import NetworkConfig, WorldConfig


In [2]:
import matplotlib.pyplot as plt
import numpy as np

In [3]:
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
@dataclass
class ConfigWrapper:
    world : WorldConfig
    network : NetworkConfig

In [5]:
world_cfg = WorldConfig(
    field_sizes = {'v': 9, 't': 36, 'h': 64, 'w': 120},
    patch_sizes = {'vv': 1, 'tt': 6, 'hh': 4, 'ww': 4},
    batch_size = 16,
    tau = 2,
    alphas = {'t': 0.5, 'v': 1.},
)

model_cfg = NetworkConfig(
    dim = 768, 
    num_layers=2, 
    num_compute_blocks=12,
    dim_in = world_cfg.dim_tokens,
    dim_out = world_cfg.dim_tokens * 4,
    dim_coords = 128,
    wavelengths=[(1e-3, 1e2,), (1e-3, 1e2), (1e-3, 1e2,), (1e-3, 1e2,)], 
    num_latents=256,
    use_checkpoint=True)

cfg = ConfigWrapper(
    world = world_cfg,
    network = model_cfg
)

In [6]:
### TOKENIZATION
@staticmethod
def field_to_tokens(world, field):
    return einops.rearrange(field, 
                            f'{world.field_pattern} -> {world.flatland_pattern}',
                            **world.patch_sizes)

@staticmethod
def tokens_to_field(world, patch):
    return einops.rearrange(patch, 
                            f"{world.flatland_pattern} ... -> {world.field_pattern} ...",
                            **world.token_sizes, **world.patch_sizes)

In [7]:
dd = DiscreteDiffusion(cfg, device= device)
dd.network.compile(fullgraph=True)

print(f'{count_parameters(dd.network):,}')
data = torch.randn((world_cfg.batch_size, *world_cfg.field_sizes.values()), device = device)

169,242,240


In [None]:
with torch.amp.autocast('cuda'):
    tokens = field_to_tokens(world_cfg, data)
    loss, obs, ens, visible = dd(tokens)
    loss.backward()