In [1]:
import torch
import einops
import math

from einops.layers.torch import Rearrange
from dataclasses import dataclass

from utils.config import NetworkConfig, WorldConfig
from utils.components import *

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

In [2]:
from einops.layers.torch import EinMix

In [3]:
from utils.components import InterfaceBlock

In [25]:
world = WorldConfig(
    {"v": 4, "t": 36, 'h': 64, 'w': 120}, 
    {'vv': 1, 'tt': 6, 'hh': 4, "ww": 4}, 
    tau= 2, alphas={}, timestep= 'framewise', mask= 'bernoulli', schedule='cosine')

In [None]:
D = 128
to_tokens = EinMix(f'b {world.flat_token_pattern} {world.flat_patch_pattern} -> b {world.flat_token_pattern} d',
                    weight_shape = f'v {' '.join(world.patch_layout)} d',
                    bias_shape='v d',
                    d = D, **world.patch_sizes, **world.token_sizes)
        
to_field = EinMix(f'b {world.flat_token_pattern} d -> b {world.flat_token_pattern} {world.flat_patch_pattern} e',
                    weight_shape = f'v {' '.join(world.patch_layout)} e d',
                    e = 4,
                    d = D, **world.patch_sizes, **world.token_sizes)

In [33]:
isinstance(to_tokens.weight, torch.nn.Parameter)

True

In [27]:
print(to_tokens.weight.shape)
print(to_field.weight.shape)


torch.Size([4, 1, 6, 4, 4, 128])
torch.Size([4, 1, 6, 4, 4, 4, 128])


In [28]:
field = torch.randn((1, world.num_tokens, world.dim_tokens))

In [29]:
tokens = to_tokens.cuda()(field.cuda())
rec = to_field.cuda()(tokens)

In [10]:

def octahedral_grid(n_lats: int):
    """
    Create an octahedral reduced Gaussian grid with n_lats latitude bands
    (one hemisphere; the function mirrors it so there are 2*n_lats bands in total).
    """
    n_lats = 2 * n_lats  # include the southern hemisphere
    latitudes = np.linspace(90, -90, n_lats)
    o_rgg = []
    for i in range(n_lats):
        # Use endpoint=False so that 0 and 360 are not both included.
        if i < n_lats // 2:
            n_lons = 4 * i + 16
        else:
            n_lons = 4 * (n_lats - i - 1) + 16
        longitudes = np.linspace(0, 360, n_lons, endpoint=False)
        for lon in longitudes:
            o_rgg.append((latitudes[i], lon))
    return np.array(o_rgg, dtype=np.float32)

In [11]:
og = octahedral_grid(48)