In [None]:
%load_ext jupyter_black

In [None]:
import torch
import math
import itertools
from model.utils import AutoPatcher

B = 1
C = 3
T = 2
Z = 2
Y = 4
X = 4
input_shape = (T, Z, Y, X)
in_channels = C
out_channels = C
batch_size = B
# -- __init__ --
kernel_size = (1, 1, 1, 1)
stride = (1, 1, 1, 1)
padding = (0, 0, 0, 0)
dilation = (1, 1, 1, 1)
groups = 1
weight = torch.nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *kernel_size))
bias = False
layers = torch.nn.ModuleList()
if bias:
    bias = torch.nn.Parameter(torch.Tensor(out_channels))
else:
    bias = None
for i in range(kernel_size[0]):
    # Initialize a Conv3D layer
    m = torch.nn.Conv3d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size[1::],
        padding=padding[1::],
        dilation=dilation[1::],
        stride=stride[1::],
        bias=False,
    )
    m.weight = torch.nn.Parameter(weight[:, :, i, :, :])

    # Store the layer
    layers.append(m)
dims = tuple(
    itertools.starmap(
        lambda i, k, p, d, s: (i + 2 * p - (k) - (k - 1) * (d - 1)) // s + 1,
        zip(input_shape, kernel_size, padding, dilation, stride),
    )
)

zeros = torch.zeros(B, out_channels, *dims)
if bias is not None:
    out = zeros + bias.view(1, -1, 1, 1, 1, 1)
out = zeros.clone()
z = k0, p0, d0, s0 = tuple(x[0] for x in (kernel_size, padding, dilation, stride))


__x = torch.randn(B, C, T, Z, Y, X)


l_i = __x.shape[2]

for i in range(k0):
    # Calculate the zero-offset of kernel frame i
    zero_offset = -p0 + (i * d0)
    # Calculate the range of input frame j corresponding to kernel frame i
    j_start = max(zero_offset % s0, zero_offset)
    j_end = min(l_i, l_i + p0 - (k0 - i - 1) * d0)
    # Convolve each kernel frame i with corresponding input frame j
    for j in range(j_start, j_end, s0):
        # Calculate the output frame
        out_frame = (j - zero_offset) // s0
        # Add results to this output frame
        out[:, :, out_frame, :, :, :] += layers[i](__x[:, :, j, :, :])


out, zeros

In [5]:

import torch
from model.conv4d import Conv4d
from model.utils import get_patch_encoding_functions

B = 5
C = 3
T = 2
Z = 4
Y = 4
X = 4
x = torch.randn(B, C, T, Z, Y, X)
assert len(x.shape) != 3
encode,decode = get_patch_encoding_functions(B, C, (T, Z, Y, X), (2, 2, 2, 2))
assert len(encode(x).shape) == 3
assert torch.all(decode(encode(x)) == x).item()
# model = AutoPatchEmbed(
#     input_shape=(T, Z, Y, X),
#     patch_shape=(1, 1, 1, 1),
#     in_channels=C,
# )  # in_channels=C, out_channels=4, input_shape=(T, Z, Y, X), kernel_size=(1, 1, 1, 1))



# z  = model.__call__(x)
# z

# # x.shape[0:3:2]

In [None]:
print(x.shape)
x.flatten(3).moveaxis(1, -1).shape, torch.einsum("bczs->bzsc", x.flatten(3)).shape

torch.all(x.flatten(3).moveaxis(1, -1) == torch.einsum("bczs->bzsc", x.flatten(3)))

In [None]:
import os

import torch
import numpy as np

import mesoscaler as ms


_local_data = os.path.abspath("data")

urma_store = os.path.join(_local_data, "urma.zarr")
assert os.path.exists(urma_store)
era5_store = os.path.join(_local_data, "era5.zarr")
assert os.path.exists(era5_store)

from mesoscaler.enums import (
    # - ERA5
    GEOPOTENTIAL,
    SPECIFIC_HUMIDITY,
    TEMPERATURE,
    U_COMPONENT_OF_WIND,
    V_COMPONENT_OF_WIND,
    # - URMA
    SURFACE_PRESSURE,
    TEMPERATURE_2M,
    SPECIFIC_HUMIDITY_2M,
    U_WIND_COMPONENT_10M,
    V_WIND_COMPONENT_10M,
    SURFACE_PRESSURE,
)

era5_dvars = [
    GEOPOTENTIAL,
    TEMPERATURE,
    SPECIFIC_HUMIDITY,
    U_COMPONENT_OF_WIND,
    V_COMPONENT_OF_WIND,
]

urma_dvars = [
    SURFACE_PRESSURE,
    TEMPERATURE_2M,
    SPECIFIC_HUMIDITY_2M,
    U_WIND_COMPONENT_10M,
    V_WIND_COMPONENT_10M,
]

dataset_sequence = ms.open_datasets([(urma_store, urma_dvars), (era5_store, era5_dvars)])
dataset_sequence

In [None]:
from model.mae import MaskedAutoencoder3d

width = 80
height = 40
distance_ratio = 2.5  # km
patch_ratio = 0.2

patch_size = (
    int(width * patch_ratio),
    int(height * patch_ratio),
)

dx = int(width * distance_ratio)
dy = int(height * distance_ratio)

levels = [
    1013.25,
    1000,
    925,
    850,
]
img_size = (width, height)
texas = aoi = -106.6, 25.8, -93.5, 36.5

config = {
    "img_size": img_size,
    "patch_size": patch_size,
}
config, dx, dy

In [None]:
scale = ms.Mesoscale(dx, dy, levels=levels)
scale

In [None]:
from torch.utils.data import DataLoader, IterableDataset


class Dataset(IterableDataset):
    def __init__(self, scale: ms.Mesoscale) -> None:
        super().__init__()

        self.resampler = resampler = scale.resample(dataset_sequence)
        self.indices = ms.AreaOfInterestSampler(
            resampler.domain,
            aoi=(-106.6, 25.8, -93.5, 36.5),
        )

    def __iter__(self):
        for (lon, lat), time in self.indices:
            yield self.resampler(lon, lat, time)


ds = Dataset(scale)

In [None]:
b = next(iter(ds))
b.shape