Note: this notebook heavily relies on Aleksa Gordic's note book

# Stateful Computations

## Problem

In [2]:
import jax
import jax.numpy as jnp
import numpy as np

from jax import grad, jit, vmap, pmap

from jax import random
import matplotlib.pyplot as plt
from copy import deepcopy
from typing import Tuple, NamedTuple
import functools

In [3]:
# 1) Imputer functions are problematic.

g = 0.  # state

def impure_uses_global(x):
    return x + g

# JAX captures the value of the global/state during the first run
print("First call:", jit(impure_uses_global)(4.))

g = 10.

print("Second call:", jit(impure_uses_global)(5.))

First call: 4.0
Second call: 5.0


In [6]:
# 2) Pattern how JAX's PRNG (not stateful) is handling state.

seed = 0
state = jax.random.PRNGKey(seed)

state1, state2 = jax.random.split(state)

In [8]:
# Problem of state!
# NN love statefulness: model params, optimizer params, BatchNorm, etc.
# and this is a problem with JAX

class Counter:
    """A simple counter."""

    def __init__(self):
        self.n = 0

    def count(self) -> int:
        self.n += 1
        return self.n

    def reset(self):
        self.n = 0

counter = Counter()

for _ in range(3):
    print(counter.count())


1
2
3


In [9]:
from jax import make_jaxpr

counter.reset()
print(make_jaxpr(counter.count)())

{ lambda ; . let  in (1,) }


In [10]:
counter.reset()
fast_count = jax.jit(counter.count)

for _ in range(3):
    print(fast_count())

1
1
1


## Solution

In [11]:
CounterState = int  # our counter state is implemented as a simple integer

class CounterV2:

    def count(self, n: CounterState) -> Tuple[int, CounterState]:
        # You could just return n + 1, but here we seperate its role as
        # the output and as the counter state for didactic purposes.
        # (as output may be some arbitrary function of state in general case)
        return n+1, n+1

    def reset(self) -> CounterState:
        return 0

counter = CounterV2()
state = counter.reset()

for _ in range(3):
    value, state = counter.count(state)
    print(value)

1
2
3


In [None]:
state = counter.reset()
fast_count = jax.jit(counter.count)

for _ in range(3):
    value, state = fast_count(state)
    print(value)

In summary, we used the following rule to convert a stateful class:

```python
class StatefulClass
    
    state: State

    def stateful_method(*arg, **kwargs)->Output:
```

into a class of the form:
```python
class StatelessClass

    def stateless_method(state: State, *args, **kwargs)->(Output, State):
```