In [25]:
import jax
import jax.numpy as jnp
import diffrax as dfx
from jax import random

# Define the ODE function
def ode_func(t, y, params):
    return y * params

# Modified function to solve the ODE and track the maximum value with params
def solve_ode_and_track_max(ode_func, y0, t0, t1, params, max_steps=1000):
    def max_value_ode(t, y, args):
        # args is a tuple (params, max_val)
        params, max_val = args
        dy_dt = ode_func(t, y[0], params)
        max_val = jnp.maximum(max_val, y[0])[0]
        return dy_dt, max_val

    solver = dfx.Dopri5()
    initial_max = jnp.array(y0[0])  # Initial max is the initial condition
    max_callback = dfx.SaveAt(ts=jnp.linspace(t0, t1, max_steps))
    
    # dy_dt, max_val = max_value_ode(1, (jnp.array(y0), initial_max), (params, initial_max))

    # Solve the ODE with an auxiliary variable to track the max
    solution = dfx.diffeqsolve(
        dfx.ODETerm(max_value_ode),
        solver=solver,
        t0=t0,
        t1=t1,
        dt0=0.1,
        y0=(jnp.array(y0), initial_max),
        args=(params, initial_max),  # Pass params and initial max as args
        saveat=max_callback,
    )
    
    max_value = solution.ys[1][-1]  # Get the max value reached
    return max_value, solution

# Define a loss function to optimize params
def loss_fn(params, y0, t0, t1, target_max):
    max_value = solve_ode_and_track_max(ode_func, y0, t0, t1, params)
    return (max_value - target_max) ** 2  # Squared error loss

# Define optimization step
@jax.jit
def optimize_step(params, y0, t0, t1, target_max, learning_rate=0.01):
    grads = jax.grad(loss_fn)(params, y0, t0, t1, target_max)
    return params - learning_rate * grads

# Initialize parameters and settings
key = random.PRNGKey(0)
params = random.normal(key, ())  # Start with a random parameter
y0 = jnp.array([1.0])            # Initial condition
t0, t1 = 0.0, 10.0               # Start and end times
target_max = 5.0                 # Desired maximum value
n_steps = 1000                   # Number of optimization steps

# Optimization loop
for i in range(n_steps):
    params = optimize_step(params, y0, t0, t1, target_max)
    if i % 100 == 0:
        current_loss = loss_fn(params, y0, t0, t1, target_max)
        print(f"Step {i}, Loss: {current_loss}, Params: {params}")

print(f"Optimized Params: {params}")


Step 0, Loss: 36.00006866455078, Params: -0.2058422565460205
Step 100, Loss: 36.00006866455078, Params: -0.2058422565460205
Step 200, Loss: 36.00006866455078, Params: -0.2058422565460205
Step 300, Loss: 36.00006866455078, Params: -0.2058422565460205
Step 400, Loss: 36.00006866455078, Params: -0.2058422565460205


KeyboardInterrupt: 