In [None]:
import jax
import jax.numpy as jnp

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

import matplotlib.pyplot as plt

import optimal_control.constraints as constraints
import optimal_control.trainers as trainers
import optimal_control.solvers as solvers
import optimal_control.controls as controls
import optimal_control.environments.examples as examples

from jaxtyping import ArrayLike
import equinox as eqx
import optax

In [None]:
key = jax.random.PRNGKey(1234)

key, subkey = jax.random.split(key)
control = controls.ImplicitControl(controls.Siren(1, 1, 32, 2, subkey), 0, 10 * 60)


def normal_pdf(x, mean, std):
    return 1 / (std * jnp.sqrt(2 * jnp.pi)) * jnp.exp(-0.5 * ((x - mean) / std) ** 2)


num_control_points = 128
kernel_x = jnp.linspace(control.t_start, control.t_end, num_control_points).reshape(
    -1, 1
)
kernel = normal_pdf(
    x=kernel_x,
    mean=(control.t_end - control.t_start) / 2,
    std=60.0,
)
integral = jnp.full((1,), 100.0)
constraint_chain = [
    constraints.NonNegativeConstantIntegralConstraint(integral),
    constraints.ConvolutionConstraint(
        kernel=kernel,
        padding_type="clip",
        pad_left=kernel.shape[0],
        pad_right=kernel.shape[0],
    ),
    constraints.ConstantIntegralConstraint(integral),
]

solver = solvers.DirectSolver(
    optimizer=optax.adam(learning_rate=1e-4),
    num_control_points=num_control_points,
)

environment = examples.StressEnvironment(
    couples_filepath="/home/lena/master-thesis/repos/optimal-control/data/Repository_data_210919.mat",
    couple_idx=-1,
)
environment_state = environment.init()


def reward_fn(args):
    ys, sg = args

    return -jnp.mean(sg)


plt.figure()
plt.plot(kernel_x, kernel)
plt.show()

In [None]:
key, subkey = jax.random.split(key)
optimized_reward, optimized_control = trainers.solve_optimal_control_problem(
    num_train_steps=4096,
    environment=environment,
    reward_fn=reward_fn,
    constraint_chain=constraint_chain,
    solver=solver,
    control=control,
    key=subkey,
    pbar_interval=10,
)

In [None]:
optimized_reward

In [None]:
from optimal_control.solvers.direct import build_control

@eqx.filter_jit
def integrate(
    control: controls.AbstractControl,
    environment: examples.StressEnvironment,
    environment_state: examples.StressState,
) -> ArrayLike:
    env_seq = environment.integrate(control, environment_state, key)

    return env_seq

In [None]:
eval_control = build_control(
    control=optimized_control,
    constraint_chain=constraint_chain,
    num_points=num_control_points,
)
eval_t = jnp.linspace(0.0, 20 * 60, 20 * 60)

eval_seq, eval_sg = integrate(eval_control, environment, environment_state)

plt.figure()
plt.plot(eval_t, eval_seq)
plt.show()

plt.figure()
plt.ylim([-0.05, 1.05])
plt.plot(eval_t, eval_sg)
plt.show()

plt.figure()
#plt.plot(eval_t, constraint_chain[0].transform(jax.vmap(optimized_control)(eval_t.reshape(-1, 1))))
plt.plot(eval_t, eval_control(eval_t))
plt.show()

In [None]:
eval_control = controls.LambdaControl(lambda t: jnp.full_like(t, 100.0).reshape(-1))
eval_t = jnp.linspace(0.0, 20 * 60, 20 * 60)

eval_seq, eval_sg = integrate(eval_control, environment, environment_state)

print(reward_fn((eval_seq, eval_sg)))

plt.figure()
plt.plot(eval_t, eval_seq)
plt.show()

plt.figure()
plt.ylim([-0.05, 1.05])
plt.plot(eval_t, eval_sg)
plt.show()

plt.figure()
plt.plot(eval_t, eval_control(eval_t))
plt.show()