[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/merantix/mxlabs-datasets/blob/main/examples/Squirrel_Tutorial_Plugins.ipynb)

# Install Squirrel and Squirrel Datasets

In [None]:
# Please uncomment and run the following cells if needed

In [None]:
# !pip install keyring keyrings.google-artifactregistry-auth
# from google.colab import auth
# auth.authenticate_user()

In [None]:
# !pip install squirrel-core squirrel-datasets --extra-index=https://europe-west1-python.pkg.dev/mx-labs-devops/labs-pypi-registry/simple/ --ignore-requires-python --upgrade

In [None]:
# !pip install networkx --ignore-requires-python --upgrade

# The Squirrel Plugin System

Squirrel uses the amazing [pluggy](https://pluggy.readthedocs.io/en/latest/) library to provide extensibility.
Currently, you can add custom Drivers to support new data formats and Sources which are added to the default Catalog.
There are multiple ways of adding plugins.
See [here](https://pluggy.readthedocs.io/en/stable/#the-plugin-registry) to explore the various possibilities.
In this tutorial you will learn how to write custom Drivers and Sources and add them to Squirrel. Let's go!

# Catalog
Calatog is a dictionary-like data structure that allows you to add and remove `Source`.
A `Source` object has an attribute `driver_name`, which should be the same as the attribute of a driver class (will be explained below).
Let's see an example:

In [None]:
from squirrel.catalog import Catalog, Source

driver_name = "my_driver_name"
identifier = "my_identifier"
version = 1

cat = Catalog()

In [None]:
cat[identifier, version] = Source(driver_name=driver_name, driver_kwargs={}, metadata={})

In [None]:
cat[identifier, version]

Each entry in the catalog is versioned, and all versions are kept unless explicitly deleted.
The information in one entry in the catalog is all we need to instantiate a driver and start training.
But first, there should be a `Driver` associated with this entry.
Squirrel contains several useful drivers, and also allows you to write your own driver and register it using `pluggy`.
Your driver should inherit from `squirrel.driver.Driver` or one of its subclasses such as `IterDriver`, `MapDriver`, etc. 

In [None]:
from squirrel.driver import Driver
from squirrel.iterstream import IterableSource


class MyAwesomeDriver(Driver):

    name = "my_driver_name"

    def __init__(self, catalog=None):
        super().__init__(catalog)

    def get_iter(self):
        return IterableSource(range(10))

At this point, we have implemented our driver, and also declared it in the catalog.
To connect the two, we can simply register this driver using squirrel `register_driver` function.

In [None]:
from squirrel.framework.plugins.plugin_manager import register_driver

register_driver(MyAwesomeDriver)

In [None]:
cat[identifier, version].get_driver().get_iter().collect()

Note that the class attribute `name` in the driver must be present, and must be the same as `driver_name` of the `Source`. 

# Custom Drivers

Let's write our own driver by implementing a custom driver for graph data.
We have a huge graph which is stored in a distributed manner, and a microservice that access this storage engine and give us random walks on this graph.
The `GraphService` below with `sleep()` simulates this.
Our graph service randomly select a node in the network and returns a list of `num` random walks with length of `length`, starting from this node.

In [None]:
import random
from time import sleep
import networkx as nx


class GraphService:
    """A class that sample from an example graph"""

    def __init__(self, size):
        self.g = nx.random_graphs.complete_graph(size)

    def multi_walks(self, length, num):
        sleep(0.1)
        node = random.choice(list(self.g.nodes()))
        return [self.random_walk(node, length) for _ in range(num)]

    def random_walk(self, node, length):
        walk = [node]
        n = node
        for _ in range(length - 1):
            nei = list(self.g.neighbors(n))
            predecessor = random.choice(nei)
            walk.append(predecessor)
            n = predecessor
        return walk

## Our custom driver

`GraphDriver` inherited from `IterDriver`. 

In [None]:
from squirrel.driver import IterDriver


class GraphDriver(IterDriver):

    name = "graph_driver"

    def __init__(self, catalog=None, size=10):
        super().__init__()
        self.graph_service = GraphService(size)

    def get_iter(self, num_samples, length, num):
        return IterableSource(range(num_samples)).map(lambda i: self.graph_service.multi_walks(length, num))

In [None]:
driver = GraphDriver(10)

for i in driver.get_iter(1, 2, 3):
    print(i)

In [None]:
cat["gd"] = Source(driver_name="graph_driver")
register_driver(GraphDriver)

In [None]:
cat["gd"]

Note that in this case we did not specify a version.
Squirrel automatically assigns `version==1` if the source does not exist in the catalog.
We can now start using this new driver via the catalog api. 

In [None]:
cat["gd"].get_driver().get_iter(1, 2, 3).collect()

# How to share Catalogs, Drivers, and Sources

We suggest to share you Drivers and Sources depending on the scope with different approaches:


# Share with outside collaborators

Publish your Drivers and Sources using [entry points](https://pluggy.readthedocs.io/en/stable/#loading-setuptools-entry-points) in pluggy.
The squirrel-datasets package is our reference implementation.
You can see all available drivers with `squirrel.framework.plugins.plugin_manager.list_driver_names()`.


# Share within your project

Use the Python API to define your Catalog in your package and register Driver using the ```squirrel.framework.plugins.plugin_manager``` module.

# Automatic pipelines

For CI4ML pipelines, Squirrel offers sharing Catalogs as YAML files.
Have a look at ```squirrel.catalog.Catalog.to_file()``` and ```squirrel.catalog.Catalog.from_dirs()``` to get started. 

# Improve performance

All you need for over one order of magnitude of improvement in this example is replacing `map` with `async_map` in the `get_iter` method.
Under the hood, squirrel uses multithreading to parallelize the IO-bound operation of this mock `GraphService`.
This is always useful when interfacing with remote data sources such as `Object Store`, `Database`, `HTTP endpoint`, etc.

In [None]:
class BetterGraphDriver(IterDriver):

    name = "better_graph_driver"

    def __init__(self, catalog=None, size=10):
        super().__init__()
        self.graph_service = GraphService(size)

    def get_iter(self, num_samples, length, num):
        return IterableSource(range(num_samples)).async_map(
            lambda i: self.graph_service.multi_walks(length, num), buffer=1000
        )

In [None]:
gd = GraphDriver(100)
bgd = BetterGraphDriver(100)

In [None]:
%%timeit

for i in gd.get_iter(10 * 3, 10, 10):
    pass

In [None]:
%%timeit

for i in bgd.get_iter(10 * 3, 10, 10):
    pass