In [None]:
# import
from progeval import ProgEval

# How-to guide

Let's consider the following computation:

![toy computation](toy-computation.svg)

## Dynamically construct graph
The most straight-forward way to define the computational graph is by assigning to a `ProgEval()` object.
If a callable object is assigned, it is automatically interpreted as specifying how to compute the given quantity.
Dependent quantities are detected based on the argument names of the function.

In [None]:
def compute_alpha(x):
    print(f'computing alpha = 2 * {x} + 1')
    return 2 * x + 1

def compute_beta(y):
    print(f'computing beta = {y} * {y}')
    return y * y

In [None]:
graph = ProgEval()

graph.alpha = compute_alpha
graph.beta = compute_beta
# any callable object works
graph.gamma = lambda alpha, beta, y: alpha * beta - y

Having constructed the graph, we can set input values and compute the outputs

In [None]:
graph.x, graph.y = 3, 4
graph.gamma

If we request intermediate values now, they are not computed again (note there is no printed message)!

In [None]:
graph.beta

### Evaluate everything
We can evluate and collect all quantities in the graph by invoking `compute_all_quantities`:

In [None]:
graph.compute_all_quantities()

### Clear cache
By removing all saved intermediate values, we can force the computational graph to be recomputed in full.

In [None]:
graph.clear_cache()
graph.compute_all_quantities()

### Changing the computational graph
When we override the input values, only those quantities that depend on the changes will be re-computed (no printed message for alpha):

In [None]:
graph.y = 8
graph.gamma

Besides re-assigning values to the inputs, we can also change the structure of the graph itself.

In [None]:
graph.gamma = lambda alpha, beta: alpha - beta
graph.gamma

### Disabling recomputation
If for any reason the re-computation of values is not desired, it can be disabled by specifying `track_dependence=False`.
In that case, the graph no longer registers which quantities are requested in the different computations.

In [None]:
graph = ProgEval(track_dependence=False)

graph.alpha = compute_alpha
graph.beta = compute_beta
graph.gamma = lambda alpha, beta, y: alpha * beta - y

graph.x, graph.y = 3, 4
graph.gamma

In [None]:
graph.y = 8
graph.gamma  # now, no change

### Specifying input arguments
Above, the inputs to the node functions are derived from their call signature.
Instead, it is also possible to explicitly pass their names.

In [None]:
def prod(a, b):
    return a * b

graph = ProgEval()
graph.register('beta', prod, ['y', 'y'])

graph.y = 5
graph.beta

## Define computations as classes

Instead of defining computational graphs by assinging nodes to a ProgEval object, we can also define a new class that represents the computation.
This can be nice for two reasons:
1. All functions/quantities are in one place and are registered automatically.
2. We can easily specify all input values and efficiently creat the corresponding graph.

The only thing we need to do is to sub-class `ProgEval`.

In [None]:
class MyComputation(ProgEval):
    
    # this says the function below does not have a `self` argument
    @staticmethod
    def alpha(x):
        print(f'computing alpha = 2 * {x} + 1')
        return 2 * x + 1

    @staticmethod
    def beta(y):
        print(f'computing beta = {y} * {y}')
        return y * y
    
    @staticmethod
    def gamma(y, alpha, beta):
        print(f'computing gamma = {alpha} * {beta} - {y}')
        return alpha * beta - y

In [None]:
comp = MyComputation()

To evaluate it, we must assign the input values `x` and `y`:

In [None]:
comp.x, comp.y = 5, 3
comp.gamma

The strucuture can be made even cleaner by taking the inputs of the computations as inputs when creating the graph.

In [None]:
class MyComputation(ProgEval):
    
    def __init__(self, x, y):
        super().__init__(x=x, y=y)
    
    @staticmethod
    def alpha(x):
        print(f'computing alpha = 2 * {x} + 1')
        return 2 * x + 1

    @staticmethod
    def beta(y):
        print(f'computing beta = {y} * {y}')
        return y * y
    
    @staticmethod
    def gamma(y, alpha, beta):
        print(f'computing gamma = {alpha} * {beta} - {y}')
        return alpha * beta - y

In [None]:
MyComputation(5, 3).gamma

In [None]:
MyComputation(4, 5).gamma

