In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('../')
import zoo_of_odes
import ode_collapser
from tqdm.auto import tqdm

# ODE selection experiment

The purpose of this notebook is to verify that the ODE-collapse method really can distinguish which ODE data was generated from.
To do this, we will select a number of ODEs (or "systems"), and construct a _realization_ of that system by integrating from initial conditions, and then sample a dataset from that.

For each dataset, we will run ODE-collapser for each system, not only the systems that generated that dataset.

We will denote the datasets by, for example, `SHO1-R2` to indicate _the second realization of the first Simple Harmonic Oscillator system_.

In [None]:
systems = {
    'SHO': {
        'ode_name': 'damped_harmonic_oscillator',
        'params': {
            'omega': 2*np.pi*3.0,
            'nu': 0.0,
        },
        'initial_conditions': [
            {'amplitude0': 1.2, 'phase0': np.pi/4},
            {'amplitude0': 0.8, 'phase0': 0.0},
            {'amplitude0': 1.2, 'phase0': -np.pi/4},
        ],
        'matplot_symbol': 'o',
    },
    'DHO': {
        'ode_name': 'damped_harmonic_oscillator',
        'params': {
            'omega': 2*np.pi*3.0,
            'nu': 4.0,
        },
        'initial_conditions': [
            {'amplitude0': 1.2, 'phase0': np.pi/4},
            {'amplitude0': 0.8, 'phase0': 0.0},
            {'amplitude0': 1.2, 'phase0': -np.pi/4},
        ],
        'matplot_symbol': 'X',
    },
    'QD1': {
        'ode_name': 'harmonic_oscillator_with_quadratic_drag',
        'params': {
            'omega': 2*np.pi*3.0,
            'nu': 1.0,
        },
        'initial_conditions': [
            {'x0': 0.85, 'v0': 16.0},
            {'x0': 0.85, 'v0': -16.0},
        ],
        'matplot_symbol': 'h',
    },
    'QD2': {
        'ode_name': 'harmonic_oscillator_with_quadratic_drag',
        'params': {
            'omega': 2*np.pi*2.0,
            'nu': 0.5,
        },
        'initial_conditions': [
            {'x0': 0.85, 'v0': 16.0},
            {'x0': 0.85, 'v0': -16.0},
        ],
        'matplot_symbol': 'H',
    },
    'CA1': {
        'ode_name': 'constant_acceleration',
        'params': {'a': -10.0},
        'initial_conditions': [
            {'x0': 0.0, 'v0': 2.0},
            {'x0': 2.0, 'v0': 0.0},
        ],
        'matplot_symbol': 's',
    },
    'CA2': {
        'ode_name': 'constant_acceleration',
        'params': {'a': -5.0},
        'initial_conditions': [
            {'x0': 0.0, 'v0': 2.0},
            {'x0': 2.0, 'v0': 0.0},
        ],
        'matplot_symbol': 'D',
    },
}

In [None]:
# Parameters
t_start = 0.0
t_end = 1.0
h = 0.01
sigma = 0.1  # Noise level for data generation
N_samples = 10
rng_seed_data = 1234

In [None]:
# Get underlying solutions from either the analytic solution of a numerical integrator
underlying_solutions = {
    f'{system_name}-R{realization+1}': zoo_of_odes.get_solution(
        ode_name=system_spec['ode_name'],
        params=system_spec['params'],
        initial_conditions=ics,
        t_start=t_start,
        t_end=t_end,
        h=h,
    )
    for system_name, system_spec in systems.items()
    for realization, ics in enumerate(system_spec['initial_conditions'])
}

In [None]:
# Sample points from each realization
rng_data = np.random.RandomState(rng_seed_data)

def sample(xs, ts):
    idx_samples = rng_data.choice(ts.shape[0], size=N_samples, replace=False)
    idx_samples = np.sort(idx_samples)
    x_noise = rng_data.normal(scale=sigma, size=(N_samples,))
    t_samples = ts[idx_samples]
    x_samples = xs[idx_samples] + x_noise
    return x_samples, t_samples, idx_samples

samples = {
    realization_name: sample(x_true_grid, t_grid)
    for realization_name, (x_true_grid, t_grid) in underlying_solutions.items()
}
del rng_data

In [None]:
# You are able to plot the underlying and the data at this point if you wish.
# realization_to_plot = 'QD1-R2'
# fig, ax = plt.subplots()
# ax.plot(underlying_solutions[realization_to_plot][1], underlying_solutions[realization_to_plot][0],
#         ls='-', marker='none', label='True solution')
# ax.plot(samples[realization_to_plot][1], samples[realization_to_plot][0],
#         ls='none', marker='o', alpha=0.7, label='Samples / measurements')
# ax.set_xlim(left=underlying_solutions[realization_to_plot][1][0], right=underlying_solutions[realization_to_plot][1][-1])
# ax.set_xlabel('t')
# ax.set_ylabel('x(t)')
# ax.legend()

