In [6]:
import torch
import einops

from utils.objective import AnyOrder_RIN
from utils.config import *
from utils.loss_fn import f_kernel_crps

In [2]:
world = WorldConfig(
    field_sizes = {'v': 4, 't': 36, 'h': 64, 'w': 120},
    patch_sizes={'vv': 4, 'tt': 6, 'hh': 4, 'ww': 4},
    batch_size=2,
    tau=2,
)

model = NetworkConfig(
    dim = 128,
    num_layers=4,
    num_compute_blocks=4,
    num_latents=256,
    dim_noise=32,
)


In [3]:
ao_rin = AnyOrder_RIN(world = world, model =model, device = 'cpu')

In [29]:
tokens = torch.randn((world.batch_size, world.num_tokens, world.dim_tokens), device = ao_rin.device)

pred, mask, weights = ao_rin(tokens, num_steps = 2, num_ensemble = 4, prior = 'dirichlet', schedule = 'cosine')

tokens, pred = map(lambda x: einops.rearrange(x, 
                                              f'{world.flatland_pattern} ... -> {world.field_pattern} ...', 
                                              **world.token_sizes, **world.patch_sizes), 
                                              (tokens, pred)
                                              )
mask = einops.repeat(mask, f'b {world.flat_token_pattern} -> {world.field_pattern}', **world.token_sizes, **world.patch_sizes)
weights = einops.repeat(weights, f'b 1 -> {world.field_pattern}', **world.token_sizes, **world.patch_sizes)

In [30]:
print(pred.shape, mask.shape, weights.shape)

torch.Size([2, 4, 36, 64, 120, 4]) torch.Size([2, 4, 36, 64, 120]) torch.Size([2, 4, 36, 64, 120])


In [32]:
loss = f_kernel_crps(observation=tokens, ensemble=pred, fair=True)
masked_loss = (loss * weights)[mask].mean()

In [33]:
loss.shape

torch.Size([2, 4, 36, 64, 120])