In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('../')
import zoo_of_odes
import ode_collapser

In [None]:
# Parameters

ode_name = 'quartic_double_well'
ode_params = {
    'a': 2*np.pi*3/np.sqrt(2),  # 3Hz oscillation in the wells
    'nu': 3.0
}

ode_initial_conditions_1 = {
    'x0': 1e-3,
    'v0': 0.0,
}

ode_initial_conditions_2 = {
    'x0': 1e-11,
    'v0': 0.0,
}

ode_initial_conditions_3 = {
    'x0': -1e-3,
    'v0': 0.0,
}

t_start=0.0
t_end=3.0
h = 0.01  # Grid resolution
rng_seed_data = 123  # Used for data generation
sigma = 0.1  # Noise level for data generation
N_samples = 10  # Number of datapoints to 'measure' from the 'true' curve

In [None]:
# Get solution for the ODE
x_true_grid_1, t_grid = zoo_of_odes.get_solution(
    ode_name,
    params=ode_params,
    initial_conditions=ode_initial_conditions_1,
    t_start=t_start,
    t_end=t_end,
    h=h,
)

# Get solution for the ODE
x_true_grid_2, t_grid = zoo_of_odes.get_solution(
    ode_name,
    params=ode_params,
    initial_conditions=ode_initial_conditions_2,
    t_start=t_start,
    t_end=t_end,
    h=h,
)

# Get solution for the ODE
x_true_grid_3, t_grid = zoo_of_odes.get_solution(
    ode_name,
    params=ode_params,
    initial_conditions=ode_initial_conditions_3,
    t_start=t_start,
    t_end=t_end,
    h=h,
)

fig, ax = plt.subplots(figsize=(6, 3))
ax.plot(t_grid, x_true_grid_1, ls='-', color='tab:blue')
ax.plot(t_grid, x_true_grid_2, ls='-', color='tab:purple')
ax.plot(t_grid, x_true_grid_3, ls='-', color='tab:red')
ax.set_xlabel('t')
ax.set_ylabel('x(t)')
ax.set_xlim(t_grid[0], t_grid[-1])
plt.show()

In [None]:
rng_data = np.random.RandomState(rng_seed_data)  # Instantiate the RNG in the same cell that we will do all the calls.
N_grid = t_grid.shape[0]

# Draw different sample sets for each of the three curves
def get_samples(x_true):
    idx_samples = rng_data.choice(N_grid, size=N_samples, replace=False)  # Choose which datapoints we will 'measure'
    idx_samples = np.sort(idx_samples)
    x_noise = rng_data.normal(scale=sigma, size=(N_samples,))  # Sample noise to be added to our datapoints.
    t_samples = t_grid[idx_samples]  # Used only for plotting in this example
    x_samples = x_true[idx_samples] + x_noise  # Noisy datapoints
    return t_samples, idx_samples, x_samples
t_samples_1, idx_samples_1, x_samples_1 = get_samples(x_true_grid_1)
t_samples_2, idx_samples_2, x_samples_2 = get_samples(x_true_grid_2)
t_samples_3, idx_samples_3, x_samples_3 = get_samples(x_true_grid_3)

del rng_data  # Delete to prevent re-use of this RNG in the solution section.

# Plot the generated data
fig, ax = plt.subplots(figsize=(6, 3))
ax.plot(t_grid, x_true_grid_1, ls='-', color='tab:blue')
ax.plot(t_samples_1, x_samples_1, ls='none', marker='o', color='tab:blue', alpha=0.7)
ax.plot(t_grid, x_true_grid_2, ls='-', color='tab:purple')
ax.plot(t_samples_2, x_samples_2, ls='none', marker='o', color='tab:purple', alpha=0.7)
ax.plot(t_grid, x_true_grid_3, ls='-', color='tab:red')
ax.plot(t_samples_3, x_samples_3, ls='none', marker='o', color='tab:red', alpha=0.7)
ax.set_xlabel('t')
ax.set_ylabel('x(t)')
ax.set_xlim(t_grid[0], t_grid[-1])
plt.show()



In [None]:
# Define our own version of this, since we want a smalled w_ode (1e-4) during the warmup-period than the default (1e-2)
def get_w_ODE(it, n_iterations):
    if it < 0.1 * n_iterations:
        # First 10% of steps: optimize mainly for fitting the samples
        w_ode = 1e-4
    elif it >= 0.9 * n_iterations:
        # Final 90% of steps: optimize mainly for satisfying the ODE
        w_ode = 1.0
    else:
        # Linear ramp-up of w_ODE in between these iterations
        w_ode = 1e-4 + (1 - 1e-4) * (it - 0.1 * n_iterations) / (0.8 * n_iterations)
    return w_ode

