In [1]:
%load_ext jupyter_black

In [2]:
from __future__ import annotations

import time
import itertools
import random
import numpy as np

from mesoscaler.generic import DataWorker, DataGenerator
from mesoscaler._typing import Array

In [3]:
Index = tuple[int, int, int, int]


DUMMY_DATA: dict[Index, np.ndarray] = {
    # some dummy data either on disk or in
    k: np.random.rand(100, 100)
    for k in itertools.product([1, 2, 3], [4, 5, 6], [4, 5, 6], [4, 5, 6])
}


def get_data_from_disk(key: Index) -> np.ndarray:
    time.sleep(0.1)  # io latency
    return DUMMY_DATA[key]

In [4]:
indices = random.choices(list(DUMMY_DATA.keys()), k=20)

print(len(DUMMY_DATA), len(indices))

81 20


The worker class is `Mapping` that is instantiated with a Iterable sequence of indices.

In [5]:
from numpy.typing import NDArray
from typing import NewType

_100 = NewType("100", int)  # type: ignore


class MyWorker(DataWorker[Index, Array[[_100, _100], np.float_]]):
    def __getitem__(self, idx: Index) -> Array[[_100, _100], np.float_]:
        return get_data_from_disk(idx)


worker = MyWorker(indices=indices)
data = worker[indices[0]]
print(worker, data.shape, sep="\n")

MyWorker(size=20):
- (3, 5, 5, 4): ndarray[(100, 100), dtype[float64]]
- (2, 6, 6, 4): ndarray[(100, 100), dtype[float64]]
- (1, 4, 5, 6): ndarray[(100, 100), dtype[float64]]
- (3, 4, 5, 5): ndarray[(100, 100), dtype[float64]]
- (3, 6, 4, 4): ndarray[(100, 100), dtype[float64]]
...
- (2, 5, 5, 6): ndarray[(100, 100), dtype[float64]]
(100, 100)


In [6]:
train, test = worker.split(0.8)
test

MyWorker[test](size=4):
- (3, 5, 4, 6): ndarray[(100, 100), dtype[float64]]
- (2, 4, 4, 6): ndarray[(100, 100), dtype[float64]]
- (3, 6, 6, 6): ndarray[(100, 100), dtype[float64]]
- (2, 5, 5, 6): ndarray[(100, 100), dtype[float64]]
...
- (2, 5, 5, 6): ndarray[(100, 100), dtype[float64]]

# DataConsumer

Assuming there is some IO bottle neck involved with loading data from disk the
`DataConsumer` can be used as a DataLoader that will queue up the data to be
loaded in the background while the model is training.

In [7]:
start = time.time()
for idx in worker:
    data = worker[idx]
    time.sleep(0.1)
print("worker:", time.time() - start)

start = time.time()
for x in DataGenerator(worker):
    time.sleep(0.1)
print("consumer:", time.time() - start)

worker: 4.006197929382324
consumer: 2.1038269996643066


Combining the Resampling pipeline with the data consumer.

In [8]:
import os
from mesoscaler.core import Mesoscale, P0, DependentDataset, DataProducer
import time
from mesoscaler.enums import (
    # - ERA5
    GEOPOTENTIAL,
    SPECIFIC_HUMIDITY,
    TEMPERATURE,
    U_COMPONENT_OF_WIND,
    V_COMPONENT_OF_WIND,
    # - URMA
    SURFACE_PRESSURE,
    TEMPERATURE_2M,
    SPECIFIC_HUMIDITY_2M,
    U_WIND_COMPONENT_10M,
    V_WIND_COMPONENT_10M,
    SURFACE_PRESSURE,
)

_test_data = "../tests/data"

urma_store = os.path.join(_test_data, "urma.zarr")
urma_dvars = [
    SURFACE_PRESSURE,
    TEMPERATURE_2M,
    SPECIFIC_HUMIDITY_2M,
    U_WIND_COMPONENT_10M,
    V_WIND_COMPONENT_10M,
]

era5_store = os.path.join(_test_data, "era5.zarr")
era5_dvars = [
    GEOPOTENTIAL,
    TEMPERATURE,
    SPECIFIC_HUMIDITY,
    U_COMPONENT_OF_WIND,
    V_COMPONENT_OF_WIND,
]

In [9]:
dx, dy = np.array([200, 175])
pressure_levels = [P0, 1000.0, 925.0, 850.0, 700.0, 500.0, 300.0]
scale = Mesoscale(dx, dy, levels=pressure_levels, rate=15)

era5 = DependentDataset.from_zarr(era5_store, era5_dvars)  # get datasets
urma = DependentDataset.from_zarr(urma_store, urma_dvars)  # get datasets

In [10]:
import mesoscaler as ms