# plt.show()

In [None]:
# Run optimization with every system on every realization.
# Note that this may take some time.
optimization_results = {}
for realization_name, (x_samples, t_samples, idx_samples) in tqdm(samples.items(), desc='Realizations', leave=True):
    optimization_results[realization_name] = {}
    for system_name, system_spec in tqdm(systems.items(), desc='Systems', leave=False):
        optimization_results[realization_name][system_name] = ode_collapser.collapse_to_solution(
            rhs=zoo_of_odes.get_rhs_func(system_spec['ode_name'], system_spec['params']),
            h=h,
            t_start=t_start,
            t_end=t_end,
            idx_samples=idx_samples,
            x_samples=x_samples,
            show_progress=False,
        )

In [None]:
system_names_list = ['CA1', 'CA2', 'QD1', 'QD2', 'SHO', 'DHO']  # Ensure consistent order

fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(6, 8))

fig.suptitle('Timeseries samples from...')

for idx_plt, system_name in enumerate(system_names_list):
    ax = axs[idx_plt % 2][idx_plt // 2]  # Col-wise population

    # Get all of the realizations corresponding to this system
    regex_system_part = re.compile(r'(.*)-R[0-9]+')
    matching_realizations = [
        rn
        for rn in optimization_results.keys()
        if regex_system_part.match(rn).group(1) == system_name
    ]
    
    for system_name2 in system_names_list:
        is_right_system = system_name2 == system_name

        loss_data = [
            optimization_results[rn][system_name2]['log_scalars'][-1]['loss_data']
            for rn in matching_realizations
        ]
        loss_ODE = [
            optimization_results[rn][system_name2]['log_scalars'][-1]['loss_ODE']
            for rn in matching_realizations
        ]
    
        marker = systems[system_name2]['matplot_symbol']
        color = 'tab:green' if is_right_system else 'tab:red'
        ax.plot(loss_ODE, loss_data, ls='none', marker=marker, color=color)
    # Put the legend only on one plot. We choose the one with the most empty space.
    if idx_plt == 0:
        # Make fake artists to put in the legend with neutral colours
        artists_for_legend = [
            ax.plot([], [], ls='none', marker=systems[sn]['matplot_symbol'], color='tab:gray', label=sn)
            for sn in system_names_list
        ]
        ax.legend(title='Candidate\nODE', loc='upper left')
    

    ax.set_xscale('log')
    ax.set_xlim(5e-9, 5e-6)
    ax.set_yscale('log')
    ax.set_ylim(1e-3, 2)
    ax.set_title(f'...{system_name}')

    if idx_plt % 2 == 1:
        ax.set_xlabel('Loss_ODE')
    else:
        # Remove the labels from the major ticks, but keep the ticks themselves
        ax.set_xticks([1e-8, 1e-7, 1e-6], ['', '', ''])
    if idx_plt // 2 == 0:
        ax.set_ylabel('Loss_data')
    else:
        # Remove the labels from the major ticks, but keep the ticks themselves
        ax.set_yticks([1e-3, 1e-2, 1e-1, 1e0], ['', '', '', ''])

plt.show()
fig.savefig('./quantitative_experiments.png', bbox_inches='tight')

In [None]:
# Old version that plots the losses for each ODE on each timeseries separately.
# This provides more information, but also results in an unwieldy number of plots.

# for realization_name in optimization_results.keys():
#     fig, ax = plt.subplots()
    
#     regex_system_part = re.compile(r'(.*)-R[0-9]+')
#     for system_name in systems.keys():
#         is_right_system = regex_system_part.match(realization_name).group(1) == system_name
    
#         loss_data = optimization_results[realization_name][system_name]['log_scalars'][-1]['loss_data']
#         loss_ODE = optimization_results[realization_name][system_name]['log_scalars'][-1]['loss_ODE']
    
#         marker = 'x' if is_right_system else 'o'
#         color = 'tab:blue' if is_right_system else 'tab:red'
#         ax.plot(loss_ODE, loss_data, ls='none', marker=marker, color=color)
    
#     ax.set_xlabel('Loss_ODE')
#     ax.set_ylabel('Loss_data')
#     ax.set_xscale('log')
#     ax.set_yscale('log')
#     ax.set_title(f'Collapse results for realization {realization_name}')
#     plt.show()