In [9]:
import torch
import einops
from utils.components import *
from einops.layers.torch import EinMix
from utils.masking import *
from utils.einmask import EinMask

In [10]:
from utils.config import *
from omegaconf import OmegaConf
from dataclasses import replace, dataclass
from math import prod

from utils.components import *
from utils.config import *
from utils.random_fields import RandomField

In [11]:
world = WorldConfig({'v': 6, 't':36,'h': 64, 'w':120}, {'vv':6, 'tt':6,'hh':4, 'ww':4}, batch_size=4)
objective = ObjectiveConfig()
network = NetworkConfig(512, 8, 256, dim_noise=32, num_tails=1)

In [None]:
import torch
import einops
from utils.config import WorldConfig, ObjectiveConfig, default
    
class ForecastMasking(torch.nn.Module):
    def __init__(self, world: WorldConfig, objective: ObjectiveConfig):
        super().__init__()
        self.world = world
        self.objective = objective
        self.event_pattern = 't'
        self.register_buffer("prefix_frames", torch.tensor(objective.tau, dtype = torch.long))
        self.register_buffer("total_frames", torch.tensor(world.token_sizes["t"], dtype = torch.long))
    
    def forward(self, shape: tuple, return_indices: bool = True):
        mask = torch.zeros((self.total_frames,), device = self.total_frames.device)
        mask[:self.prefix_frames] = 1
        mask = einops.repeat(mask, f'{self.event_pattern} -> {self.world.flat_token_pattern}', **self.world.token_sizes)
        if return_indices:
            indices = mask.nonzero(as_tuple=True)[0]
            return indices.expand(*shape,-1)
        else:
            return mask.expand(*shape, -1).bool().logical_not()
    
# KUMARASWAMY DISTRIBUTION
class StableKumaraswamy(torch.nn.Module):
    '''Numerically stable methods for Kumaraswamy sampling, courtesy of Wasserman et al 2024'''
    def __init__(self, c1: float = 1., c0: float = 1., epsilon=1e-3):
        super().__init__()
        assert c1 > 0. and c0 > 0., 'invalid concentration'
       
        # Register hyperparameters as buffers for device consistency
        self.register_buffer("c1", torch.as_tensor(c1))
        self.register_buffer("c0", torch.as_tensor(c0))
        self.register_buffer("epsilon", torch.as_tensor(epsilon))
    
    # Kumaraswamy with log1mexp
    @staticmethod
    def log1mexp(t: torch.FloatTensor): #numerically stable log(1 - e**x)
        return torch.where(
        t < -0.69314718, #~ -log2
        torch.log1p(-torch.exp(t)), 
        torch.log(-torch.expm1(t))
    )

    def quantile_dt(self, t: torch.Tensor): # time derivative of the quantile function
        #(1 - t)**((1 - c0) / c0) * (1 - (1 - t)**(1 / c0))**((1 - c1) / c1) / (c1 * c0)
        log_1_minus_t = torch.log1p(-t) # 1 - t
        log_constant = -self.c1.log() - self.c0.log() # 1 / c0 * c1
        log_outer = log_1_minus_t * ((1 - self.c0) / self.c0) # (1 - t)**(1-c0)/c0
        log_inner = ((1 - self.c1) / self.c1) * self.log1mexp(log_1_minus_t / self.c0)
        return torch.exp(log_constant + log_inner + log_outer)

    def quantile(self, t: torch.Tensor): # (1 - (1 - t)**(1 / c0))**(1 / c1)
        return torch.exp(self.log1mexp(torch.log1p(-t) / self.c0) / self.c1)
    
    def cdf(self, t: torch.Tensor): # 1 - (1 - t**c1)**c0
        return -torch.expm1(self.c0 * self.log1mexp(self.c1 * t.log()))
    
    def forward(self, shape: tuple, rng: torch.Generator = None):
        t = torch.rand(shape, device=self.epsilon.device, generator= rng)
        t = t * (1.0 - 2.0 * self.epsilon) + self.epsilon
        return self.quantile(t), self.quantile_dt(t)

# MASKING STRATEGIES
class MultinomialMasking(torch.nn.Module):
    def __init__(self, world: WorldConfig, objective: ObjectiveConfig):
        super().__init__()
        # configs
        self.world = world
        self.objective = objective

        #schedule
        self.schedule = StableKumaraswamy(c0=objective.c0, c1=objective.c1)
        self.prior = StableKumaraswamy(c1= objective.alpha)

        # attributes
        self.k_min = default(objective.k_min, 1)
        self.k_max = default(objective.k_max, world.num_tokens)

        # Events
        assert all([d in world.flat_token_pattern for d in objective.event_dims]), 'event dims not in token pattern'
        self.num_events = torch.tensor([world.token_sizes[d] for d in objective.event_dims]).prod()
        self.event_pattern = f'({" ".join(objective.event_dims)})'

    def forward(self, shape: tuple, rng: torch.Generator = None):
        p, _ = self.prior((*shape, self.num_events), rng)
        p = einops.repeat(p, f'... {self.event_pattern} -> ... {self.world.flat_token_pattern}', 
                             **self.world.token_sizes)
        r, w = self.schedule((1,), rng)
        k = self.k_min + (self.k_max - self.k_min) * r
        indices = torch.multinomial(p, int(k), generator=rng)
        mask = torch.ones_like(p, dtype= torch.bool).scatter_(1, indices, False)
        return indices, mask, w

In [13]:
device = 'cpu'
ks = MultinomialMasking(world=world, objective=objective).to(device)
model = EinMask(network, world).to(device)

In [14]:
data = torch.randn((world.batch_size, *tuple(world.field_sizes[ax] for ax in world.field_layout)), device = device)
print(data.shape)

torch.Size([4, 6, 36, 64, 120])


In [15]:
src, _, weight = ks((world.batch_size,))
tgt, mask, _ = ks((world.batch_size,))

In [16]:
xs = model(data, src, tgt, members = 2)