Sometimes, it is usefull to store your data into different stores based on a categorical label of your data. In this notebook, we demonstrate how this can be done using the additional help of Spark.

In [None]:
!pip install squirrel-core pyspark
!pip install more-itertools

In [None]:
import tempfile
from random import randint
from functools import partial
from pyspark.sql import SparkSession
from squirrel.store import SquirrelStore
from squirrel.serialization import MessagepackSerializer
from squirrel.iterstream import IterableSource, FilePathGenerator

In [None]:
def generate_categorical_samples(N):
    """Generate data where the uid field is used as a categorical label to split"""
    return [{"uid": randint(1, 10), "data": 0} for _ in range(N)]


def save_shards(tuple_, shard_size, uri):
    """Used as a partial function to save the data into a different store based on the uid"""
    key = tuple_[0]
    store = SquirrelStore(url=f"{uri}/{key}", serializer=MessagepackSerializer())
    iterab = tuple_[1]
    store.set(value=iterab, key=key)


N_SHARDS = 50
# Generate samples
samples = IterableSource(generate_categorical_samples(100))

# Initiate Spark
spark = SparkSession.builder.appName("test").getOrCreate()
rdd = spark.sparkContext.parallelize(samples)
with tempfile.TemporaryDirectory() as tempdir:

    def to_list(a):
        return [a]

    def append(a, b):
        a.append(b)
        return a

    def extend(a, b):
        a.extend(b)
        return a

    _ = (
        rdd.map(lambda x: (x["uid"], x))
        .combineByKey(to_list, append, extend)
        .foreach(partial(save_shards, uri=f"{tempdir}", shard_size=100))
    )
    # We can see that each uid now has its own storage URI
    print(FilePathGenerator(tempdir, nested=True).collect())