In [1]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import optax
import numpy as np
from neurolib.models.jax.wc import WCModel
from neurolib.models.jax.wc.timeIntegration import timeIntegration_args, timeIntegration_elementwise
from neurolib.optimize.autodiff.wc_optimizer import args_names
from neurolib.control.optimal_control.oc_jax import OcWc

import logging

In [2]:
model = WCModel()

model.params.duration = 20
model.params.sigma_ou = 0
model.params.exc_init = jnp.array([[0.1]])
model.params.inh_init = jnp.array([[0.1]])

In [3]:
args_values = timeIntegration_args(model.params)

args = dict(zip(args_names, args_values))

In [4]:
ones_target = jnp.ones_like(args['exc_ext'], dtype=float)

In [5]:
oc_wc = OcWc(model, ones_target, optimizer=optax.adam(5e0))

In [6]:
oc_wc.optimize_deterministic(5)

Cost in iteration 0: 9.631590191704923
Final cost : 2.8145833213130036


### Oscillating Target

In [20]:
# We import the model
model = WCModel()

# Some parameters to define stimulation signals
dt = model.params["dt"]
duration = 10.
amplitude = 1.
period = duration /4.

# We define a "zero-input", and a sine-input
input = jnp.zeros((1, int(duration/dt) + 1))
input = input.at[0,1:-1].set(amplitude * jnp.sin(2.*jnp.pi*jnp.arange(0,duration-0.1, dt)/period)) # other functions or random values can be used as well
zero_input = jnp.zeros_like(input)

# We set the duration of the simulation and the initial values
model.params["duration"] = duration
x_init = 0.011225367461896877
y_init = 0.013126741089502588
model.params["exc_init"] = jnp.array([[x_init]])
model.params["inh_init"] = jnp.array([[y_init]])

In [23]:
# We set the stimulus in x and y variables, and run the simulation
model.params["exc_ext"] = input
model.params["inh_ext"] = zero_input
model.run()

# Define the result of the stimulation as target
target = jnp.concatenate((jnp.concatenate( (model.params["exc_init"], model.params["inh_init"]), axis=1)[:,:, jnp.newaxis],
                          jnp.stack( (model.exc, model.inh), axis=1)), axis=2).transpose((1, 0, 2))

In [25]:
oc_wc = OcWc(model, target, optimizer=optax.adam(5e0))

In [26]:
oc_wc.optimize_deterministic(10)

Cost in iteration 0: 0.0004655563711360722
Final cost : 0.2273074401605294


In [27]:
plt.plot(oc_wc.control)

NameError: name 'plt' is not defined

In [None]:
target_input = jnp.concatenate((input, zero_input), axis=0)[jnp.newaxis,:,:]

# Remove stimuli and re-run the simulation
model.params["exc_ext"] = zero_input
model.params["inh_ext"] = zero_input
control = jnp.concatenate( (zero_input,zero_input), axis=0)[jnp.newaxis,:,:]
model.run()

# combine initial value and simulation result to one array
state = jnp.concatenate((jnp.concatenate( (model.params["exc_init"], model.params["inh_init"]), axis=1)[:,:, jnp.newaxis],
    jnp.stack( (model.exc, model.inh), axis=1)), axis=2)