In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import scipy.integrate as integrate

In [None]:
def _interpolate_between_samples(
        t_grid: torch.Tensor,
        idx_samples: torch.Tensor,
        x_samples: torch.Tensor,
):
    """
    Returns a linear interpolation between the points x_samples at the grid points t_grid[idx_samples].

    :param t_grid: Grid of times.
    :param idx_samples: Indices of times that x values are samples from.
    :param x_samples: Samples at times. Shape must match idx_samples.
    :return: Interpolated x values, should match the shape of t_grid.
    """
    x_interpolated = torch.tile(input=x_samples[0, :], dims=(t_grid.shape[0], 1))
    for j, (i1, i2) in enumerate(zip(idx_samples[:-1], idx_samples[1:])):
        aa = x_samples[j, :] + (x_samples[j + 1, :] - x_samples[j, :]) * ((
                    t_grid[i1 + 1:i2 + 1] - t_grid[i1]) / (t_grid[i2] - t_grid[i1])).reshape((-1, 1))
        x_interpolated[i1 + 1: i2 + 1, :] = aa
    x_interpolated[idx_samples[-1] + 1:, :] = x_samples[-1, :]
    return x_interpolated


def collapse_to_solution(
    rhs,
    h,
    t_start,
    t_end,
    idx_samples,
    x_samples,
    dim,
    transformation_x2z=None,
    N_iter=5000,
    get_w_ODE=None,
    initialize_by_interpolation=True,
    logging_freq_scalars=1,
    logging_freq_grids=10,
    show_progress=False,
    get_optimizer_from_params=None,
):
    # Get the number of grid points
    t_grid = torch.arange(t_start, t_end, h)
    n_grid = t_grid.shape[0]

    # If the inputs are not torch tensors, try and make tensors from them.
    assert x_samples.shape[1] == dim
    assert idx_samples.shape[0] == x_samples.shape[0]  # FIXME - will need to modify in the case of vector x with partial observations
    if not isinstance(idx_samples, torch.Tensor):
        idx_samples = torch.tensor(idx_samples, requires_grad=False)
    if not isinstance(x_samples, torch.Tensor):
        x_samples = torch.tensor(x_samples, requires_grad=False)
    # FIXME - check that idx_samples are sorted. Else do dual-sort of idx_samples and x_samples


    # I believe this is necessary to avoid errors if we ever move the tensors to the gpu
    # (and to avoid a deprecation warning even if they are on the cpu).
    rhs = torch.vmap(rhs, in_dims=(0, 0), out_dims=0)
    #rhs = torch.compile(rhs)

    if transformation_x2z is None:
        # # The default option transforms the problem into gradient-space plus a constant.
        # transformation_x2z = torch.zeros((n_grid, n_grid), dtype=torch.float64)
        # transformation_x2z[0, 0] = 1.0  # z_0 = x(t_0)
        # for i in range(1, n_grid):
        #     # z_i = x(t_i) - x(t_{i-1}) for i > 0
        #     transformation_x2z[i, i] = 1.0
        #     transformation_x2z[i, i - 1] = -1.0
        
        # Default transformation is the identity
        transformation_x2z = torch.eye(n_grid, dtype=torch.float64)
    else:
        if not isinstance(transformation_x2z, torch.Tensor):
            transformation_x2z = torch.tensor(transformation_x2z)

    # Default loss-weighting schedule
    if get_w_ODE is None:
        def get_w_ODE(it, n_iterations):
            if it < 0.1 * n_iterations:
                # First 10% of steps: optimize mainly for fitting the samples
                w_ode = 1e-2
            elif it >= 0.9 * n_iterations:
                # Final 90% of steps: optimize mainly for satisfying the ODE
                w_ode = 1.0
            else:
                # Linear ramp-up of w_ODE in between these iterations
                w_ode = 1e-2 + (1 - 1e-2) * (it - 0.1 * n_iterations) / (0.8 * n_iterations)
            return w_ode

    # Invert the transformation matrix. We will need this repeatedly later.
    if torch.linalg.det(transformation_x2z).item() < 1e-4:
        error_msg = f'Transformation matrix from [x(t_i)] to [z_i] has a small determinant.'
        raise ValueError(error_msg)
    transformation_z2x = torch.linalg.inv(transformation_x2z)

    # Initialize the solution grid.
    # FIXME - do I need to change this in the case of partial observations?
    #         e.g. what if I get (x, xdot) at point i, but only x at point i+1?
    # I don't believe there is any benefit to using random initialization, since this problem does not
    # have the same requirement for symmetry-breaking that exists with the hidden neurons of a neural network.
    if initialize_by_interpolation:
        # FIXME - need to make _interpolate_between_samples work in multiple dimensions even for complete data.
        x_interpolated = _interpolate_between_samples(t_grid, idx_samples, x_samples)
        z_solution_grid = (transformation_x2z @ x_interpolated).detach().clone().to(torch.float64).requires_grad_(True)
    else:
        z_solution_grid = torch.zeros((n_grid, dim), dtype=torch.float64, requires_grad=True)

    # Initialize the optimizer.
    # This problem seems to benefit from using a second-order optimizer (which LBFGS is), and I believe that
    # is due to the Hessian of loss_ODE (see below for definition) having a very large condition number.
    if get_optimizer_from_params is None:
        optimizer = torch.optim.LBFGS(lr=1, history_size=10, params=[z_solution_grid])
    else:
        optimizer = get_optimizer_from_params([z_solution_grid])

    # Loss function definitions

    # The loss for the optimization problem has two parts:
    # loss_data measures the l2 error of the data relative to our current solution x(t_0), x(t_1), ..., x(t_{N-1})
    # loss_ODE measures the l2 norm of the local violation of the ODE.
    # Our aim is to bring loss_ODE to zero while keeping loss_data as small as possible.

    # Note that loss_data is normalized by the number of samples.
    # FIXME - add an option to weight data points (and dimensions within points)
    # FIXME - does this work in multiple dimensions (genuinely unsure)
    def loss_data(z):
        x = transformation_z2x @ z
        x_at_sample_points = x[idx_samples, :]
        error_of_samples = (x_at_sample_points - x_samples)
        loss_val = 0.5 * torch.mean(error_of_samples ** 2)
        return loss_val

    # Note that this loss is a sum over only the interior points, 0, 2, ..., N-2.
    # This is because we do not have sufficient data to compute the forward derivative at N-1.
    # This should be consistent with your intuition: if we simply demand that loss_ODE = 0,
    # we would have d*(N-1) equations in d*N unknowns.
    # This would (typically) have a d-dimensional space of solutions,
    # which is what we should expect for a 1st order, d-dimensional ODE.
    def loss_ODE(z):
        x = transformation_z2x @ z
        first_deriv = h**(-1) * (x[1:] - x[:-1])
        rhs_val = rhs(0.5*(x[1:] + x[:-1]), 0.5*(t_grid[1:] + t_grid[:-1]))
        # Note the factor of h. This cancels out the implicit factor of n_grid from the sum.
        # Alternatively, think of this loss as the (approximation to) the integral of the l2-violation of the ODE.
        loss_val = 0.5 * h * torch.sum((first_deriv - rhs_val) ** 2)
        return loss_val

    # Optimization takes place below here

    # Initialize logging history
    log_scalars = []
    log_grids = []

    success_flag = True  # Set to false if an error condition is encountered
    failure_reason = None  # Should be set if success_flag gets set to False
    iterations = tqdm(range(N_iter)) if show_progress else range(N_iter)
    for iteration in iterations:
        w_ODE = get_w_ODE(iteration, N_iter)

        # Pass forward through the network
        loss_data_torch = loss_data(z_solution_grid)
        loss_ODE_torch = loss_ODE(z_solution_grid)
        #loss_total_torch = (1.0 - w_ODE) * loss_data_torch + w_ODE * loss_ODE_torch
        loss_total_torch = w_ODE*loss_ODE_torch

        # Store these for logging
        loss_val_data = loss_data_torch.detach().item()
        loss_val_ODE = loss_ODE_torch.detach().item()
        loss_val_total = loss_total_torch.detach().item()

        # Step the optimizer, updating z_solution_grid.
        optimizer.zero_grad()
        loss_total_torch.backward()
        # Stepping the LBFGS optimizer requires a closure for evaluating the loss function
        if type(optimizer) is torch.optim.LBFGS:
            #optimizer.step(lambda: (1.0 - w_ODE) * loss_data(z_solution_grid) + w_ODE * loss_ODE(z_solution_grid))
            optimizer.step(lambda: w_ODE * loss_ODE(z_solution_grid))
        else:
            optimizer.step()

        # In many of my early experiments, the solution became NaN due to numerical instability.
        # If this happens, it is useful to fail at this point.
        # It is also helpful to know which iteration this happened at.
        if z_solution_grid.isnan().any().item():
            error_msg = f'NaNs appeared in solution after the gradient descent step of iteration {iteration}'
            failure_reason = error_msg
            success_flag = False
        
        # After taking the optimization step, exactly solve for the minimum of L_data in the plane
        # parallel to the gradient.
        a = z_solution_grid.grad  # FIXME - this should really take place in x-space. Getting away with this because I'm
                                  #         using trans_x2z = eye right now.
        a = a[idx_samples]  # We only solve the quadratic L_data minimization problem in the subspace of sampled points.
                            # The other directions have zero gradient in L_data, and make the solution non-unique.
                            # FIXME - can I come up with a constraint that renders the solution still unique?
        a = a.flatten()
        # Optimize over the x . grad = b plane. Therefore we must compute b.
        b = torch.dot(x_samples.flatten(), a)
        # We don't need to numerically invert a matrix for this problem,
        # since it's possible to exactly compute the inverse matrix.
        inv_matrix = torch.eye(len(idx_samples)+1, dtype=torch.float64)
        inv_a2 = 1/(torch.linalg.norm(a, ord=2))**2
        inv_matrix[:-1, :-1] -= inv_a2*torch.outer(a, a)
        inv_matrix[:-1, -1] = inv_a2*a
        inv_matrix[-1, :-1] = inv_a2*a
        inv_matrix[-1, -1] = -inv_a2
        rhs_vec = torch.cat([x_samples.flatten(), torch.tensor([b])])
        minimizer = inv_matrix @ rhs_vec
        # FIXME - only works right now because we are using trans_x2z = eye.
        with torch.no_grad():
            z_solution_grid[idx_samples] = minimizer[:-1].reshape((-1, 1))

        # In many of my early experiments, the solution became NaN due to numerical instability.
        # If this happens, it is useful to fail at this point.
        # It is also helpful to know which iteration this happened at.
        if z_solution_grid.isnan().any().item():
            det_inv_matrix = torch.linalg.det(inv_matrix)
            error_msg = f'NaNs appeared in solution after the quadratic optimization step of iteration {iteration}.'
            # error_msg += f'\nMatrix had determinant {det_inv_matrix}:\n'
            # error_msg += f'{inv_matrix}'
            # error_msg += '\nRestricted gradient vector:\n'
            # error_msg += f'{a}'
            failure_reason = error_msg
            success_flag = False

        if iteration % logging_freq_scalars == 0:
            log_scalars.append({
                'iteration': iteration,
                'w_ODE': w_ODE,
                'loss_data': loss_val_data,
                'loss_ODE': loss_val_ODE,
                'loss_total': loss_val_total,
            })

        if iteration % logging_freq_grids == 0:
            log_grids.append({
                'iteration': iteration,
                'z_grid': z_solution_grid.detach().numpy(),
                'x_grid': (transformation_z2x @ z_solution_grid).detach().numpy(),
            })

        # Break out of the loop if we have encountered a failure reason
        if success_flag == False:
            break

    # Convert the solution back to x-space.
    # This may already be stored in the logs (if logging_freq_grids divides N_iter).
    x_solution_grid = (transformation_z2x @ z_solution_grid).detach().numpy()

    return {
        'success': success_flag,
        'failure_reason': failure_reason,
        'x_solution_grid': x_solution_grid,
        'log_scalars': log_scalars,
        'log_grids': log_grids,
    }