N_SAMPLES = 5
start, stop = urma.time.to_numpy().astype("datetime64[h]")  # sample data only has 2 times


lons = np.random.choice(((urma.lons - 180) % 360 - 180).to_numpy().ravel(), N_SAMPLES)
lats = np.random.choice(urma.lats.to_numpy().ravel(), N_SAMPLES)

indices = zip(
    np.c_[lons, lats].round(2),
    itertools.repeat(np.s_[start:stop]),
)

worker = ms.create.producer([urma, era5], indices=indices)  # DataProducer(indices, urma, era5, scale=scale)
worker

DataProducer(size=5):
-   ([-85.34  31.25], 2019-01-02T00:00:00Z:2019-01-02T01:00:00Z:None): ndarray[(Nv, Nt, Nz, Ny, Nx), dtype[float64]]
- ([-130.67   32.47], 2019-01-02T00:00:00Z:2019-01-02T01:00:00Z:None): ndarray[(Nv, Nt, Nz, Ny, Nx), dtype[float64]]
- ([-118.24   46.65], 2019-01-02T00:00:00Z:2019-01-02T01:00:00Z:None): ndarray[(Nv, Nt, Nz, Ny, Nx), dtype[float64]]
- ([-106.63   25.14], 2019-01-02T00:00:00Z:2019-01-02T01:00:00Z:None): ndarray[(Nv, Nt, Nz, Ny, Nx), dtype[float64]]
- ([-127.65   48.43], 2019-01-02T00:00:00Z:2019-01-02T01:00:00Z:None): ndarray[(Nv, Nt, Nz, Ny, Nx), dtype[float64]]
...
- ([-127.65   48.43], 2019-01-02T00:00:00Z:2019-01-02T01:00:00Z:None): ndarray[(Nv, Nt, Nz, Ny, Nx), dtype[float64]]

In [11]:
idx = worker.indices[0]
worker[idx]

array([[[[[            nan,             nan,             nan, ...,
                       nan,             nan,             nan],
          [            nan,             nan,             nan, ...,
                       nan,             nan,             nan],
          [            nan,             nan,             nan, ...,
                       nan,             nan,             nan],
          ...,
          [            nan,             nan,             nan, ...,
                       nan,             nan,             nan],
          [            nan,             nan,             nan, ...,
                       nan,             nan,             nan],
          [            nan,             nan,             nan, ...,
                       nan,             nan,             nan]],

         [[ 8.86587695e+03,  8.72861621e+03,  8.42906348e+03, ...,
            6.63720117e+03,  6.64645508e+03,  6.58310400e+03],
          [ 8.84618359e+03,  8.70216113e+03,  8.46892480e+03, ...,
      

In [12]:
import functools
from mesoscaler._typing import Self
from typing import TypeVar

_T = TypeVar("_T")


class A:
    def __init__(self, x: int, *, world) -> None:
        pass

    @classmethod
    def partial(cls: type[_T], *args, **kwargs) -> functools.partial[_T]:
        return functools.partial(cls, *args, **kwargs)


def f(*, a: int, b: int, c: int) -> int:
    return a + b + c


c = functools.partial(f, a=1, b=2)
x = c(a=0, c=3)
x

5

In [13]:
class C:
    @functools.partialmethod
    def f(self):
        ...

In [14]:
worker.get_array(idx)

In [15]:
worker.get_dataset(idx)

In [16]:
TIME2TRAIN = 1.5

start = time.time()
for idx in worker:
    sample = worker[idx]
    print(sample.shape)
    time.sleep(TIME2TRAIN)
total = time.time() - start
print(
    f"""\
total = {total}
io_time = {total - TIME2TRAIN * N_SAMPLES}
"""
)

start = time.time()
for sample in DataGenerator(worker):
    print(sample.shape)
    time.sleep(TIME2TRAIN)
total = time.time() - start
print(
    f"""\
total = {total}
io_time = {total - TIME2TRAIN * N_SAMPLES}
"""
)

(5, 2, 6, 80, 80)
(5, 2, 6, 80, 80)
(5, 2, 6, 80, 80)
(5, 2, 6, 80, 80)
(5, 2, 6, 80, 80)
total = 14.409813642501831
io_time = 6.909813642501831

(5, 2, 6, 80, 80)
(5, 2, 6, 80, 80)
(5, 2, 6, 80, 80)
(5, 2, 6, 80, 80)
(5, 2, 6, 80, 80)
total = 8.993643283843994
io_time = 1.4936432838439941



In [17]:
from torch.utils.data import DataLoader


for batch in DataLoader(DataGenerator(worker), batch_size=2):
    print(batch.shape)

torch.Size([2, 5, 2, 6, 80, 80])
torch.Size([2, 5, 2, 6, 80, 80])
torch.Size([1, 5, 2, 6, 80, 80])
