[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/merantix/mxlabs-datasets/blob/main/examples/Squirrel_Tutorial_Create_Squirrel_Store.ipynb)

In [None]:
!pip install keyring keyrings.google-artifactregistry-auth
# from google.colab import auth
# auth.authenticate_user()

In [None]:
!pip install squirrel-core squirrel-datasets pyspark --extra-index=https://europe-west1-python.pkg.dev/mx-labs-devops/labs-pypi-registry/simple/ --ignore-requires-python --upgrade

# Squirrel Store

`SquirrelStore` is responsible for reading and writing the data. Inheriting from `squirrel.store.AbstractStore`, it defines three main methods, `set`, `get`, and `keys`. `SquirrelStore` requires a serializer. Two serializers are provided, namely `MessagepackSerializer` and `JsonSerializer`, and it's straightforward to write your own. A store can be instantiated as the example below.

In [None]:
import numpy as np
import tempfile

from squirrel.store import SquirrelStore
from squirrel.serialization import MessagepackSerializer
from squirrel.driver import MessagepackDriver

In [None]:
tmpdir = tempfile.TemporaryDirectory()
msg_store = SquirrelStore(url=tmpdir.name, serializer=MessagepackSerializer())

You can get an instance of a store from driver too. This is a recommended approach, unless low-level control is needed.

In [None]:
driver = MessagepackDriver(tmpdir.name)
store = driver.store

In [None]:
# assert isinstance(store, SquirrelStore)
# assert len(list(store.keys())) == 0

# Creating a SquirrelStore


## First approach: SquirrelStore itself
You can use the low-level map interface of the store to achieve this.

In [None]:
def get_sample(i):
    return {"image": np.random.random((3, 3, 3)), "label": np.random.choice([1, 2]), "metadata": {"key": i}}


samples = [get_sample(i) for i in range(100)]
shards = [samples[i : i + 10] for i in range(10)]

In [None]:
for shard in shards:
    store.set(shard)

In [None]:
list(store.keys())

`set()` method accepts an optional argument `key`. If not provided, a random name is automatically assigned. 

In [None]:
for key in store.keys():
    shard = store.get(key)
    for sample in shard:
        print(sample)
        break
    break

In [None]:
tmpdir.cleanup()

## Second approach: Iterstream api

`SquirrelStore` does not buffer any data, as soon as `set()` is called, the data is written to the store. Because of this, writing to the store can be easily paralellized. In the following example, we use `async_map` from `Iterstream` module to write shards to the store in parallel, and read from the store in parallel.

In [None]:
from squirrel.iterstream import IterableSource

tmpdir = tempfile.TemporaryDirectory()
store = MessagepackDriver(tmpdir.name).store

IterableSource(shards).async_map(store.set).join()
# assert len(list(store.keys())) == 10

samples = IterableSource(store.keys()).async_map(store.get).flatten().collect()
# assert len(samples) == 100

In [None]:
tmpdir.cleanup()

## Reading and writing to the store using Dask

Scaling out using dask is as easy as replacing `async_map(store.set)` with `async_map(store.set, executor=dask.distributed.Client())` in the example above.

## Reading and writing to the store using Spark

Squirrel makes it a breeze to scale out any data workload. To illustrate this using Spark, we first create a squirrel store and write some data to it, then read from this store into a spark dataframe, then write back from the dataframe into another store.

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType

In [None]:
tmpdir = tempfile.TemporaryDirectory()
driver = MessagepackDriver(tmpdir.name)
store = driver.store

In [None]:
def get_sample(i):
    return {
        "name": np.random.choice(["John", "Jane"]),
        "identifier": int(np.random.choice([1, 2])),
        "age": int(np.random.choice([20, 30])),
    }


samples = [get_sample(i) for i in range(100)]
shards = [samples[i : i + 10] for i in range(10)]

IterableSource(shards).async_map(store.set).join()

In [None]:
# assert len(list(store.keys())) == 10

In [None]:
spark = SparkSession.builder.appName("test").getOrCreate()

In [None]:
SCHEMA = StructType(
    [
        StructField("name", StringType(), False),
        StructField("identifier", StringType(), False),
        StructField("age", IntegerType(), False),
    ]
)

parallel_collection_rdd = spark.sparkContext.parallelize(driver.get_iter())
df = spark.createDataFrame(parallel_collection_rdd, SCHEMA)

In [None]:
tmpdir.cleanup()
tmpdir = tempfile.TemporaryDirectory()
driver = MessagepackDriver(tmpdir.name)
store = driver.store

In [None]:
from functools import partial


def save_iterable_as_shard(_it, _url) -> None:
    """Helper to save a shard into a messagepack store using squirrel"""
    SquirrelStore(_url, serializer=MessagepackSerializer()).set(value=[i for i in _it])


num_shards = 10
_ = (
    df.rdd.map(lambda row: row.asDict())
    .coalesce(num_shards)
    .foreachPartition(partial(save_iterable_as_shard, _url=tmpdir.name))
)

In [None]:
# assert len(list(store.keys())) == 10