# Introduction to Dask-checkpoint

Dask-checkpoint is a Python package
that adds a customizable caching capabilities to [dask](https://dask.org).
It builds on top of `dask.delayed`,
adding load and save instructions
to the dask graph.

In [1]:
from dask_checkpoint import Storage, task

## The task decorator

A task can be created from a function using `task` as a decorator.

If you're familiar with `dask.delayed`,
`task` uses `dask.delayed(pure=True)` inside.

In [2]:
@task
def double(x: float) -> float:
    return 2 * x

Calling the function doesn't perform the computation immediately,
but returns a `dask.delayed` object:

In [3]:
double(21)

Delayed('double/3d00a6e20b85ee3517c62a761f774f72')

To obtain the result,
you must call the `.compute` method.

In [4]:
double(21).compute()

42

## Storing results to disk

So far,
a task behaves as a `dask.delayed`.
But,
additionally,
we can easily save and load data
to avoid recomputing expensive tasks,
even between different Python sessions.

To do that,
we first need to create a `Storage` object,
which accepts any object implementing a `MutableMapping[str, bytes]` interface.
`Storage.from_fsspec` accepts a `str`,
which is passed to `fsspec` to create an `fsspec.FSMap`.

Here,
we will just use a simply Python `dict`:

In [5]:
storage = Storage({})

To save (and then load) a task,
we must explicitly create it with `save=True`,
as we probably don't want to save cheap to compute
and expensive to store tasks.

In [6]:
@task(save=True)
def double_with_print(x: float) -> float:
    out = 2 * x
    # The next line is just to demonstrate
    # that the function does not run
    # when it is loaded from the Storage
    print(f"Calculating: 2 * {x} = {out}")
    return out

Again,
calling the function does not actually run it:

In [7]:
my_task = double_with_print(21)

my_task

Delayed('double_with_print/3d00a6e20b85ee3517c62a761f774f72')

but creates a `dask.delayed` object.

If we call `.compute()`,
we see the `print` statement inside the function
and it returns the result.

In [8]:
my_task.compute()

Calculating: 2 * 21 = 42


42

Now,
if we call compute inside an storage context manager,
that result will be saved to the storage:

In [9]:
with storage():
    result = my_task.compute()

result

Calculating: 2 * 21 = 42


42

We can inspect `data` attribute of the storage,
which in this case corresponds to a `dict`,
and see that the result was saved:

In [10]:
storage.data

{'double_with_print/3d00a6e20b85ee3517c62a761f774f72': b'(\xb5/\xfd \x05)\x00\x00\x80\x05K*.'}

When we recompute the task inside the context manager,
we obtain the same result,
but we see no output from the `print` function,
as `double_with_print` wasn't actually run,
but the result retrieved from storage.

In [11]:
with storage():
    result = my_task.compute()

result

42

Recomputing outside the context manager does run the function again:

In [12]:
my_task.compute()

Calculating: 2 * 21 = 42


42

**Important:**
functions are expected to be **pure**,
that is,
that their output only depend on their input parameters,
and have no side effects.

A function will not be called again
when its output is already stored.

Example of non-pure functions:

- **Mutating input**: using in-place operations.
- **Non-deterministic output**: drawing random numbers, or relying on global variables.
- **Side effects:** updating global variables.

We've seen an example of a side effect:
the `print` function in the previous example
was not called when recomputing the task
(inside the storage context manager).

## Custom encoding

To store data,
results need to be encoded into a bytes representation.
By default,
`task` encodes data by
serializing with `cloudpickle` and
compressing with `zstandard`.

But,
`cloudpickle` is not appropriate for long-term storage,
as it depends on the Python version used.

To customize the encoding,
`task` accepts any encoder implementing the `Encoder[T, bytes]` protocol:

```python
class Encoder(Protocol[T, E]):
    def encode(self, value: T) -> E: ...
    def decode(self, value: E) -> T: ...
```

We offer a customizable `DefaultEncoder`,
which accepts the following arguments:

```python
@dataclass(kw_only=True)
class DefaultEncoder:
    encoders: tuple[Encoder] = ()
    serializer: Serializer | None = cloudpickle
    compressor: Compressor | None = zstandard
    encrypter: Encrypter | None = None
```

where

| type       | implements           |
|------------|----------------------|
| Encoder    | encode, decode       |
| Serializer | dumps, loads         |
| Compressor | compress, decompress |
| Encrypter  | encrypt, decrypt     |

The `DefaultEncoder` applies the following transformations:

`encoders[0] -> ... -> encoder[-1] > serializer -> compressor -> encrypter`

when encoding,
and in reverse order when decoding.

Most serializers, compressors and encrypters already implement these interfaces,
so you can simply import a module/class/object and pass it to `DefaultEncoder()`.

For instance,
[`numcodecs`](https://numcodecs.readthedocs.io/en/stable/)
implements several encoders
that might be useful to apply before serializing.

In the submodule `serializer`,
some commonly used serializers are already included:

In [13]:
from dask_checkpoint import DefaultEncoder, serializer

For the following function,
we will create 3 task variants,
with different serializers,
and no compression:

In [14]:
# we set compressor=None to leave an human-readable byte representation
def point_as_dict(x, y):
    return dict(x=x, y=y)


@task(save=True, encoder=DefaultEncoder(compressor=None))
def point_as_cloudpickle(x, y):
    return point_as_dict(x, y)


@task(save=True, encoder=DefaultEncoder(serializer=serializer.json, compressor=None))
def point_as_json(x, y):
    return point_as_dict(x, y)


@task(save=True, encoder=DefaultEncoder(serializer=serializer.yaml, compressor=None))
def point_as_yaml(x, y):
    return point_as_dict(x, y)


with storage():
    point_as_cloudpickle(x=1, y=2).compute()
    point_as_json(x=1, y=2).compute()
    point_as_yaml(x=1, y=2).compute()

If we inspect the storage,
we can read see the different outputs:

In [15]:
storage.data

{'double_with_print/3d00a6e20b85ee3517c62a761f774f72': b'(\xb5/\xfd \x05)\x00\x00\x80\x05K*.',
 'point_as_cloudpickle/b6e4f13ef66a0d1496c9561971b0b66c': b'\x80\x05\x95\x11\x00\x00\x00\x00\x00\x00\x00}\x94(\x8c\x01x\x94K\x01\x8c\x01y\x94K\x02u.',
 'point_as_json/b6e4f13ef66a0d1496c9561971b0b66c': b'{"x": 1, "y": 2}',
 'point_as_yaml/b6e4f13ef66a0d1496c9561971b0b66c': b'x: 1\ny: 2\n'}

## Custom output naming

By default,
an output's name is given by `f"{task.name}{arguments_hash}"`.

The `task` decorator accepts a `name` argument
to customize its name,
either by passing an `str`,
or a callable which receives the underlying function `func`.
By default,
it uses a callable which returns `func.__name__ + "/"`.

The hash of the arguments can also be customized,
passing a function to the `hasher` parameter of `task`.

In [16]:
def my_function_namer(func: callable) -> str:
    return f"{func.__name__}-"


def my_hasher(kwargs: dict) -> str:
    return str(kwargs["x"])


@task(name=my_function_namer, hasher=my_hasher)
def double(x):
    return 2 * x


double(21)

Delayed('double-21')

### Combined Storages

There are two ways to combine storages:

#### Chaining storages

To join multiple storages,
we can use `Storage.chain`
to create a joint `MutableMapping`:

```python
storage_local = Storage.from_fsspec("/local_folder")
storage_remote = Storage.from_fsspec("ssh://user@server/home/user/remote_folder")
storage_combined = Storage.from_chain(storage_local, storage_remote)

with storage_combined():
    ...
```

In this case,
`storage` will try to load in order,
first from `"local_folder"`
and then from `"remote_folder"`.
If it needs to save a task,
it will be saved to the first one (`local_storage`).

*Note: underneath it's just using a `collections.ChainMap` to join them.*

#### Nested storages

Alternatively, we can simply nest the context managers:

```python
with storage_remote():
    # Tasks computed here load from and save to remote only
    task.compute()

    with storage_local():
        # Tasks computed here load from local or remote (in that order)
        # and save to local.
        task.compute()

    with storage_local(nested=False):
        # Tasks computed here ignore remote
        task.compute()

## Read more at dask.org

As `Task` works on top of `dask.delayed`,
it is useful to check out dask's documentation:

- Delayed: https://docs.dask.org/en/stable/delayed.html
- Delayed Collections: https://docs.dask.org/en/stable/delayed-collections.html
- Delayed Best Practices: https://docs.dask.org/en/stable/delayed-best-practices.html
- General Best Practices: https://docs.dask.org/en/stable/best-practices.html

Many tips discussed there apply for `Task` too.
In particular,
these points,
which were discussed before:

- Don’t mutate inputs
- Avoid global state
- Don’t rely on side effects