In [1]:
%load_ext jupyter_black

In [2]:
from __future__ import annotations

import time
import itertools
import random
import numpy as np

from mesoscaler.generic import DataWorker, DataConsumer
from mesoscaler._typing import Array, N, N4

In [3]:
Index = tuple[int, int, int, int]


DUMMY_DATA: dict[Index, np.ndarray] = {
    # some dummy data either on disk or in
    k: np.random.rand(100, 100)
    for k in itertools.product([1, 2, 3], [4, 5, 6], [4, 5, 6], [4, 5, 6])
}


def get_data_from_disk(key: Index) -> np.ndarray:
    time.sleep(0.1)  # io latency
    return DUMMY_DATA[key]

In [4]:
indices = random.choices(list(DUMMY_DATA.keys()), k=20)

print(len(DUMMY_DATA), len(indices))

81 20


The worker class is `Mapping` that is instantiated with a Iterable sequence of indices.

In [5]:
from numpy.typing import NDArray
from typing import NewType

_100 = NewType("100", int)  # type: ignore


class MyWorker(DataWorker[Index, Array[[_100, _100], np.float_]]):
    def __getitem__(self, idx: Index) -> Array[[_100, _100], np.float_]:
        return get_data_from_disk(idx)


worker = MyWorker(indices=indices)
data = worker[indices[0]]
print(worker, data.shape, sep="\n")

MyWorker(size=20):
- (2, 5, 4, 5): ndarray[(100, 100), dtype[float64]]
- (1, 4, 5, 4): ndarray[(100, 100), dtype[float64]]
- (1, 5, 4, 5): ndarray[(100, 100), dtype[float64]]
- (2, 4, 5, 6): ndarray[(100, 100), dtype[float64]]
- (2, 5, 4, 4): ndarray[(100, 100), dtype[float64]]
...
- (2, 4, 6, 6): ndarray[(100, 100), dtype[float64]]
(100, 100)


In [8]:
train, test = worker.split(0.8)
test

MyWorker[test](size=4):
- (1, 4, 4, 5): ndarray[(100, 100), dtype[float64]]
- (1, 5, 5, 5): ndarray[(100, 100), dtype[float64]]
- (2, 5, 6, 5): ndarray[(100, 100), dtype[float64]]
- (2, 4, 6, 6): ndarray[(100, 100), dtype[float64]]
...
- (2, 4, 6, 6): ndarray[(100, 100), dtype[float64]]

# DataConsumer

Assuming there is some IO bottle neck involved with loading data from disk the
`DataConsumer` can be used as a DataLoader that will queue up the data to be
loaded in the background while the model is training.

In [10]:
start = time.time()
for idx in worker:
    data = worker[idx]
    time.sleep(0.1)
print("worker:", time.time() - start)

start = time.time()
for x in DataConsumer(worker):
    time.sleep(0.1)
print("consumer:", time.time() - start)

worker: 4.006312370300293
consumer: 2.104593276977539
