In [1]:
import torch
import einops
import math

from dataclasses import dataclass

from utils.objective import DiscreteDiffusion
from utils.config import NetworkConfig, WorldConfig


In [2]:
import matplotlib.pyplot as plt
import numpy as np

In [3]:
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

In [4]:

@dataclass
class ConfigWrapper:
    world : WorldConfig
    network : NetworkConfig

In [5]:
model_cfg = NetworkConfig(
    dim = 256, 
    num_layers=4, 
    num_compute_blocks=4,
    dim_in = 480,
    dim_out = 960,
    dim_coords = 128,
    wavelengths=[(1e-3, 1e2,), (1e-3, 1e2), (1e-3, 1e2,), (1e-3, 1e2,)], 
    num_latents=64)

world_cfg = WorldConfig(
    field_sizes = {'v': 9, 't': 36, 'h': 64, 'w': 120},
    patch_sizes = {'vv': 1, 'tt': 6, 'hh': 8, 'ww': 10},
    batch_size = 8,
    tau = 2,
    alphas = {'t': 0.5, 'v': 1.},
)

cfg = ConfigWrapper(
    world = world_cfg,
    network = model_cfg
)

In [6]:
dd = DiscreteDiffusion(cfg)

### TOKENIZATION
@staticmethod
def field_to_tokens(world, field):
    return einops.rearrange(field, 
                            f'{world.field_pattern} -> {world.flatland_pattern}',
                            **world.patch_sizes)

@staticmethod
def tokens_to_field(world, patch):
    return einops.rearrange(patch, 
                            f"{world.flatland_pattern} ... -> {world.field_pattern} ...",
                            **world.token_sizes, **world.patch_sizes)

In [7]:
data = torch.randn((world_cfg.batch_size, *world_cfg.field_sizes.values()))
loss = dd(field_to_tokens(world_cfg, data))

In [8]:
loss

tensor(0.7915, grad_fn=<MeanBackward0>)