In [1]:
%load_ext jupyter_black

In [3]:
from __future__ import annotations

import time
import itertools
import random
import numpy as np

from mesoscaler.generic import DataWorker, DataConsumer
from mesoscaler._typing import Array

In [4]:
Index = tuple[int, int, int, int]


DUMMY_DATA: dict[Index, np.ndarray] = {
    # some dummy data either on disk or in
    k: np.random.rand(100, 100)
    for k in itertools.product([1, 2, 3], [4, 5, 6], [4, 5, 6], [4, 5, 6])
}


def get_data_from_disk(key: Index) -> np.ndarray:
    time.sleep(0.1)  # io latency
    return DUMMY_DATA[key]

In [5]:
indices = random.choices(list(DUMMY_DATA.keys()), k=20)

print(len(DUMMY_DATA), len(indices))

81 20


The worker class is `Mapping` that is instantiated with a Iterable sequence of indices.

In [6]:
from numpy.typing import NDArray
from typing import NewType

_100 = NewType("100", int)  # type: ignore


class MyWorker(DataWorker[Index, Array[[_100, _100], np.float_]]):
    def __getitem__(self, idx: Index) -> Array[[_100, _100], np.float_]:
        return get_data_from_disk(idx)


worker = MyWorker(indices=indices)
data = worker[indices[0]]
print(worker, data.shape, sep="\n")

MyWorker(size=20):
- (1, 4, 4, 5): ndarray[(100, 100), dtype[float64]]
- (1, 5, 4, 5): ndarray[(100, 100), dtype[float64]]
- (3, 5, 5, 6): ndarray[(100, 100), dtype[float64]]
- (1, 4, 5, 6): ndarray[(100, 100), dtype[float64]]
- (3, 6, 5, 6): ndarray[(100, 100), dtype[float64]]
...
- (3, 5, 5, 5): ndarray[(100, 100), dtype[float64]]
(100, 100)


In [7]:
train, test = worker.split(0.8)
test

MyWorker[test](size=4):
- (2, 5, 6, 6): ndarray[(100, 100), dtype[float64]]
- (3, 5, 4, 5): ndarray[(100, 100), dtype[float64]]
- (1, 5, 6, 5): ndarray[(100, 100), dtype[float64]]
- (3, 5, 5, 5): ndarray[(100, 100), dtype[float64]]
...
- (3, 5, 5, 5): ndarray[(100, 100), dtype[float64]]

# DataConsumer

Assuming there is some IO bottle neck involved with loading data from disk the
`DataConsumer` can be used as a DataLoader that will queue up the data to be
loaded in the background while the model is training.

In [8]:
start = time.time()
for idx in worker:
    data = worker[idx]
    time.sleep(0.1)
print("worker:", time.time() - start)

start = time.time()
for x in DataConsumer(worker):
    time.sleep(0.1)
print("consumer:", time.time() - start)

worker: 4.005807399749756
consumer: 2.104191541671753


Combining the Resampling pipeline with the data consumer.

In [None]:
import os
from mesoscaler.core import Mesoscale, P0, DependentDataset, ArrayWorker
import time
from mesoscaler.enums import (
    # - ERA5
    GEOPOTENTIAL,
    SPECIFIC_HUMIDITY,
    TEMPERATURE,
    U_COMPONENT_OF_WIND,
    V_COMPONENT_OF_WIND,
    # - URMA
    SURFACE_PRESSURE,
    TEMPERATURE_2M,
    SPECIFIC_HUMIDITY_2M,
    U_WIND_COMPONENT_10M,
    V_WIND_COMPONENT_10M,
    SURFACE_PRESSURE,
)

_test_data = "../tests/data"

urma_store = os.path.join(_test_data, "urma.zarr")
urma_dvars = [
    SURFACE_PRESSURE,
    TEMPERATURE_2M,
    SPECIFIC_HUMIDITY_2M,
    U_WIND_COMPONENT_10M,
    V_WIND_COMPONENT_10M,
]

era5_store = os.path.join(_test_data, "era5.zarr")
era5_dvars = [
    GEOPOTENTIAL,
    TEMPERATURE,
    SPECIFIC_HUMIDITY,
    U_COMPONENT_OF_WIND,
    V_COMPONENT_OF_WIND,
]

In [None]:
dx, dy = np.array([200, 175])
pressure_levels = [P0, 1000.0, 925.0, 850.0, 700.0, 500.0, 300.0]

scale = Mesoscale(dx, dy, pressure=pressure_levels, rate=15)
era5 = DependentDataset.from_zarr(era5_store, era5_dvars)  # get datasets
urma = DependentDataset.from_zarr(urma_store, urma_dvars)  # get datasets

In [None]:
N_SAMPLES = 5


start, stop = urma.time.to_numpy().astype("datetime64[h]")  # sample data only has 2 times


lons = np.random.choice(((urma.lons - 180) % 360 - 180).to_numpy().ravel(), N_SAMPLES)
lats = np.random.choice(urma.lats.to_numpy().ravel(), N_SAMPLES)

indices = zip(np.c_[lons, lats].round(2), itertools.repeat(np.s_[start:stop]))

worker = ArrayWorker(indices, urma, era5, scale=scale)
worker

In [None]:
idx = worker.indices[0]
worker[idx]

In [None]:
worker.get_array(idx)

In [None]:
worker.get_dataset(idx)

In [None]:
TIME2TRAIN = 1.5


start = time.time()
for idx in worker:
    sample = worker[idx]
    print(sample.shape)
    time.sleep(TIME2TRAIN)
total = time.time() - start
print(
    f"""
total = {total}
io_time = {total - TIME2TRAIN * N_SAMPLES}
"""
)


start = time.time()
for sample in DataConsumer(worker):
    print(sample.shape)
    time.sleep(TIME2TRAIN)
total = time.time() - start
print(
    f"""
total = {total}
io_time = {total - TIME2TRAIN * N_SAMPLES}
"""
)

In [None]:
from torch.utils.data import DataLoader


for batch in DataLoader(DataConsumer(worker)):
    print(batch.shape)
    # time.sleep(TIME2TRAIN)