# Harmonizing `Response`s with JAX

In [10]:
from jax import jit
from gcdyn import responses
from jaxopt import GradientDescent, ScipyBoundedMinimize

## Understanding object flattening

In [2]:
resp = responses.ConstantResponse(-2.)
flat = resp.tree_flatten()

print(resp)
print(flat)
print(responses.ConstantResponse().tree_unflatten(flat[1], flat[0]))

ConstantResponse(value=-2.0)
((-2.0,), ('value',))
ConstantResponse(value=-2.0)


## Gradient descent

In [8]:
@jit
def objective(resp):
    x = resp.f(1)
    return (x - 5)**2

objective(resp)

DeviceArray(49., dtype=float32, weak_type=True)

In [9]:
optimizer = GradientDescent(fun=objective)
optimizer.run(
    responses.ConstantResponse(-2.)
)

OptStep(params=ConstantResponse(value=5.0), state=ProxGradState(iter_num=DeviceArray(1, dtype=int32, weak_type=True), stepsize=DeviceArray(1., dtype=float32, weak_type=True), error=DeviceArray(0., dtype=float32), aux=None, velocity=ConstantResponse(value=5.0), t=DeviceArray(1.618034, dtype=float32, weak_type=True)))

## Bounded optimization

Note that the bounds used in the optimizer call must be numerical and not `Response` objects, and I don't exactly know why

In [14]:
optimizer = ScipyBoundedMinimize(fun=objective)
optimizer.run(
    responses.ConstantResponse(-2.),
    (-10, 10)
)

OptStep(params=ConstantResponse(value=5.0), state=ScipyMinimizeInfo(fun_val=DeviceArray(0., dtype=float32, weak_type=True), success=True, status=0, iter_num=2))