In [None]:
import diffrax
import equinox as eqx
import imageio.v3 as iio
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from tqdm.auto import trange

import optimal_control.constraints as constraints
import optimal_control.controls as controls
import optimal_control.environments.examples as examples
import optimal_control.solvers as solvers
import optimal_control.trainers as trainers

# As a constrained control in a training run

# As a control with constraints

In [None]:
key = jax.random.PRNGKey(1234)
control = controls.ImplicitControl(controls.Siren(1, 2, 256, 4, key), 0.0, 1.0)
constraint = constraints.NonNegativeConstantIntegralConstraint(jnp.asarray([1.0, 2.0]))

constrained_control = constraint.transform_continuous(control)

t = jnp.linspace(0.0, 1.0, 1024).reshape(-1, 1)
signal = jax.vmap(control)(t)

constrained_signal1 = jax.vmap(constrained_control)(t)
constrained_signal2 = constraint.transform(signal) * 1024

fig, ax = plt.subplots(3, 1, sharex=True)
ax[0].plot(t, signal)
ax[1].plot(t, constrained_signal1)
ax[2].plot(t, constrained_signal2)
plt.show()

# Fit an image

In [None]:
# Load image

image = iio.imread("../data/testpattern.png")

plt.figure()
plt.imshow(image, origin="upper")
plt.show()

# Convert to suitable format
data = image
data = data.reshape(-1, 3)
data = (data.astype(np.float32) / 255.0) * 2 - 1

In [None]:
# Prepare network

coords = jnp.stack(
    jnp.meshgrid(
        jnp.linspace(-1.0, 1.0, image.shape[0]), jnp.linspace(-1.0, 1.0, image.shape[1])
    ),
    axis=-1,
).reshape(-1, 2)

key = jax.random.PRNGKey(1234)
siren_net = controls.Siren(2, 3, 256, 4, key)

optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(eqx.partition(siren_net, eqx.is_array)[0])

In [None]:
# Train

@eqx.filter_value_and_grad
def loss_fn(model, x, y):
    return jnp.mean(jnp.square(y - jax.vmap(model)(x)))

for i in trange(64):
    loss, grads = loss_fn(siren_net, coords, data)

    params, static = eqx.partition(siren_net, eqx.is_array)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)

    siren_net = eqx.combine(params, static)

In [None]:
# Evaluate

pred_data = jax.vmap(siren_net)(coords)

plt.figure()
plt.imshow(pred_data.reshape(*image.shape) * 0.5 + 0.5)
plt.show()