In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# Eliminating for-loops that have carry-over using `lax.scan`

We are now going to see how we can eliminate for-loops that have carry-over using `lax.scan`.

From the JAX docs, `lax.scan` replaces a for-loop with carry-over:

> Scan a function over leading array axes while carrying along state.
> 
> ...
> 
> ```python
> def scan(f, init, xs, length=None):
    if xs is None:
         xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        ys.append(y)
    return carry, np.stack(ys)
> ```

A key requirement of the function `f` is that it must have only two positional arguments in there, one for `carry` and one for `x`. You'll see how we can thus apply `functools.partial` to construct functions that have this signature from other functions that have other 

Let's see some concrete examples of this in action.

## Updating a variable with new info on each loop iteration

One classic case where we might use a for-loop is in the cumulative sum or product. Here, we need the current loop information to update the information from the previous loop. Let's see it in action for the cumulative sum:

In [None]:
import jax.numpy as np
a = np.array([1, 2, 3, 5, 7, 11, 13, 17])

result = []
res = 0
for el in a:
    res += el
    result.append(res)
np.array(result)

This is identical to the cumulative sum:

In [None]:
np.cumsum(a)

Now, let's write it using `lax.scan`, so we can see the pattern in action:

In [None]:
from jax import lax
def scanfunc(res, el):
    res = res + el
    return res, res  # ("carryover", "accumulated")

result_init = 0
final, result = lax.scan(scanfunc, result_init, a)
result

As you can see, scanned function has to return two things:

- One object that gets carried over to the next loop (`carryover`), and
- Another object that gets "accumulated" into an array (`accumulated`).

The starting initial value, `result_init`, is passed into the `scanfunc` as `res` on the first call of the `scanfunc`. On subsequent calls, the first `res` is passed back into the `scanfunc` as the new `res`.

## Example 1: Simulating compound interest

We can use `lax.scan` to generate data that simulates the generation of wealth by compound interest. Here's an implementation using a plain vanilla for-loop:

In [None]:
wealth_record = []
starting_wealth = 100.
interest_factor = 1.01

prev_wealth = starting_wealth
for t in range(100):
    new_wealth = prev_wealth * interest_factor
    wealth_record.append(prev_wealth)
    prev_wealth = new_wealth

np.array(wealth_record)

Now, we'll try implementing it in a `lax.scan` form:

In [None]:
from functools import partial

starting_wealth = 100.
interest_factor = 1.01

timesteps = np.arange(100)

def make_wealth_at_time_func(interest_factor):
    def wealth_at_time(prev_wealth, time):
        new_wealth = prev_wealth * interest_factor
        return new_wealth, prev_wealth
    return wealth_at_time

wealth_func = make_wealth_at_time_func(interest_factor)

final, result = lax.scan(wealth_func, init=starting_wealth, xs=timesteps)
result

The two are equivalent, so we know we have the `lax.scan` implementation right.

In [None]:
import matplotlib.pyplot as plt
plt.plot(wealth_record, label="for-loop")
plt.plot(result, label="lax.scan")
plt.legend();

## Example 2: Compose `vmap` and `lax.scan` together

That was one simulation of wealth generation by compound interest for one individual. Now, let's simulate the wealth generation for different starting wealth levels (you may choose any 300 starting points however you'd like). To do so, you'll likely want to start with a function that accepts a scalar starting wealth and generates the simulated time series from there, and then `vmap` that function across multiple starting points (which is an array itself).

In [None]:
from jax import vmap
def make_simulation_func(timesteps):
    def inner(starting_wealth):
        final, result = lax.scan(wealth_func, init=starting_wealth, xs=timesteps)
        return final, result
    return inner

simulation_func = make_simulation_func(timesteps=np.arange(200))
starting_wealth = np.arange(300).astype(float)

final, growth = vmap(simulation_func)(starting_wealth)
growth

In [None]:
plt.plot(growth[1])
plt.plot(growth[2])
plt.plot(growth[3]);