In [None]:
def get_w_ODE(it, N_iterations):
    return 1.0

def get_optim(params):
    return torch.optim.SGD(lr=1e-3, params=params)

In [None]:
def f_rhs(x, t):
    return x

In [None]:
sigma = 0.1
n_samples = 100

t_start = 0.0
t_end = 1.0
h = 0.01
t_grid = np.arange(t_start, t_end, h)
x_grid = np.exp(t_grid).reshape((-1, 1))
n_grid = t_grid.shape[0]

rng = np.random.RandomState(123)
idx_samples = rng.choice(n_grid, size=n_samples, replace=False)
idx_samples = np.sort(idx_samples)
x_samples = x_grid[idx_samples, :] + rng.normal(size=(n_samples, 1), scale=sigma)
t_samples = t_grid[idx_samples]
del rng

In [None]:
collapser_results = collapse_to_solution(
    f_rhs,
    h=h,
    t_start=0.0,
    t_end=1.0,
    idx_samples=idx_samples,
    x_samples=x_samples,
    dim=1,
    get_w_ODE=get_w_ODE,
    get_optimizer_from_params=get_optim,
    show_progress=True,
)

In [None]:
plt.plot(
    [collapser_results['log_scalars'][i]['loss_data'] for i in range(len(collapser_results['log_scalars']))],
    ls='-', marker='none', color='tab:blue'
)

