[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/merantix/mxlabs-datasets/blob/main/examples/Squirrel_Tutorial_Plugins.ipynb)

# Catalog API

## Install Squirrel and Squirrel Datasets

Uncomment and run the following lines if you are working on Google Colab:

In [None]:
# from google.colab import auth

# auth.authenticate_user()

Uncomment and run the following lines if you have not installed squirrel and squirrel-datasets already:

In [None]:
# !pip install keyring keyrings.google-artifactregistry-auth
# !pip install squirrel-core squirrel-datasets --extra-index=https://europe-west1-python.pkg.dev/mx-labs-devops/labs-pypi-registry/simple/ --ignore-requires-python

For this tutorial, we also need matplotlib

In [None]:
!pip install matplotlib

## Squirrel-datasets Catalog

Squirrel-datasets comes with built-in data sources that you can use.
Let's check them out:

In [None]:
from squirrel.catalog import Catalog


# loads a catalog that collects datasets from installed plug-ins such as squirrel-datasets
cat = Catalog.from_plugins()
print(list(cat.sources))

That is quite a lot, right?

Catalog is a collection of sources.
It takes care of maintaining different sources and keeps track of all versions of the same source.

To get more information about a Source, you can simply index the Catalog:

In [None]:
cat["cifar10"]

Catalog returned us the latest version of CIFAR-10 it stores.
Note that:
- a driver called `"torchvision"` will be used to read from this source
- the driver will be passed the keyword arguments `name="CIFAR10"` and `download=True`
- no metadata was provided for the dataset
- the latest version stored in the catalog is `v2`

As mentioned earlier, it can be the case that we have multiple versions for the same data source.
For example, we can keep the raw dataset as version 1 and also store the cleansed version of it as version 2.

To see all versions of a source, we can use `Catalog.get_versions()`.

In [None]:
print(len(cat["cifar10"]))
cat["cifar10"][1], cat["cifar10"][2]

You can see that to load CIFAR-10, we have two options: one using the HuggingfaceDriver and one using the TorchvisionDriver.

Let's load data from both. To do that, we will instantiate the drivers with the help of the Catalog, and then load some samples using the `iterstream` API (if you haven't already, have a look at the `Iterstream Tutorial`).

Note that different drivers expect different keyword arguments:

In [None]:
driver_hg = cat["cifar10"][1]  # load version 1 (Huggingface)
driver_tv = cat["cifar10"][2]  # load version 2 (Torchvision)

N = 3  # just for demonstration, we dont need a lot of samples
samples_hg = driver_hg.get_driver().get_iter(split="train").take(N).collect()
samples_tv = driver_tv.get_driver().get_iter().take(N).collect()

The format of the returned samples also depends on the driver:

In [None]:
(
    samples_hg[0],  # a dictionary with keys "img" and "label",
    samples_tv[0],  # a tuple of (image, label id)
)

That's it! Here is what is returned from the drivers:

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure()
subfigs = fig.subfigures(nrows=2, ncols=1)
subfigs[0].suptitle("Huggingface")
subfigs[1].suptitle("Torchvision")

# plot huggingface
axs = subfigs[0].subplots(nrows=1, ncols=N)
for col, ax in enumerate(axs):
    sample = samples_hg[col]
    ax.imshow(sample["img"])
    label = f"Class: {sample['label']}"
    ax.set_title(label)

# plot torchvision
axs = subfigs[1].subplots(nrows=1, ncols=N)
for col, ax in enumerate(axs):
    sample = samples_tv[col]
    ax.imshow(sample[0])
    label = f"Class: {sample[1]}"
    ax.set_title(label)

## End
You are up to speed!

If you are willing to learn more, check out the `Plugins Tutorial` to see how you can implement and register a new plugin, which will extend the sources provided by squirrel-datasets.
You can also refer to the API reference to discover more information such as implementation details.