In [1]:
import torch
import einops
import math

from dataclasses import dataclass
from utils.model import *
from utils.config import NetworkConfig
from utils.loss_fn import f_kernel_crps

In [2]:
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

In [3]:
@dataclass
class WorldConfig:
    field_sizes: dict
    patch_sizes: dict
    batch_size: int
    tau: int
    alphas: dict

@dataclass
class ConfigWrapper:
    world : WorldConfig
    network : NetworkConfig

In [4]:
model_cfg = NetworkConfig(
    dim = 256, 
    num_layers=4, 
    num_compute_blocks=4,
    dim_in = 480,
    dim_out = 960,
    dim_coords = 128,
    wavelengths=[(1e-3, 1e2,), (1e-3, 1e2), (1e-3, 1e2,), (1e-3, 1e2,)], 
    num_features= 9, 
    num_latents=64)

world_cfg = WorldConfig(
    field_sizes = {'v': 9, 't': 36, 'h': 64, 'w': 120},
    patch_sizes = {'vv': 1, 'tt': 6, 'hh': 8, 'ww': 10},
    batch_size = 8,
    tau = 2,
    alphas = {'t': 0.5},
)

cfg = ConfigWrapper(
    world = world_cfg,
    network = model_cfg
)

In [5]:
class DiscreteDiffusion(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.device_type = 'cpu'
        self.device = torch.device('cpu')
        self.network = MaskedTokenModel(cfg.network)

        # Config shapes
        self.batch_size = cfg.world.batch_size
        self.field_sizes = cfg.world.field_sizes
        self.patch_sizes = cfg.world.patch_sizes
        self.field_layout = tuple(self.field_sizes.keys())
        self.patch_layout = tuple(self.patch_sizes.keys())

        # derived sizes and shapes
        self.token_sizes = {ax: (self.field_sizes[ax] // self.patch_sizes[f'{ax*2}'])
                            for ax in self.field_layout}
        self.token_shape = tuple(self.token_sizes[ax] for ax in self.field_layout)
        self.num_tokens = math.prod(self.token_sizes[ax] for ax in self.field_layout)
        self.num_elements = math.prod(self.field_sizes[ax] for ax in self.field_layout)
        self.dim_tokens = math.prod(self.patch_sizes[ax] for ax in self.patch_layout)

        #einops patterns
        field = " ".join([f"({f} {p})" for f, p in zip(self.field_layout, self.patch_layout)])
        self.field_pattern = f"b {field}"
        self.flat_token_pattern = f"({' '.join(self.field_layout)})"
        self.flat_patch_pattern = f"({' '.join(self.patch_layout)})"
        self.flatland_pattern = f"b {self.flat_token_pattern} {self.flat_patch_pattern}"

        # Index tensors: register as buffers so .to() moves them
        flatland_index = torch.arange(self.num_tokens, device=self.device).expand(self.batch_size, -1)
        self.register_buffer("flatland_index", flatland_index)
        token_index = torch.stack(torch.unravel_index(flatland_index, self.token_shape), dim=-1)
        self.register_buffer("token_index", token_index)

        # additional config attributes
        self.alphas = cfg.world.alphas
        self.tau = cfg.world.tau

    ### TOKENIZATION
    def field_to_tokens(self, field):
        return einops.rearrange(field, f'{self.field_pattern} -> {self.flatland_pattern}', **self.patch_sizes)
    
    def tokens_to_field(self, patch):
        return einops.rearrange(patch, f"{self.flatland_pattern} ... -> {self.field_pattern} ...", **self.token_sizes, **self.patch_sizes)

    ### MASKING
    def gumbel_noise(self, shape: tuple, eps: float = 1e-6):
        u = torch.rand(shape, device = self.device).clamp(min=eps, max=1-eps)
        return -torch.log(-torch.log(u))
            
    def dirichlet_marginal(self, ax: str, eps: float = 1e-6):
        concentration = torch.full((self.batch_size, self.token_sizes[ax]), self.alphas[ax], device= self.device)
        probs = torch._sample_dirichlet(concentration).clamp(min=eps, max=1-eps)
        probs = einops.repeat(probs, f'b {ax} -> b {self.flat_token_pattern}', **self.token_sizes)
        return probs.log()
    
    def k_from_rates(self, rates):
        return (self.num_tokens * rates).long().clamp(1, self.num_tokens - 1).view(-1, 1)
            
    def binary_topk(self, weights, ks):
        index = weights.argsort(dim = -1, descending=True)
        topk = self.flatland_index < ks
        binary = torch.zeros_like(topk, dtype=torch.bool).scatter(1, index, topk)
        return binary
    
    ### PRIORS
    def get_visible_ws(self):
        G = self.gumbel_noise((self.batch_size, self.num_tokens))
        D = self.dirichlet_marginal('t')
        return G + D

    def get_history_ws(self):
        step = torch.zeros((self.token_sizes['t'],), device=self.device)
        step[:self._cfg.tau] = float('inf')
        return einops.repeat(step, f't -> b {self.flat_token_pattern}', **self.token_sizes, b=self.batch_size)
    
    def get_visible_ks(self):
        linear_grid = torch.linspace(0, 1, self.batch_size, device= self.device)
        u = torch.rand((1,), device = self.device)
        rates = (u + linear_grid) % 1 
        return self.k_from_rates(rates)
    
    def get_history_ks(self):
        rates = torch.full((self.batch_size,), self.tau / self.token_sizes['t'], device = self.device)
        return self.k_from_rates(rates)

    ### FORWARD
    @property
    def land_sea_mask(self):
        lsm = torch.ones((1, self.field_sizes['h'], self.field_sizes['w']), device = self.device, dtype= torch.bool)
        return einops.repeat(lsm, f'1 (h hh) (w ww) -> {self.flatland_pattern}', **self.patch_sizes, **self.token_sizes, b=self.batch_size)
    
    def forward(self, data, mode: str = 'prior'):
        # tokens
        data = data.to(self.device)
        tokens = self.field_to_tokens(data)
        
        # masks
        ws = self.get_visible_ws() if mode == 'prior' else self.get_history_ws()  
        ks = self.get_visible_ks() if mode == 'prior' else self.get_history_ks() 
        visible = self.binary_topk(ws, ks)[..., None] # add singleton D dimension
        
        # predict
        pred = self.network(tokens, visible, self.token_index)
        
        # scoring rule
        ensemble = einops.rearrange(pred, '(b n) ... (d e) -> b ... d (n e)', d = tokens.size(-1), b = self.batch_size)
        score = f_kernel_crps(tokens, ensemble)

        # masked loss
        mask = torch.logical_and(self.land_sea_mask, ~visible) # combine lsm [B, N, D] and mask [B, N, 1]
        mask_rate = einops.reduce(mask, 'b n d -> b 1 1', reduction='sum') / self.num_elements
        loss = (score * mask / mask_rate).sum() / mask.sum()

        return loss

In [6]:
dd = DiscreteDiffusion(cfg)

In [7]:
data = torch.randn((dd.batch_size, *dd.field_sizes.values()))
loss = dd(data)