In this setting, recomputation may not be required (since one would just call with different inputs instead of replacing `x` and `y`).
Dependency tracking can be turned off by setting `class MyComputation(ProgEval, track_dependence=False)` in the first line.

### Accessing quantities as attributes of ``self``

In the above examples, the dependencies of quantities were made explicit by the arguments the functions take.
It is also possible to have methods that are not a `staticmethod`, i.e. that access quantities as attributes of `self`.
However:

```{eval-rst}
.. warning::
    If a method accesses computational quantities as attributes of ``self`` (instead of explicit arguments), the dependencies in the computational graph can currently not be tracked.
    That means quantities are not properly recomputed when intermediate values are changed.
    This is only a problem if the computational graph is changed, i.e. if nodes are replaced or deleted, after it was created.
```

In [None]:
class MyComputation(ProgEval):
    
    def __init__(self, x, y):
        super().__init__(x=x, y=y)

    def alpha(self):
        print(f'computing alpha = 2 * {self.x} + 1')
        return 2 * self.x + 1

    def beta(self):
        print(f'computing beta = {self.y} * {self.y}')
        return self.y * self.y
    
    def gamma(self):
        print(f'computing gamma = {self.alpha} * {self.beta} - {self.y}')
        return self.alpha * self.beta - self.y

In [None]:
MyComputation(5, 3).gamma

In [None]:
MyComputation(4, 5).gamma

## Advanced: transforming functions
It is possible to specify an optional `transformer` when constructing the computational graph, which can modify the node functions before they are added. 
It must take three arguments: `transformer(function, static, name)`.
The first is the function which is used to compute the quantity with the given `name`.
`static` is a boolean value which is false if the function takes `self` as the first argument.

The output should be a function of the same signature. If the signature is changed, the output must be a tuple of the transformed function and the new signature as an instance of type `inspect.Signature`.

Below are two examples of how this can be used.
They require [JAX](https://github.com/google/jax/) and [Dask](https://docs.dask.org/) to be installed, respectively.

### Just in time compilation with JAX

In [None]:
import jax

In [None]:
def jit_if_static(fun, static, name):
    # only jit compile if the function doesn't depend on self
    if static:
        return jax.jit(fun)
    return fun


class Computation(ProgEval):
    
    def __init__(self, x, y):
        super().__init__(x=x, y=y)
        
    @staticmethod
    def alpha(x, y):
        return jax.numpy.trace(x @ x) * jax.numpy.trace(y)
    
    @staticmethod
    def beta(x, y, alpha):
        return jax.numpy.trace(x @ y) * alpha
    
    @staticmethod
    def total(alpha, beta):
        return (alpha + beta) / alpha.size

The above construction only makes a noticeable difference if the individual functions are sufficiently costly.
Another useful pattern with JAX is that we can define efficient function for parts of the computational tree we are interested in, without repeating code.

In [None]:
@jax.jit
def compute_alpha(x, y):
    return Computation(x, y).alpha

@jax.jit
def compute_beta(x, y):
    return Computation(x, y).beta

@jax.jit
def compute_total(x, y):
    return Computation(x, y).total

In [None]:
rng = jax.random.PRNGKey(0)
x, y = jax.random.normal(rng, (2, 32, 32))

%timeit compute_alpha(x, y).block_until_ready()

In [None]:
%timeit compute_beta(x, y).block_until_ready()

In [None]:
%timeit compute_total(x, y).block_until_ready()

Comparing this with a manual implementation, we see that the construction via the computational graph has virtually no cost after jit-compilation.

In [None]:
@jax.jit
def computation(x, y):
    alpha = jax.numpy.trace(x @ x) * jax.numpy.trace(y)
    beta = jax.numpy.trace(x @ y) * alpha
    total = (alpha + beta) / alpha.size
    return total

In [None]:
%timeit computation(x, y).block_until_ready()

### Generating Dask delayed objects

In [None]:
import dask

In [None]:
def delay(fun, _, name):
    return dask.delayed(fun, name)
          
comp = ProgEval(transformer=delay)

def inc(a):
    return a + 1

def add(a, b):
    return a + b

comp.x = 5
comp.y = 3
comp.register('a', inc, 'x') 
comp.register('b', inc, 'y') 
comp.register('total', add)

# comp.total.visualize()