In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from neurolib.models.jax.wc import WCModel
from neurolib.models.jax.wc.timeIntegration import timeIntegration_args, timeIntegration_elementwise
from neurolib.optimize.autodiff.wc_optimizer import args_names

import logging

In [2]:
model = WCModel()

model.params.duration = 203
model.params.sigma_ou = 0

In [3]:
def update_control_with_limit(N, dim_in, T, control, step, gradient, u_max):
    return control + step * gradient

In [69]:
from neurolib.control.optimal_control.oc import getdefaultweights

class OcWc:
    def __init__(
        self,
        model,
        target,
        opt_params=['exc_ext']
    ):
        self.model = model
        self.target = target
        self.opt_params = opt_params
        self.weights = getdefaultweights()
        self.M = 1
        
        args_values = timeIntegration_args(self.model.params)
        self.args = dict(zip(args_names, args_values))

        self.loss = self.get_loss()
        self.compute_gradient = jax.jit(jax.grad(self.loss))
        self.T = self.args['exc_ext'].shape[1]
        self.control = jnp.zeros_like(self.args['exc_ext'], dtype=float)#TODO: depend on opt_params

        self.step = 10.0  # Initial step size in first optimization iteration.
        self.count_noisy_step = 10
        self.count_step = 30

        self.factor_down = 0.5  # Factor for adaptive step size reduction.
        self.factor_up = 2.0  # Factor for adaptive step size increment.
        
        self.cost_history = []
        self.step_sizes_history = []
        self.step_sizes_loops_history = []

        self.dim_vars = len(self.model.state_vars)
        self.dim_in = 1
        self.dim_out = len(self.model.output_vars)
        self.maximum_control_strength = 0

        self.print_array = []
        self.zero_step_encountered = False  # deterministic gradient descent cannot further improve


    def simulate(self, control):
        args_local = self.args.copy()
        args_local.update(dict(zip(self.opt_params, [control])))
        return timeIntegration_elementwise(**args_local)
    
    def get_loss(self):
        @jax.jit
        def loss(control):
            t, exc, inh, exc_ou, inh_ou = self.simulate(control)
            return self.compute_total_cost(control, exc)
        return loss
    
    def accuracy_cost(self, exc):
        return self.weights["w_p"] * 0.5 * self.model.params.dt * jnp.sum((exc - self.target)**2)
    
    def control_strength_cost(self, control):
        return self.weights["w_2"] * 0.5 * self.model.params.dt * jnp.sum(control**2)

    def compute_total_cost(self, control, exc):
        """Compute the total cost as weighted sum precision of all contributing cost terms.
        :rtype: float
        """
        accuracy_cost = self.accuracy_cost(exc)
        control_strength_cost = self.control_strength_cost(control)
        return accuracy_cost + control_strength_cost
    
    def optimize_deterministic(self, n_max_iterations):
        """Compute the optimal control signal for noise averaging method 0 (deterministic, M=1).

        :param n_max_iterations: maximum number of iterations of gradient descent
        :type n_max_iterations: int
        """

        # (I) forward simulation
        t, exc, inh, exc_ou, inh_ou = self.simulate(self.control)  # yields x(t)

        cost = self.compute_total_cost(self.control, exc)
        print(f"Cost in iteration 0: %s" % (cost))
        if len(self.cost_history) == 0:  # add only if control model has not yet been optimized
            self.cost_history.append(cost)

        for i in range(1, n_max_iterations + 1):
            self.gradient = self.compute_gradient(self.control)

            self.step_size(-self.gradient)
            t, exc, inh, exc_ou, inh_ou = self.simulate(self.control)

            cost = self.compute_total_cost(self.control, exc)
            if i in self.print_array:
                print(f"Cost in iteration %s: %s" % (i, cost))
            self.cost_history.append(cost)

            if self.zero_step_encountered:
                print(f"Converged in iteration %s with cost %s" % (i, cost))
                break

        print(f"Final cost : %s" % (cost))

    def step_size(self, cost_gradient):
        """Adaptively choose a step size for control update.

        :param cost_gradient:   N x V x T gradient of the total cost wrt. control.
        :type cost_gradient:    np.ndarray

        :return:    Step size that got multiplied with the 'cost_gradient'.
        :rtype:     float
        """
        if self.M > 1:
            noisy = True
        else:
            noisy = False

        t, exc, inh, exc_ou, inh_ou = self.simulate(self.control)
        if noisy:
            cost0 = self.compute_cost_noisy(self.M)
        else:
            cost0 = (
                self.compute_total_cost(self.control, exc)
            )  # Current cost without updating the control according to the "cost_gradient".

        step = self.step  # Load step size of last optimization-iteration as initial guess.

        control0 = self.control  # Memorize unchanged control throughout step-size computation.

        while True:  # Reduce the step size, if numerical instability occurs in the forward-simulation.
            # inplace updating of models control bc. forward-sim relies on models parameters
            self.control = update_control_with_limit(
                self.model.params.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength
            )
            ##self.update_input()

            # Input signal might be too high and produce diverging values in simulation.
            t, exc, inh, exc_ou, inh_ou = self.simulate(self.control)

            #TODO
            """
            if np.isnan(self.get_xs()).any():  # Detect numerical instability due to too large control update.
                step *= self.factor_down**2  # Double the step for faster search of stable region.
                self.step = step
                print(f"Diverging model output, decrease step size to {step}.")
            else:
                break
            """
            break
        
        if noisy:
            cost = self.compute_cost_noisy(self.M)
        else:
            cost = (
                self.compute_total_cost(self.control, exc)
            )  # Cost after applying control update according to gradient with first valid
        # step size (numerically stable).
        # print(cost, cost0)
        if (
            cost > cost0
        ):  # If the cost choosing the first (stable) step size is no improvement, reduce step size by bisection.
            step, counter = self.decrease_step(cost, cost0, step, control0, self.factor_down, cost_gradient)

        elif (
            cost < cost0
        ):  # If the cost is improved with the first (stable) step size, search for larger steps with even better
            # reduction of cost.

            step, counter = self.increase_step(cost, cost0, step, control0, self.factor_up, cost_gradient)

        else:  # Remark: might be included as part of adaptive search for further improvement.
            step = 0.0  # For later analysis only.
            counter = 0
            self.zero_step_encountered = True

        self.step = step  # Memorize the last step size for the next optimization step with next gradient.

        self.step_sizes_loops_history.append(counter)
        self.step_sizes_history.append(step)

        return step

    def decrease_step(self, cost, cost0, step, control0, factor_down, cost_gradient):
        """Find a step size which leads to improved cost given the gradient. The step size is iteratively decreased.
        The control-inputs are updated in place according to the found step size via the
        "####self.update_input()" call.

        :param cost:    Cost after applying control update according to gradient with first valid step size (numerically
                        stable).
        :type cost:     float
        :param cost0:   Cost without updating the control.
        :type cost0:    float
        :param step:    Step size initial to the iterative decreasing.
        :type step:     float
        :param control0:    The unchanged control signal.
        :type control0:     np.ndarray N x V x T
        :param factor_down:  Factor the step size is scaled with in each iteration until cost is improved.
        :type factor_down:   float
        :param cost_gradient:   Gradient of the total cost wrt. the control signal.
        :type cost_gradient:    np.ndarray of shape N x V x T

        :return:    The selected step size and the count-variable how often step-adjustment-loop was executed.
        :rtype:     tuple[float, int]
        """
        if self.M > 1:
            noisy = True
        else:
            noisy = False

        counter = 0

        while cost > cost0:  # Decrease the step size until first step size is found where cost is improved.
            step *= factor_down  # Decrease step size.
            counter += 1
            # print(step, cost, cost0)

            # Inplace updating of models control bc. forward-sim relies on models parameters.
            self.control = update_control_with_limit(
                self.model.params.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength
            )
            #self.update_input()

            # Simulate with control updated according to new step and evaluate cost.
            t, exc, inh, exc_ou, inh_ou = self.simulate(self.control)

            if noisy:
                cost = self.compute_cost_noisy(self.M)
            else:
                cost = self.compute_total_cost(self.control, exc)

            if counter == self.count_step:  # Exit if the maximum search depth is reached without improvement of
                # cost.
                step = 0.0  # For later analysis only.
                self.control = update_control_with_limit(
                    self.model.params.N, self.dim_in, self.T, control0, 0.0, jnp.zeros_like(control0, dtype=float), self.maximum_control_strength
                )
                #self.update_input()

                self.zero_step_encountered = True
                break

        return step, counter

    def increase_step(self, cost, cost0, step, control0, factor_up, cost_gradient):
        """Find the largest step size which leads to the biggest improvement of cost given the gradient. The step size is
        iteratively increased. The control-inputs are updated in place according to the found step size via the
        "self.update_input()" call.

        :param cost:    Cost after applying control update according to gradient with first valid step size (numerically
                        stable).
        :type cost:     float
        :param cost0:   Cost without updating the control.
        :type cost0:    float
        :param step:    Step size initial to the iterative decreasing.
        :type step:     float
        :param control0:    The unchanged control signal.
        :type control0:     np.ndarray N x V x T
        :param factor_up:  Factor the step size is scaled with in each iteration while the cost keeps improving.
        :type factor_up:   float
        :param cost_gradient:   Gradient of the total cost wrt. the control signal.
        :type cost_gradient:    np.ndarray of shape N x V x T

        :return:    The selected step size and the count-variable how often step-adjustment-loop was executed.
        :rtype:     tuple[float, int]
        """
        if self.M > 1:
            noisy = True
        else:
            noisy = False

        cost_prev = cost0
        counter = 0

        while cost < cost_prev:  # Increase the step size as long as the cost is improving.
            step *= factor_up
            counter += 1

            # Inplace updating of models control bc. forward-sim relies on models parameters
            self.control = update_control_with_limit(
                self.model.params.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength
            )
            #self.update_input()

            t, exc, inh, exc_ou, inh_ou = self.simulate(self.control)
            #TODO
            """
            if np.isnan(self.get_xs()).any():  # Go back to last step (that was numerically stable and improved cost)
                # and exit.
                logging.info("Increasing step encountered NAN.")
                step /= factor_up  # Undo the last step update by inverse operation.
                self.control = update_control_with_limit(
                    self.model.params.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength
                )
                #self.update_input()
                break

            else:
            """
            if noisy:
                cost = self.compute_cost_noisy(self.M)
            else:
                cost = self.compute_total_cost(self.control, exc)

            if cost > cost_prev:  # If the cost increases: go back to last step (that resulted in best cost until
                # then) and exit.
                step /= factor_up  # Undo the last step update by inverse operation.
                self.control = update_control_with_limit(
                    self.model.params.N, self.dim_in, self.T, control0, step, cost_gradient, self.maximum_control_strength
                )
                self.update_input()
                break

            else:
                cost_prev = cost  # Memorize cost with this step size for comparison in next step-update.

            if counter == self.count_step:
                # Terminate step size search at count limit, exit with the best performing step size.
                break

        return step, counter



In [70]:
args_values = timeIntegration_args(model.params)

args = dict(zip(args_names, args_values))

In [71]:
ones_target = jnp.ones_like(args['exc_ext'], dtype=float)

In [72]:
oc_wc = OcWc(model, ones_target)

In [73]:
oc_wc.optimize_deterministic(10)

Cost in iteration 0: 99.194
Final cost : 25.754984


In [12]:
args_local = args.copy()
args_local.update(dict(zip(['exc_ext'], [oc_wc.control])))