In [48]:
import torch
import einops
import math

from dataclasses import dataclass


from utils.config import NetworkConfig
from utils.loss_fn import f_kernel_crps

In [49]:
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

In [50]:
@dataclass
class WorldConfig:
    field_layout: tuple
    patch_layout: tuple
    field_sizes: dict
    patch_sizes: dict
    batch_size: int
    tau: int
    alphas: dict
    mask_range: tuple

@dataclass
class ConfigWrapper:
    world : WorldConfig
    network : NetworkConfig

In [51]:
model_cfg = NetworkConfig(
    dim = 256, 
    num_layers=12, 
    dim_in = 4,
    dim_out = 96,
    dim_coords = 128,
    wavelengths=[(1e-3, 1e2,), (1e-3, 1e2), (1e-3, 1e2,), (1e-3, 1e2,)], 
    num_features= 8, 
    num_latents=64)

world_cfg = WorldConfig(
    field_layout = ('v', 't', 'h', 'w'),
    patch_layout = ('vv', 'tt', 'hh', 'ww'),
    field_sizes = {'v': 8, 't': 36, 'h': 64, 'w': 120},
    patch_sizes = {'vv': 1, 'tt': 6, 'hh': 8, 'ww': 10},
    batch_size = 8,
    tau = 2,
    mask_range = (64, 4096),
    alphas = {'t': 3.}  # Dirichlet prior concentration parameters
)

cfg = ConfigWrapper(
    world = world_cfg,
    network = model_cfg
)

In [None]:
from utils.components import *

class WeatherField(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        # embeddings
        self.latent_embedding = torch.nn.Embedding(cfg.num_latents, cfg.dim)
        self.world_embedding = ContinuousPositionalEmbedding(cfg.dim_coords, cfg.wavelengths, cfg.dim)

        # linear projections
        self.proj_in = SegmentLinear(cfg.dim_in, cfg.dim, cfg.num_features)
        self.proj_out = SegmentLinear(cfg.dim, cfg.dim_out, cfg.num_features)

        # Transformer blocks
        self.encoder = torch.nn.ModuleList([TransformerBlock(cfg.dim, dim_heads=cfg.dim_heads) for _ in range(cfg.num_layers)])
        self.decoder = TransformerBlock(cfg.dim, dim_heads=cfg.dim_heads, has_skip=False)
        
        # Initialization
        self.apply(self.base_init)

    @staticmethod
    def base_init(m):
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.trunc_normal_(m.weight, std = get_weight_std(m.weight))
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

        if isinstance(m, torch.nn.Embedding):
            torch.nn.init.trunc_normal_(m.weight, std = get_weight_std(m.weight))

        if isinstance(m, torch.nn.LayerNorm):
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)
            if m.weight is not None:
                torch.nn.init.ones_(m.weight)

        if isinstance(m, ConditionalLayerNorm) and m.linear is not None:
            torch.nn.init.trunc_normal_(m.linear.weight, std = 1e-8)

    def forward(self, tokens, visible, coordinates):
        batch = torch.arange(tokens.size(0), device=tokens.device).expand(tokens.size(0), -1) # index for fancy indexing
        modality = coordinates[..., 0] # index for modality-wise linear layers
        
        # positional embedding for all available coordinates
        world = self.world_embedding(coordinates)
        
        # embed visible values and add their positional code
        src = self.proj_in(tokens[batch, visible], modality[batch, visible])
        src = src + world[batch, visible]
        
        # update latents given src and latents
        latents = self.latent_embedding.weight.expand(tokens.size(0), -1, -1)
        for perceiver in self.encoder:
            kv = torch.cat([src, latents], dim = 1)
            latents = perceiver(latents, kv)

        # update world given latents
        out = self.decoder(world, latents)
        out = self.proj_out(out, modality)
        return out

