In [1]:
import torch
import einops
import math

from dataclasses import dataclass
from utils.model import *
from utils.config import NetworkConfig
from utils.loss_fn import f_kernel_crps

In [2]:
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

In [3]:
@dataclass
class WorldConfig:
    field_layout: tuple
    patch_layout: tuple
    field_sizes: dict
    patch_sizes: dict
    batch_size: int
    tau: int
    alphas: dict
    mask_range: tuple
    loss_weights: dict

@dataclass
class ConfigWrapper:
    world : WorldConfig
    network : NetworkConfig

In [None]:
model_cfg = NetworkConfig(
    dim = 256, 
    num_layers=4, 
    num_compute_blocks=4,
    dim_in = 480,
    dim_out = 960,
    dim_coords = 128,
    wavelengths=[(1e-3, 1e2,), (1e-3, 1e2), (1e-3, 1e2,), (1e-3, 1e2,)], 
    num_features= 9, 
    num_latents=64)

world_cfg = WorldConfig(
    field_layout = ('v', 't', 'h', 'w'),
    patch_layout = ('vv', 'tt', 'hh', 'ww'),
    field_sizes = {'v': 9, 't': 36, 'h': 64, 'w': 120},
    patch_sizes = {'vv': 1, 'tt': 6, 'hh': 8, 'ww': 10},
    batch_size = 8,
    tau = 2,
    alphas = {'t': 0.5},

    loss_weights = {'temp_ocn_0a': 1.,
               'temp_ocn_1a': 0.1,
               'temp_ocn_3a': 0.1,
               'temp_ocn_5a': 0.1,
               'temp_ocn_8a': 0.1,
               'temp_ocn_11a': 0.1,
               'temp_ocn_14a': 0.1,
               'tauxa': 0.01,
               'tauya': 0.01,
        }
)

cfg = ConfigWrapper(
    world = world_cfg,
    network = model_cfg
)

