In [22]:
import concurrent.futures as futures
import json
import os
import random
import tempfile
from collections import Counter
from typing import Generator, Iterator, Tuple

import numpy as np
import pytest
import torch.utils.data as tud

from squirrel.driver import JsonlDriver
from squirrel.iterstream.torch_composables import SplitByWorker, TorchIterable, skip_k
from squirrel.serialization import JsonSerializer
from squirrel.store import SquirrelStore

N_SHARDS = 20
MIN_SAMPLES_PER_SHARD = 50
MAX_SAMPLES_PER_SHARD = 100

def create_data(test_folder: str) -> int:
    """Helper function to create test data."""

    def create_shard(id_: int, min: int, max: int) -> Tuple[int, int]:
        num_samples = random.Random().randint(a=min, b=max)
        store = SquirrelStore(url=test_folder, serializer=JsonSerializer())
        shard = [{"shard_idx": id_, "sample_idx": idx} for idx in range(num_samples)]
        store.set(key=str(id_), value=shard)
        return id_, num_samples

    os.makedirs(test_folder, exist_ok=True)

    with futures.ThreadPoolExecutor(max_workers=4) as pool:
        futs = [pool.submit(create_shard, idx, MIN_SAMPLES_PER_SHARD, MAX_SAMPLES_PER_SHARD) for idx in range(N_SHARDS)]

    n_samples = 0
    for fut in futures.as_completed(futs):
        shard_id, _n_samps = fut.result()
        assert MIN_SAMPLES_PER_SHARD <= _n_samps <= MAX_SAMPLES_PER_SHARD
        n_samples += _n_samps
    
    print(n_samples)
    return n_samples

def test_data() -> Tuple[str, int]:
    """Fixture for this modules test data"""

    test_folder = tempfile.TemporaryDirectory()

    n_samples = create_data(test_folder=test_folder.name)
    return test_folder.name, n_samples

test_data_folder, total_samples = test_data()

it = (
    JsonlDriver(test_data_folder)
    # In order to not have duplication of all kinds of buffers, we need to limit the buffer size
    # Increasing the buffer will lead to failing tests due to over-sampling.
    .get_iter(shuffle_key_buffer=0, prefetch_buffer=0, shuffle_item_buffer=0, max_workers=0)
    # .map(print)
    # # change Image.open handles to np arrays which is pickable,
    # # allowing python to send the objects to other processes.
    .compose(SplitByWorker)
    .shuffle(size=400)
    .take(n=N_SHARDS * MIN_SAMPLES_PER_SHARD)
    .compose(TorchIterable)
)

# dl = tud.DataLoader(it, batch_size=None, num_workers=4)
# elems = [json.dumps(item) for item in dl]
for item in it:
    print(item)

1451
