This tutorial is about demonstrating the preprocessing capabilities with spark.
To run this tutorial, please make sure that [Apache Spark](https://spark.apache.org/) along with [pyspark](https://spark.apache.org/docs/latest/api/python/getting_started/install.html) is installed.
Installation instructions for spark can be found for example [here for Ubuntu](https://phoenixnap.com/kb/install-spark-on-ubuntu) or [here for Mac](https://medium.com/beeranddiapers/installing-apache-spark-on-mac-os-ce416007d79f).

In [None]:
try:
    import squirrel
    import squirrel_datasets_core
    import numpy as np
    import matplotlib.pyplot as plt
except:
    !pip install -q --ignore-requires-python --upgrade squirrel-datasets-core numpy matplotlib # noqa
    import squirrel
    import squirrel_datasets_core
    import matplotlib.pyplot as plt

print(squirrel.__version__)
print(squirrel_datasets_core.__version__)

Any squirrel `Composable` can be used as input to a preprocessing pipeline and processed. 
For this example we will use the `TorchvisionDriver` and the `CIFAR10` dataset. 
The `get_spark` method can be used to easily get access to a spark session.

In [None]:
from squirrel.catalog import Catalog
from squirrel_datasets_core.spark import get_spark

it = Catalog.from_plugins()["cifar10"].get_driver().get_iter()
spark_session = get_spark("preprocess-cifar")

The method `save_composable_to_shards` is used for processing the data and saving it to the fast messagepack format. 
With the `hooks` parameter a list of functions can be specified to transform the data. Here we simply convert the PIL Image from the `TorchvisionDriver` to a numpy array as an example. 
The output can be saved to the local disk as in this case and also to a Google Cloud bucket directly. 
The number of shards for the messagepack format should be specified as well.  

In [None]:
from squirrel_datasets_core.preprocessing.save_shards import save_composable_to_shards

local_store = "cifar_local"
num_shards = 10


def map_image_to_np(sample):
    return np.array(sample[0]), sample[1]


save_composable_to_shards(it, spark_session, local_store, num_shards, hooks=[map_image_to_np])

As the data has been processed and saved locally, it can now be loaded using the squirrel `MessagepackDriver`:

In [None]:
from squirrel.driver.msgpack import MessagepackDriver

it_msgpack = MessagepackDriver(local_store).get_iter()
sample = it.take(1).collect()[0]

plt.title(f"Class: {sample[1]}")
plt.imshow(sample[0])

In the following, we can compare the loading speed for the full dataset using the `TorchvisionDriver` to the `MessagepackDriver`.

In [None]:
# measure time to load full dataset (it/s) with torchvision driver (default in squirrel catalog)
it.tqdm().collect()

In [None]:
# measure time to load full dataset (it/s) with messagepack driver from local store
it_msgpack.tqdm().collect()