In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('../')
import zoo_of_odes
import ode_collapser

In [None]:
h = 0.01  # Grid resolution
rng_seed_data = 123  # Used for data generation
sigma = 0.1  # Noise level for data generation
N_samples = 10  # Number of datapoints to 'measure' from the 'true' curve
omega = 2*np.pi*5

In [None]:
t_grid = np.arange(0, 1, h)
N_grid = t_grid.shape[0]

x_left = np.sin(omega*t_grid)
x_right = np.cos(omega*t_grid)

rng = np.random.RandomState(rng_seed_data)
idx_left = rng.choice(int(N_grid*0.2), 5, replace=False)
idx_left = np.sort(idx_left)
idx_right = rng.choice(int(N_grid*0.2), 5, replace=False) + int(0.8*N_grid)
idx_right = np.sort(idx_right)

idx_samples = np.array(list(idx_left) + list(idx_right))
t_samples = t_grid[idx_samples]
x_samples = np.where(t_samples < 0.5, x_left[idx_samples], x_right[idx_samples])

# Plot the generated data
fig, ax = plt.subplots()
ax.plot(t_grid, x_left, ls='dotted', marker='none', color='tab:gray')
ax.plot(t_grid, x_right, ls='dotted', marker='none', color='black')
ax.plot(t_samples, x_samples, ls='none', marker='o', color='tab:orange', alpha=0.7, label='Samples / measurements')
ax.set_xlim(left=t_grid[0], right=t_grid[-1])
ax.set_xlabel('t')
ax.set_ylabel('x(t)')

plt.show()

In [None]:
collapser_results = ode_collapser.collapse_to_solution(
    rhs=zoo_of_odes.get_rhs_func('damped_harmonic_oscillator', {'omega': omega, 'nu': 0.0}),
    h=h,
    t_start=0.0,
    t_end=1.0,
    idx_samples=idx_samples,
    x_samples=x_samples,
    show_progress=True,
)

In [None]:
# Plot the fitted x(t)
fig, ax = plt.subplots(figsize=(6, 3))
ax.plot(t_grid, x_left, ls='dotted', marker='none', color='tab:gray')
ax.plot(t_grid, x_right, ls='dotted', marker='none', color='black')
ax.plot(t_samples, x_samples, ls='none', marker='o', color='tab:orange', alpha=0.7, label='Samples / measurements')
ax.plot(t_grid, collapser_results['x_solution_grid'], ls='--', marker='none', color='tab:green', label='Optimization result solution')
ax.set_xlim(left=t_grid[0], right=t_grid[-1])
ax.set_xlabel('t')
ax.set_ylabel('x(t)')
ax.legend()

plt.show()
fig.savefig('./mismatched_oscillators.png')