[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/merantix/mxlabs-datasets/blob/main/examples/Squirrel_Tutorial_Create_Squirrel_Store.ipynb)

In [None]:
# !pip install keyring keyrings.google-artifactregistry-auth
# from google.colab import auth
# auth.authenticate_user()

In [None]:
# !pip install mxlabs-squirrel squirrel-datasets dask --extra-index=https://europe-west1-python.pkg.dev/mx-labs-devops/labs-pypi-registry/simple/ --ignore-requires-python --upgrade

If you have not already, refer to the documentation page for `Store` to cover the basics first.

## Reading and writing to the store using Spark

Squirrel makes it a breeze to scale out any data workload.
To illustrate this using Spark, we:

1. Create a data source

2. Initialize a Driver that can read from the data source

3. Construct an RDD in Spark from the data loaded using the Driver

4. Create a DataFrame from the RDD and write data into shards using SquirrelStore

Let's first create some dummy data and save it into a .csv file.

In [None]:
import tempfile

import numpy as np
import pandas as pd


N_SAMPLES = 1_000


def create_sample():
    return {
        "name": np.random.choice(["John", "Jane"]),
        "identifier": int(np.random.choice([1, 2])),
        "age": int(np.random.choice([20, 30])),
    }


samples = [create_sample() for _ in range(N_SAMPLES)]

tmpdir = tempfile.TemporaryDirectory()
csv_path = f"{tmpdir.name}/my_source.csv"
pd.DataFrame(samples).to_csv(csv_path, index=False)

Now, we can read the source using the CsvDriver.

In [None]:
from squirrel.driver import CsvDriver

driver = CsvDriver(csv_path)
df = pd.DataFrame(driver.get_iter().collect()).set_index("Index")
df.head()

We will convert the data into an RDD first.
For this we need to provide a schema.
Note that the schema is the only data-source-specific part in the pipeline.
As long as we provide the correct schema, we can use the pipeline with any driver or store.

In [None]:
from pyspark.sql.types import StructType, StructField, IntegerType, StringType

SCHEMA = StructType(
    [
        StructField("index", IntegerType(), False),
        StructField("name", StringType(), False),
        StructField("identifier", StringType(), False),
        StructField("age", IntegerType(), False),
    ]
)

In [None]:
from pyspark.sql import SparkSession


spark = SparkSession.builder.appName("test").getOrCreate()
parallel_collection_rdd = spark.sparkContext.parallelize(driver.get_iter())
df = spark.createDataFrame(parallel_collection_rdd, SCHEMA)

RDD is ready.
Now we can write the data into shards.
We opt for the SquirrelStore that is used by the MessagepackDriver here.

In [None]:
from functools import partial

from squirrel.serialization import MessagepackSerializer
from squirrel.store import SquirrelStore


def save_iterable_as_shard(it, url) -> None:
    """Helper to save a shard into a messagepack store using squirrel."""
    SquirrelStore(url, serializer=MessagepackSerializer()).set(value=list(it))


tmpdir2 = tempfile.TemporaryDirectory()
N_SHARDS = 10

_ = (
    df.rdd.map(lambda row: row.asDict())
    .coalesce(N_SHARDS)
    .foreachPartition(partial(save_iterable_as_shard, url=tmpdir2.name))
)

Clean up

In [None]:
tmpdir.cleanup()
tmpdir2.cleanup()