class MaskedField(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self._cfg = cfg.world
        self.device_type: str = 'cpu'
        self.device: torch.DeviceObjType = torch.device('cpu')
        self.network = WeatherField(cfg.network)

    ### SHAPES
    @property # arrangement of axes in the unpatched field 
    def field_layout(self): 
        return self._cfg.field_layout
    
    @property # arrangement of axes in the patch dimension
    def patch_layout(self): 
        return self._cfg.patch_layout
    
    @property
    def batch_size(self): 
        return self._cfg.batch_size

    @property # field sizes per axis from config
    def field_sizes(self): 
        return self._cfg.field_sizes

    @property # patch sizes per axis from config
    def patch_sizes(self): 
        return self._cfg.patch_sizes
    
    @property # patch counts per axis derived from field and patch sizes
    def token_sizes(self): 
        return {ax: (self.field_sizes[ax] // self.patch_sizes[f'{ax*2}']) for ax in self.field_layout}
    
    @property # tuple of integers for the patched dimensions
    def token_shape(self):
        return tuple(self.token_sizes[ax] for ax in self.field_layout)
    
    @property # total number of patches in the flatland representation
    def num_tokens(self): 
        return math.prod([self.token_sizes[t] for t in self.field_layout])

    @property # total number of values in a single patch
    def dim_tokens(self):  
        return math.prod([self.patch_sizes[p] for p in self.patch_layout])
    
    @property # for fancy indexing along the batch dimension
    def batch_index(self): 
        return torch.arange(self.batch_size, device = self.device).view(-1, 1)
    
    @property # per-axis indices for all patches
    def coordinate_index(self):
        return torch.stack(torch.unravel_index(torch.arange(self.num_tokens, device = self.device).expand(self.batch_size, -1), self.token_shape), dim=-1)

    @property # einops pattern for the unpatched field
    def field_pattern(self): 
        field = " ".join([f"({f} {p})" for f, p in zip(self.field_layout, self.patch_layout)])
        return f"b {field} ..."

    @property # einops pattern for the flattened token dimension
    def flat_token_pattern(self):
        return f'({' '.join(self.field_layout)})'
    
    @property # einops pattern for the flattened patch dimension
    def flat_patch_pattern(self):
        return f'({' '.join(self.patch_layout)})'

    @property # einops pattern for the patched flatland representation
    def flatland_pattern(self): 
        return f'b {self.flat_token_pattern} {self.flat_patch_pattern} ...'
    
    ### TOKENIZATION
    def field_to_tokens(self, field):
        return einops.rearrange(field, f'{self.field_pattern} -> {self.flatland_pattern}', **self.patch_sizes)
    
    def tokens_to_field(self, patch):
        return einops.rearrange(patch, f"{self.flatland_pattern} -> {self.field_pattern}", **self.token_sizes, **self.patch_sizes)

    ### MASKING
    def gumbel_noise(self, shape: tuple, eps: float = 1e-6):
        u = torch.rand(shape, device = self.device).clamp(min=eps, max=1-eps)
        return -torch.log(-torch.log(u))
            
    def dirichlet_marginal(self, ax: str):
        concentration = torch.full((self.batch_size, self.token_sizes[ax]), self._cfg.alphas[ax], device= self.device)
        probs = torch._sample_dirichlet(concentration)
        probs = einops.repeat(probs, f'b {ax} -> {self.flat_token_pattern}', **self.token_sizes)
        return probs.log()
    
    def get_frcst_mask(self):
        k = int(self.num_tokens  * self._cfg.tau / self.token_sizes['t'])
        step = torch.zeros((self.token_sizes['t'],), device=self.device)
        step[:self._cfg.tau] = float('inf')
        step = einops.repeat(step, f't -> b {self.flat_token_pattern}', **self.token_sizes, b=self.batch_size)
        step = step.argsort(dim = -1, descending=True)
        return step[..., :k], step[..., k:]
            
    def get_random_mask(self):
        k = torch.randint(*self._cfg.mask_range, (1,), device = self.device)
        G = self.gumbel_noise((self.batch_size, self.num_tokens))
        factors = [G] + [self.dirichlet_marginal(ax) for ax in self._cfg.alphas.keys()]
        score = einops.reduce(factors, f'factors ... -> ...', 'sum') 
        score = score.argsort(dim = -1, descending=True)
        return score[..., :k], score[..., k:]
    
    ### FORWARD
    @property
    def land_sea_mask(self):
        lsm = torch.ones((1, self.field_sizes['h'], self.field_sizes['w']), device = self.device)
        return einops.repeat(lsm, f'1 (h hh) (w ww) -> {self.flatland_pattern}', **self.patch_sizes)
    
    @property
    def loss_weights(self):
        w = torch.as_tensor(self._cfg.loss_weights)
        return einops.repeat(w, f'(v vv) -> {self.flatland_pattern}', **self.patch_sizes)

    def forward(self, data, mode: str = 'prior'):
        data = data.to(self.device)
        tokens = self.field_to_tokens(data)
        visible, masked = self.get_random_mask() if mode == 'prior' else self.get_frcst_mask()
        
        pred = self.network(tokens, visible, self.coordinate_index)

        loss = f_kernel_crps(tokens, pred) * self.loss_weights * self.land_sea_mask # re-scale loss by variable and land-sea
        loss = loss[self.batch_index, masked] #only calculate loss on masked tokens

        return loss.mean()

In [59]:
mf = MaskedField(cfg)

In [60]:
data = torch.randn((mf.batch_size, mf.num_tokens, 4))
mask, complement = mf.get_random_mask()



In [61]:
pred = mf.network(data, mask, mf.coordinate_index)