In [None]:
import numpy as np
import scipy.integrate as integrate
import matplotlib.pyplot as plt
import matplotlib
from tqdm.auto import tqdm

In [None]:
# For SciPy's ODE integrator.
# We don't supply the Jacobian, since it's not defined.
def f(t, y):
    return np.array([
        y[1],  # dx/dt
        -np.sign(y[0]),  # d2x/dt2
    ])

# Plot solutions for the silly oscillator

Produce a few plots of solutions to
$$
\ddot{x} = -\mathrm{sign} (x)
$$

In [None]:
fig, ax = plt.subplots(figsize=(6, 3))

ax.plot([0.0, t_end], [0.0, 0.0], ls='--', color='tab:gray', alpha=0.5)

t_end = 1.2
t_grid = np.linspace(0.0, t_end, 100)

v0s = [0.65, 0.5, 0.35, 0.25, 0.2]
cmap = matplotlib.colormaps['pink']
colors = cmap(np.linspace(0, 0.5, len(v0s)))
for v0, c in zip(v0s, colors):
    ode_integrator = integrate.ode(f)
    ode_integrator.set_initial_value(np.array([0, v0]))
    x_grid = np.zeros_like(t_grid)
    for i, t in enumerate(t_grid[1:]):
        x_grid[i+1] = ode_integrator.integrate(t)[0]
    ax.plot(t_grid, x_grid, label=f'v(0)={v0:.2f}', ls='-', marker='none', color=c)

ax.plot([0.0, 1.0], [0.0, 0.0], ls='none', marker='o', color='tab:orange', label='Samples')

ax.set_xlim(left=0.0, right=t_end)
ax.set_xlabel('t')
ax.set_ylabel('x(t)')
ax.legend()

fig.savefig('./silly_oscillator_solns.png', bbox_inches='tight')
plt.show()

# Countour plot of $L_{\mathrm{data}}$

In [None]:
x0 = np.linspace(-.25, +.25, 1001)
v0 = np.linspace(-0.7, +0.7, 501)

X, V = np.meshgrid(x0, v0)

In [None]:
L = np.zeros_like(X)

for i in tqdm(range(X.shape[0]), leave=True):
    for j in range(X.shape[1]):
        ode_integrator = integrate.ode(f)
        ode_integrator.set_initial_value(np.array([X[i, j], V[i, j]]))
        x1 = ode_integrator.integrate(1)[0]

        loss = 0.25 * (X[i, j]**2 + x1**2)
        L[i, j] = loss

In [None]:
fig, ax = plt.subplots(figsize=(6, 3))

levels = [1e-5, 2e-5, 3e-5, 1e-4, 2e-4, 3e-4, 1e-3, .01, .02, .03, .1]
colors = matplotlib.colormaps['autumn'](np.linspace(0.2, 0.8, len(levels)))

ax.contour(V, X, L, levels=levels, colors=colors, linewidths=0.5)
ax.set_ylabel('x(0)')
ax.set_xlabel('v(0)')

ax.plot([0.5, -0.5, .25, -.25, 1/6, -1/6, 1/8, -1/8], [0.0 for _ in range(8)], ls='none', marker='x',
        color='red')

ax.set_aspect(1.0)

fig.savefig('./Ldata_contours.png')
plt.show()