In [None]:
class MaskedField(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self._cfg = cfg.world
        self.device_type: str = 'cpu'
        self.device: torch.DeviceObjType = torch.device('cpu')
        self.network = MaskedTokenField(cfg.network)

    ### SHAPES
    @property # arrangement of axes in the unpatched field 
    def field_layout(self): 
        return self._cfg.field_layout
    
    @property # arrangement of axes in the patch dimension
    def patch_layout(self): 
        return self._cfg.patch_layout
    
    @property
    def batch_size(self): 
        return self._cfg.batch_size

    @property # field sizes per axis from config
    def field_sizes(self): 
        return self._cfg.field_sizes

    @property # patch sizes per axis from config
    def patch_sizes(self): 
        return self._cfg.patch_sizes
    
    @property # patch counts per axis derived from field and patch sizes
    def token_sizes(self): 
        return {ax: (self.field_sizes[ax] // self.patch_sizes[f'{ax*2}']) for ax in self.field_layout}
    
    @property # tuple of integers for the patched dimensions
    def token_shape(self):
        return tuple(self.token_sizes[ax] for ax in self.field_layout)
    
    @property # total number of patches in the flatland representation
    def num_tokens(self): 
        return math.prod([self.token_sizes[t] for t in self.field_layout])

    @property # total number of values in a single patch
    def dim_tokens(self):  
        return math.prod([self.patch_sizes[p] for p in self.patch_layout])

    @property # einops pattern for the unpatched field
    def field_pattern(self): 
        field = " ".join([f"({f} {p})" for f, p in zip(self.field_layout, self.patch_layout)])
        return f"b {field}"

    @property # einops pattern for the flattened token dimension
    def flat_token_pattern(self):
        return f'({' '.join(self.field_layout)})'
    
    @property # einops pattern for the flattened patch dimension
    def flat_patch_pattern(self):
        return f'({' '.join(self.patch_layout)})'

    @property # einops pattern for the patched flatland representation
    def flatland_pattern(self): 
        return f'b {self.flat_token_pattern} {self.flat_patch_pattern}'
    
    @property # (batched) indices for all tokens in flatland representation
    def flatland_index(self):
        return torch.arange(self.num_tokens, device = self.device).expand(self.batch_size, -1)
    
    @property # (batched) indices for all tokens in field representation
    def token_index(self):
        return torch.stack(torch.unravel_index(self.flatland_index, self.token_shape), dim = -1)

    ### TOKENIZATION
    def field_to_tokens(self, field):
        return einops.rearrange(field, f'{self.field_pattern} -> {self.flatland_pattern}', **self.patch_sizes)
    
    def tokens_to_field(self, patch):
        return einops.rearrange(patch, f"{self.flatland_pattern} ... -> {self.field_pattern} ...", **self.token_sizes, **self.patch_sizes)

    ### MASKING
    def gumbel_noise(self, shape: tuple, eps: float = 1e-6):
        u = torch.rand(shape, device = self.device).clamp(min=eps, max=1-eps)
        return -torch.log(-torch.log(u))
            
    def dirichlet_marginal(self, ax: str):
        concentration = torch.full((self.batch_size, self.token_sizes[ax]), self._cfg.alphas[ax], device= self.device)
        probs = torch._sample_dirichlet(concentration)
        probs = einops.repeat(probs, f'b {ax} -> b {self.flat_token_pattern}', **self.token_sizes)
        return probs.log()
    
    def get_masking_weights(self):
        G = self.gumbel_noise((self.batch_size, self.num_tokens))
        D = self.dirichlet_marginal('t')
        return G + D

    def get_frcst_weights(self):
        step = torch.zeros((self.token_sizes['t'],), device=self.device)
        step[:self._cfg.tau] = float('inf')
        return einops.repeat(step, f't -> b {self.flat_token_pattern}', **self.token_sizes, b=self.batch_size)
    
    def get_visible_rate(self):
        linear_grid = torch.linspace(0, 1, self.batch_size, device= self.device)
        u = torch.rand((1,), device = self.device)
        return (u + linear_grid) % 1 
    
    def get_history_rate(self):
        return torch.full((self.batch_size,), self._cfg.tau / self.token_sizes['t'], device = self.device)
            
    def get_binary_mask(self, weights, rates):
        k = (self.num_tokens * rates).long().clamp(1, self.num_tokens - 1).view(-1, 1)
        index = weights.argsort(dim = -1, descending=True)
        topk = self.flatland_index < k
        binary = torch.zeros_like(topk, dtype=torch.bool).scatter(1, index, topk)
        return binary, ~binary
    
    ### FORWARD
    @property
    def land_sea_mask(self):
        lsm = torch.ones((1, self.field_sizes['h'], self.field_sizes['w']), device = self.device)
        return einops.repeat(lsm, f'1 (h hh) (w ww) -> {self.flatland_pattern}', **self.patch_sizes, **self.token_sizes, b=self.batch_size)
    
    @property
    def loss_weights(self):
        w = torch.as_tensor(list(self._cfg.loss_weights.values()))
        return einops.repeat(w, f'(v vv) -> {self.flatland_pattern}', **self.patch_sizes, **self.token_sizes, b=self.batch_size)

    def forward(self, data, mode: str = 'prior'):
        # tokens
        data = data.to(self.device)
        tokens = self.field_to_tokens(data)
        
        # masking
        visible_rate = self.get_visible_rate() if mode == 'prior' else self.get_history_rate() 
        weights = self.get_masking_weights() if mode == 'prior' else self.get_frcst_weights()  
        visible, masked = self.get_binary_mask(weights, visible_rate)
        
        # network
        pred = self.network(tokens, visible, self.token_index)
        pred = einops.rearrange(pred, '(b n) ... (d e) -> b ... d (n e)', d = tokens.size(-1), b = self.batch_size)
        
        # scoring rule
        score = f_kernel_crps(tokens, pred) #* self.loss_weights * self.land_sea_mask # re-scale loss by variable and land-sea

        # masked diffusion loss
        score = score * (1 / (1 - visible_rate)).view(-1, 1, 1)
        loss = score[masked].sum() / masked.sum()

        return loss

In [6]:
mf = MaskedField(cfg)

In [7]:
data = torch.randn((mf.batch_size, *mf.field_sizes.values()))
loss = mf(data)