# Harmonizing `Response`s with JAX

In [1]:
from jax import jit
from gcdyn import responses, bdms, mutators, model
from jaxopt import GradientDescent, ScipyBoundedMinimize
from jax.tree_util import register_pytree_node_class

responses.init_numpy(use_jax=True)

## Understanding object flattening

In [2]:
resp = responses.ConstantResponse(-2.)
flat = resp.tree_flatten()

print(resp)
print(flat)
print(responses.ConstantResponse().tree_unflatten(flat[1], flat[0]))

ConstantResponse(value=-2.0)
((-2.0,), ('value',))
ConstantResponse(value=-2.0)


## Gradient descent

In [3]:
register_pytree_node_class(responses.ConstantResponse)

@jit
def objective(resp):
    x = resp.f(1)
    return (x - 5)**2

objective(resp)

DeviceArray(49., dtype=float32, weak_type=True)

In [4]:
optimizer = GradientDescent(fun=objective)
optimizer.run(
    responses.ConstantResponse(-2.)
)

OptStep(params=ConstantResponse(value=5.0), state=ProxGradState(iter_num=DeviceArray(1, dtype=int32, weak_type=True), stepsize=DeviceArray(1., dtype=float32, weak_type=True), error=DeviceArray(0., dtype=float32), aux=None, velocity=ConstantResponse(value=5.0), t=DeviceArray(1.618034, dtype=float32, weak_type=True)))

## Bounded optimization

Note that the bounds used in the optimizer call must be numerical and not `Response` objects, and I don't exactly know why

In [5]:
optimizer = ScipyBoundedMinimize(fun=objective)
optimizer.run(
    responses.ConstantResponse(-2.),
    (-10, 10)
)

OptStep(params=ConstantResponse(value=5.0), state=ScipyMinimizeInfo(fun_val=DeviceArray(0., dtype=float32, weak_type=True), success=True, status=0, iter_num=2))

## BDMS inference

In [6]:
tree = bdms.TreeNode()

responses.init_numpy(False) # jax numpy breaks evolve for some reason
tree.evolve(
    t = 6,
    birth_rate = responses.SigmoidResponse(4, 0, 1, 0),
    death_rate = responses.ConstantResponse(0),
    mutation_rate = responses.ConstantResponse(1),
    mutator = mutators.GaussianMutator(-1, 1),
    seed = 10
)
tree.sample_survivors(p = 1)

print(tree)


         /- /- /- /- /-43
      /-|
     |   \- /- /- /- /- /-31
     |
     |         /- /- /- /-42
-- /-|      /-|
     |     |   \- /- /-47
     |   /-|
     |  |  |   /- /-49
     |  |   \-|
      \-|      \- /-40
        |
        |   /- /- /- /- /- /-48
        |  |
         \-|      /- /- /-30
           |   /-|
            \-|   \- /- /-45
              |
               \- /- /- /- /- /-38


In [7]:
responses.init_numpy(True)
mod = model.BdmsModel(
    trees = [tree],
    death_rate = responses.ConstantResponse(0),
    mutation_rate = responses.ConstantResponse(1),
    mutator = mutators.GaussianMutator(-1, 1),
    sampling_probability = 1
)

In [8]:
mod.fit(
    init_value=responses.SigmoidResponse(2., 0., 1., 0.)
)

OptStep(params=SigmoidResponse(xscale=8.502071380615234, xshift=-0.02343890629708767, yscale=1.6619614362716675, yshift=0.0), state=ScipyMinimizeInfo(fun_val=DeviceArray(151.65233, dtype=float32, weak_type=True), success=True, status=0, iter_num=16))

In [9]:
# JIT works, this runs faster
mod.fit(
    init_value=responses.SigmoidResponse(2., 0., 1., 0.)
)

OptStep(params=SigmoidResponse(xscale=8.502071380615234, xshift=-0.02343890629708767, yscale=1.6619614362716675, yshift=0.0), state=ScipyMinimizeInfo(fun_val=DeviceArray(151.65233, dtype=float32, weak_type=True), success=True, status=0, iter_num=16))

TODO: I think `Responses._param_dict` has no guarantee of the order of the parameters, so I think the init value and optimization bounds might not be specified correctly...