In [3]:
import logging

import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
import optimistix as optx

from lib.ml.base_models import ICNNObsDecoder

jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platform_name', 'cpu')
jax.config.update('jax_platforms', 'cpu')

sys.path.append("../../..")
from lib.ml.base_models import ICNNObsDecoder

In [None]:

class OptimiserPath(eqx.Module):
    parameters: jnp.ndarray
    values: jnp.ndarray


class InteractiveICNNObsDecoder(ICNNObsDecoder):
    """."""

    def interactive_partial_input_optimise(self, input: jnp.ndarray, fixed_mask: jnp.ndarray) -> OptimiserPath:
        solver = self.optax_solver(self.optax_optimiser_name)
        fn = lambda y, args: self.f_energy(y, args)
        max_steps = 2 ** 10
        ys = []
        y = input
        num_steps = 0
        args = None
        f_struct = jax.ShapeDtypeStruct((), jnp.float64)
        aux_struct = None
        options = {'fixed_mask': fixed_mask}

        tags = frozenset()

        # These arguments are always fixed throughout interactive solves.
        step = eqx.filter_jit(
            eqx.Partial(solver.step, fn=fn, args=args, options=options, tags=tags)
        )
        terminate = eqx.filter_jit(
            eqx.Partial(solver.terminate, fn=fn, args=args, options=options, tags=tags)
        )

        # Initial state before we start solving.
        state = solver.init(fn, y, args, options, f_struct, aux_struct, tags)
        done, result = terminate(y=y, state=state)

        for _ in range(max_steps):
            if done:
                break
            y, state, aux = step(y=y, state=state)
            ys.append(y)
            done, result = terminate(y=y, state=state)
            num_steps += 1
        if result != optx.RESULTS.successful:
            logging.warning(f"Oh no! Got error {optx.RESULTS[result]}.")

        ys = jnp.stack(ys)
        vals = eqx.filter_vmap(eqx.Partial(self.f_energy, args=args))(ys)
        return OptimiserPath(parameters=ys, values=vals)





In [None]:
icnn_dec = InteractiveICNNObsDecoder(observables_size=input_size, state_size=0,
                                      hidden_size_multiplier=imputer_hidden_size_multiplier,
                                      depth=imputer_depth, 
                                      optax_optimiser_name=icnn_optax_optimiser_name,
                                      key=jr.PRNGKey(seed))

In [None]:
# Seek `y` such that `y - tanh(y + 1) = 0`.
@eqx.filter_jit
def fn(y, args):
    out = y - jnp.tanh(y + 1)
    aux = None
    return out, aux


solver = optx.Bisection(rtol=1e-3, atol=1e-3)
# The initial guess for the solution
y = jnp.array(0)
# Any auxiliary information to pass to `fn`.
args = None
# The interval to search over. Required for `optx.Bisection`.
options = dict(lower=-1, upper=1)
# The shape+dtype of the output of `fn`
f_struct = jax.ShapeDtypeStruct((), jnp.float32)
aux_struct = None
# Any Lineax tags describing the structure of the Jacobian matrix d(fn)/dy.
# (In this case it's just a 1x1 matrix, so these don't matter.)
tags = frozenset()


def solve(y, solver):
    # These arguments are always fixed throughout interactive solves.
    step = eqx.filter_jit(
        eqx.Partial(solver.step, fn=fn, args=args, options=options, tags=tags)
    )
    terminate = eqx.filter_jit(
        eqx.Partial(solver.terminate, fn=fn, args=args, options=options, tags=tags)
    )

    # Initial state before we start solving.
    state = solver.init(fn, y, args, options, f_struct, aux_struct, tags)
    done, result = terminate(y=y, state=state)

    # Alright, enough setup. Let's do the solve!
    while not done:
        print(f"Evaluating point {y} with value {fn(y, args)[0]}.")
        y, state, aux = step(y=y, state=state)
        done, result = terminate(y=y, state=state)
    if result != optx.RESULTS.successful:
        print(f"Oh no! Got error {result}.")
    y, _, _ = solver.postprocess(fn, y, aux, args, options, state, tags, result)
    print(f"Found solution {y} with value {fn(y, args)[0]}.")


solve(y, solver)