In [None]:
plt.plot(
    [collapser_results['log_scalars'][i]['loss_ODE'] for i in range(len(collapser_results['log_scalars']))],
    ls='-', marker='none', color='tab:orange'
)

In [None]:
fig, ax = plt.subplots()

ax.plot(t_grid, x_grid, ls='-', marker='none', color='tab:blue')
ax.plot(t_samples, x_samples, ls='none', marker='o', color='tab:orange')
ax.plot(t_grid, collapser_results['x_solution_grid'], ls='--', marker='none', color='tab:green')

In [None]:
fig, ax = plt.subplots()

ax.plot([collapser_results['log_scalars'][i]['loss_data'] for i in range(len(collapser_results['log_scalars']))])
ax.plot([collapser_results['log_scalars'][i]['loss_ODE'] for i in range(len(collapser_results['log_scalars']))])
#ax.set_ylim(0.0, 1.0)
ax.set_yscale('log')

In [None]:
# Example for xdot = -(x-1)*x*(x+1)

def f_rhs(x, t):
    return -(x - 1)*x*(x + 1)


def soln(t):
    return 1/np.sqrt(1 - np.exp(-2*(t+0.1)))

sigma = 0.1
n_samples = 10

t_start = 0.0
t_end = 1.0
h = 0.01
t_grid = np.arange(t_start, t_end, h)
x_grid = soln(t_grid).reshape((-1, 1))
n_grid = t_grid.shape[0]


