In [None]:
import torch
import einops
import math

from dataclasses import dataclass


from utils.config import NetworkConfig
from utils.model import WeatherField

In [55]:
@dataclass
class TrainerProxy:
    field_layout: tuple
    patch_layout: tuple
    token_layout: tuple
    field_sizes: dict
    patch_sizes: dict
    device_type: str = 'cpu'
    device: torch.DeviceObjType = torch.device('cpu')
    generator: torch.Generator = torch.Generator(None)

def exists(val):
    return val is not None

model_cfg = NetworkConfig(
    dim = 512, 
    num_layers=12, 
    dim_in = 96,
    dim_out = 96,
    wavelengths=[(1e-3, 1e2,), (1e-3, 1e2,), (1e-3, 1e2,)], 
    num_features= 8, 
    num_latents=256)

In [74]:
class MaskedField(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.device_type: str = 'cpu'
        self.device: torch.DeviceObjType = torch.device('cpu')
        self.network = WeatherField(cfg)

    ### SHAPES
    @property
    def batch_size(self): 
        return 8

    @property
    def field_layout(self): 
        return ('v', 't', 'h', 'w')
    
    @property
    def flatland_layout(self): 
        return ('v', 't', 'h', 'w')
    
    @property
    def patch_layout(self): 
        return ('vv', 'tt', 'hh', 'ww')

    @property
    def field_sizes(self): 
        return {'v': 8, 't': 36, 'h': 64, 'w': 120}

    @property
    def patch_sizes(self): 
        return {'vv': 1, 'tt': 6, 'hh': 4, 'ww': 4}

    @property
    def flatland_sizes(self): 
        return {ax: (self.field_sizes[ax] // self.patch_sizes[f'{ax*2}']) for ax in self.flatland_layout}

    @property
    def field_pattern(self): 
        field = " ".join([f"({f} {p})" for f, p in zip(self.field_layout, self.patch_layout)])
        return f"... {field}"

    @property
    def flatland_pattern(self): 
        tokens = ' '.join(self.flatland_layout)
        patches = ' '.join(self.patch_layout)
        return f'... ({tokens}) ({patches})'

    ### SHAPE MANIPULATIONS
    def field_to_patch(self, field):
        return einops.rearrange(field, f'{self.field_pattern} -> {self.flatland_pattern}', **self.patch_sizes)
    
    def patch_to_field(self, patch):
        return einops.rearrange(patch, f"{self.flatland_pattern} -> {self.field_pattern}", **self.flatland_sizes, **self.patch_sizes)
    
    def get_flatland_dims(self):
            B = self.batch_size
            N = math.prod([self.flatland_sizes[t] for t in self.flatland_layout])
            D = math.prod([self.patch_sizes[p] for p in self.patch_layout])
            return B, N, D
    
    def get_token_dims(self):
        return tuple(self.flatland_sizes[ax] for ax in self.flatland_layout)

    ### COORDINATES
    @staticmethod
    def compute_strides(layout: tuple, sizes: dict):
        strides = {}
        stride = 1
        for ax in reversed(layout):
            strides[ax] = stride
            stride *= sizes[ax]
        return strides
    
    @staticmethod
    def index_to_coords(idx: torch.LongTensor, layout: tuple, strides: dict):
        coords = []
        for ax in layout:
            val = idx.div(strides[ax], rounding_mode="floor")
            idx = idx.fmod(strides[ax])
            coords.append(val)
        return torch.stack(coords, dim=-1) 

    @staticmethod
    def coords_to_index(coords: torch.LongTensor, layout: tuple, strides: dict):
        return sum(coords[..., i] * strides[ax] for i, ax in enumerate(layout)) 

    ### SAMPLING
    def gumbel_topk(self, phi: torch.Tensor, k: int):
        noise = torch.rand_like(phi)
        gumbel = phi - torch.log(-torch.log(noise))
        return gumbel.topk(k, sorted = False).indices
            
    @staticmethod
    def as_logit(p: torch.Tensor):
        return torch.log(p) - torch.log1p(-p)
    
    def trunc_normal_rate(self, mean, std, a, b):
        rate = torch.empty((1,), device= self.device)
        return torch.nn.init.trunc_normal_(rate, mean, std, a, b)
    
    ### MASKING
    def mask_prior(self):
        B, N, _ = self.get_flatland_dims()
        V, T, H, W = self.get_token_dims()

        alpha_t = torch.full((B, T), 0.5)
        alpha_v = torch.full((B, V), 5.)
        priors = [self.as_logit(torch._sample_dirichlet(a)) for a in [alpha_t, alpha_v]]

        joint = torch.einsum(f'b t, b v -> b v t', *priors)
        joint = einops.repeat(joint, 'b v t -> b (v t h w)', h = H, w = W)
        
        k_src = int(self.trunc_normal_rate(0.2, 0.1, 0.05, 0.3) * N)
        k_tgt = int(self.trunc_normal_rate(0.5, 0.1, 0.4, 0.6) * N)
        
        src_mask = self.gumbel_topk(joint, k_src)
        tgt_mask = self.gumbel_topk(1 - joint, k_tgt)

        return src_mask, tgt_mask
        
        


In [75]:
mf = MaskedField(model_cfg)

In [76]:
s, t = mf.mask_prior()

In [77]:
s.shape

torch.Size([8, 4065])