collapser_results_1 = ode_collapser.collapse_to_solution(
    rhs=zoo_of_odes.get_rhs_func(ode_name, ode_params),
    h=h,
    t_start=t_start,
    t_end=t_end,
    idx_samples=idx_samples_1,
    x_samples=x_samples_1,
    show_progress=True,
    get_w_ODE=get_w_ODE,
)

collapser_results_2 = ode_collapser.collapse_to_solution(
    rhs=zoo_of_odes.get_rhs_func(ode_name, ode_params),
    h=h,
    t_start=t_start,
    t_end=t_end,
    idx_samples=idx_samples_2,
    x_samples=x_samples_2,
    show_progress=True,
    get_w_ODE=get_w_ODE,
)

collapser_results_3 = ode_collapser.collapse_to_solution(
    rhs=zoo_of_odes.get_rhs_func(ode_name, ode_params),
    h=h,
    t_start=t_start,
    t_end=t_end,
    idx_samples=idx_samples_3,
    x_samples=x_samples_3,
    show_progress=True,
    get_w_ODE=get_w_ODE,
)

# Print the headline result: how well did we fit the data?
print(f'Loss due to MSE of data (1): {collapser_results_1["log_scalars"][-1]["loss_data"]}')
print(f'Loss due to ODE violation (1): {collapser_results_1["log_scalars"][-1]["loss_ODE"]}')
print(f'Loss due to MSE of data (2): {collapser_results_2["log_scalars"][-1]["loss_data"]}')
print(f'Loss due to ODE violation (2): {collapser_results_2["log_scalars"][-1]["loss_ODE"]}')
print(f'Loss due to MSE of data (3): {collapser_results_3["log_scalars"][-1]["loss_data"]}')
print(f'Loss due to ODE violation (3): {collapser_results_3["log_scalars"][-1]["loss_ODE"]}')

# Plot the fitted x(t)
fig, ax = plt.subplots(figsize=(6, 3))
ax.plot(t_grid, x_true_grid_1, ls='-', color='tab:blue')
ax.plot(t_samples_1, x_samples_1, ls='none', marker='o', color='tab:blue', alpha=0.7)
ax.plot(t_grid, collapser_results_1['x_solution_grid'], ls='--', marker='none', color='tab:blue')
ax.plot(t_grid, x_true_grid_2, ls='-', color='tab:purple')
ax.plot(t_samples_2, x_samples_2, ls='none', marker='o', color='tab:purple', alpha=0.7)
ax.plot(t_grid, collapser_results_2['x_solution_grid'], ls='--', marker='none', color='tab:purple')
ax.plot(t_grid, x_true_grid_3, ls='-', color='tab:red')
ax.plot(t_samples_3, x_samples_3, ls='none', marker='o', color='tab:red', alpha=0.7)
ax.plot(t_grid, collapser_results_3['x_solution_grid'], ls='--', marker='none', color='tab:red')
ax.set_xlabel('t')
ax.set_ylabel('x(t)')
ax.set_xlim(t_grid[0], t_grid[-1])
plt.show()

fig.savefig('./double_well_collapser.png', bbox_inches='tight')

In [None]:
# Inspect the timeseries of the losses for diagnostic purposes if optimization is poor
fig, ax = plt.subplots()

ax.plot([d['loss_data'] for d in collapser_results_1["log_scalars"]], color='tab:blue')
ax.plot([d['loss_ODE'] for d in collapser_results_1["log_scalars"]], color='tab:red')

ax.set_ylim(bottom=0.0, top=300.0)

ax2 = ax.twinx()
ax2.plot([d['w_ODE'] for d in collapser_results_1["log_scalars"]], ls='--', color='tab:orange')

In [None]:
# Inspect the timeseries of the losses for diagnostic purposes if optimization is poor
fig, ax = plt.subplots()

ax.plot([d['loss_data'] for d in collapser_results_2["log_scalars"]], color='tab:blue')
ax.plot([d['loss_ODE'] for d in collapser_results_2["log_scalars"]], color='tab:red')

ax.set_ylim(bottom=0.0, top=300.0)

ax2 = ax.twinx()
ax2.plot([d['w_ODE'] for d in collapser_results_2["log_scalars"]], ls='--', color='tab:orange')