In [None]:
from __future__ import annotations

%load_ext jupyter_black

In [None]:
import os

import mesoscaler as ms
import matplotlib.pyplot as plt
import numpy as np
import mesoscaler as ms
import zarr

_local_data = os.path.abspath("../data")

cities_zarr = os.path.join(_local_data, "cities.zarr")

array = zarr.open_array(os.path.join(cities_zarr, "data"), mode="r")


ms.ZarrAttributes.from_array(array)

In [None]:
from mesoscaler.sampling.display import PlotArray, PlotOption

options = [
    PlotOption("pressure", "contour", dim=0, colors="k", linestyles="-"),
    PlotOption("temperature", "contour", dim=1, colors="r", linestyles="--", linewidths=0.5),
    PlotOption("dp", "contourf", dim=2, cmap="Greens"),
    PlotOption("winds", "barbs", dim=(3, 4), alpha=0.5),
]

PlotArray.from_group(array, 1).all_levels(0, options=options)

In [None]:
import torch.utils.data
from typing import Iterable, Iterator
from mesoscaler._typing import Array, Nv, Nt, Nz, Nx, Ny
import numpy as np


class Dataset(torch.utils.data.Dataset[Array[[Nv, Nt, Nz, Nx, Ny], np.float32]]):
    def __init__(self, array: zarr.Array) -> None:
        super().__init__()
        self.array = array

    def __getitem__(self, index: int) -> np.ndarray:
        return self.array[index]  # type: ignore

    def __len__(self) -> int:
        return len(self.array)

    def __iter__(self) -> Iterator[int]:
        yield from range(len(self))

    @property
    def shape(self) -> tuple[int, ...]:
        return self.array.shape[1:]

    @classmethod
    def from_disk(cls, path: str):
        return cls(zarr.open_array(path, mode="r"))

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}:\n{self.array.info}"

    def _repr_html_(self):
        return self.array.info._repr_html_()

    @property
    def zattrs(self) -> ms.ZarrAttributes:
        return ms.ZarrAttributes.from_array(self.array)


ds = Dataset.from_disk(os.path.join(cities_zarr, "data"))
ds

In [None]:
class RandomSampler(torch.utils.data.Sampler[int]):
    def __init__(self, num_samples: int, generator: torch.Generator | None = None, seed: int | None = None) -> None:
        super().__init__()
        if seed is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
        self.generator = generator or torch.Generator().manual_seed(seed)
        self.num_samples = num_samples

    def __iter__(self) -> Iterator[int]:
        n = self.num_samples
        g = self.generator

        for _ in range(self.num_samples // n):
            yield from map(int, torch.randperm(n, generator=g).numpy())
        yield from map(int, torch.randperm(n, generator=g)[: self.num_samples % n].numpy())

    def __len__(self) -> int:
        return self.num_samples


num_epoch = 10
batch_size = 12
num_workers = 10

loader = torch.utils.data.DataLoader(
    ds,
    batch_size=batch_size,
    num_workers=num_workers,
    sampler=RandomSampler(len(ds)),
    drop_last=bool((len(ds) / batch_size) % 1),
)
for epoch in range(num_epoch):
    for batch in loader:
        print(batch.shape)