# Get numerical solution
ode_integrator = integrate.ode(f = lambda t, y:f_rhs(y, t))
ode_integrator.set_initial_value(np.array([1 / np.sqrt(1 - np.exp(-0.2))]))
x_numeric = np.zeros_like(t_grid)
x_numeric[0] = ode_integrator.y[0]
for i, t in enumerate(t_grid[1:]):
    x_numeric[i+1] = ode_integrator.integrate(t)[0]

rng = np.random.RandomState(123)
idx_samples = rng.choice(n_grid, size=n_samples, replace=False)
idx_samples = np.sort(idx_samples)
x_samples = x_grid[idx_samples, :] + rng.normal(size=(n_samples, 1), scale=sigma)
t_samples = t_grid[idx_samples]
del rng

collapser_results = collapse_to_solution(
    f_rhs,
    h=h,
    t_start=0.0,
    t_end=1.0,
    idx_samples=idx_samples,
    x_samples=x_samples,
    dim=1,
    show_progress=True,
)

fig, ax = plt.subplots()

ax.plot(t_grid, x_grid, ls='-', marker='none', color='tab:blue')
ax.plot(t_samples, x_samples, ls='none', marker='o', color='tab:orange')
ax.plot(t_grid, collapser_results['x_solution_grid'], ls='--', marker='none', color='tab:green')
ax.plot(t_grid, x_numeric, ls='--', marker='none', color='black')

plt.show()