In [3]:
import math
import numpy as np
from IPython.display import clear_output
from tqdm import tqdm_notebook as tqdm

import torch
from torch import Tensor
from torch import nn
from torch.nn  import functional as F
from torch.autograd import Variable

In [132]:
def ode_solve(z0, U, t0, t1, f):
    """
    Simplest Euler ODE initial value solver
    """
    h_max = 0.05
    n_steps = math.ceil((abs(t1 - t0)/h_max).max().item())

    h = (t1 - t0)/n_steps
    t = t0
    z = z0

    for i_step in range(n_steps):
        z = z + h * f(z, U, t)
        t = t + h
    return z

In [133]:
#the below cell for forward and backward may need to be rewrite
class ODEAdjoint(torch.autograd.Function):
    @staticmethod
    def forward(ctx, z0, U, t, flat_parameters, func):
        assert isinstance(func, ODEF)
        bs, *z_shape = z0.size()
        time_len = t.size(0)

        with torch.no_grad():
            z = torch.zeros(time_len, bs, *z_shape).to(z0)
            z[0] = z0
            for i_t in range(time_len - 1):
                z0 = ode_solve(z0, U, t[i_t], t[i_t+1], func)
                z[i_t+1] = z0

        ctx.func = func
        ctx.save_for_backward(t, U, z.clone(), flat_parameters)
        return z

    @staticmethod
    def backward(ctx, dLdz):
        """
        dLdz shape: time_len, batch_size, *z_shape
        """
        func = ctx.func
        t, U, z, flat_parameters = ctx.saved_tensors
        time_len, bs, *z_shape = z.size()
        n_dim = np.prod(z_shape)
        n_params = flat_parameters.size(0)

        # Dynamics of augmented system to be calculated backwards in time
        def augmented_dynamics(aug_z_i, t_i):
            """
            tensors here are temporal slices
            t_i - is tensor with size: bs, 1
            aug_z_i - is tensor with size: bs, n_dim*2 + n_params + 1
            """
            z_i, a = aug_z_i[:, :n_dim], aug_z_i[:, n_dim:2*n_dim]  # ignore parameters and time

            # Unflatten z and a
            z_i = z_i.view(bs, *z_shape)
            a = a.view(bs, *z_shape)
            with torch.set_grad_enabled(True):
                t_i = t_i.detach().requires_grad_(True)
                z_i = z_i.detach().requires_grad_(True)
                U_i = U.detach().requires_grad_(True)
                func_eval, adfdz, adfdt, adfdp = func.forward_with_grad(z_i, U_i, t_i, grad_outputs=a)  # bs, *z_shape
                adfdz = adfdz.to(z_i) if adfdz is not None else torch.zeros(bs, *z_shape).to(z_i)
                adfdp = adfdp.to(z_i) if adfdp is not None else torch.zeros(bs, n_params).to(z_i)
                adfdt = adfdt.to(z_i) if adfdt is not None else torch.zeros(bs, 1).to(z_i)

            # Flatten f and adfdz
            func_eval = func_eval.view(bs, n_dim)
            adfdz = adfdz.view(bs, n_dim)
            return torch.cat((func_eval, -adfdz, -adfdp, -adfdt), dim=1)

        dLdz = dLdz.view(time_len, bs, n_dim)  # flatten dLdz for convenience
        with torch.no_grad():
            ## Create placeholders for output gradients
            # Prev computed backwards adjoints to be adjusted by direct gradients
            adj_z = torch.zeros(bs, n_dim).to(dLdz)
            adj_p = torch.zeros(bs, n_params).to(dLdz)
            # In contrast to z and p we need to return gradients for all times
            adj_t = torch.zeros(time_len, bs, 1).to(dLdz)

            for i_t in range(time_len-1, 0, -1):
                z_i = z[i_t]
                t_i = t[i_t]
                f_i = func(z_i, t_i).view(bs, n_dim)

                # Compute direct gradients
                dLdz_i = dLdz[i_t]
                dLdt_i = torch.bmm(torch.transpose(dLdz_i.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]

                # Adjusting adjoints with direct gradients
                adj_z += dLdz_i
                adj_t[i_t] = adj_t[i_t] - dLdt_i

                # Pack augmented variable
                aug_z = torch.cat((z_i.view(bs, n_dim), adj_z, torch.zeros(bs, n_params).to(z), adj_t[i_t]), dim=-1)

                # Solve augmented system backwards
                aug_ans = ode_solve(aug_z, t_i, t[i_t-1], augmented_dynamics)

                # Unpack solved backwards augmented system
                adj_z[:] = aug_ans[:, n_dim:2*n_dim]
                adj_p[:] += aug_ans[:, 2*n_dim:2*n_dim + n_params]
                adj_t[i_t-1] = aug_ans[:, 2*n_dim + n_params:]

                del aug_z, aug_ans

            ## Adjust 0 time adjoint with direct gradients
            # Compute direct gradients
            dLdz_0 = dLdz[0]
            dLdt_0 = torch.bmm(torch.transpose(dLdz_0.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]

            # Adjust adjoints
            adj_z += dLdz_0
            adj_t[0] = adj_t[0] - dLdt_0
        return adj_z.view(bs, *z_shape), adj_t, adj_p, None

In [134]:
class ODEF(nn.Module):
    def forward_with_grad(self, z, U, t, grad_outputs):
        """Compute f and a df/dz, a df/dp, a df/dt"""
        batch_size = z.shape[0]

        out = self.forward(z, U, t)

        a = grad_outputs
        adfdz, adfdt, *adfdp = torch.autograd.grad(
            (out,), (z, U, t) + tuple(self.parameters()), grad_outputs=(a),
            allow_unused=True, retain_graph=True
        )
        # grad method automatically sums gradients for batch items, we have to expand them back
        if adfdp is not None:
            adfdp = torch.cat([p_grad.flatten() for p_grad in adfdp]).unsqueeze(0)
            adfdp = adfdp.expand(batch_size, -1) / batch_size
        if adfdt is not None:
            adfdt = adfdt.expand(batch_size, 1) / batch_size
        return out, adfdz, adfdt, adfdp

    def flatten_parameters(self):
        p_shapes = []
        flat_parameters = []
        for p in self.parameters():
            p_shapes.append(p.size())
            flat_parameters.append(p.flatten())
        return p_shapes, flat_parameters

class NeuralODE(nn.Module):
    def __init__(self, func):
        super(NeuralODE, self).__init__()
        assert isinstance(func, ODEF)
        self.func = func

    def forward(self, z0, U, t=Tensor([0., 1.]), return_whole_sequence=False):
        t = t.to(z0)
        z = ODEAdjoint.apply(z0, U, t, self.func.flatten_parameters(), self.func)
        if return_whole_sequence:
            return z
        else:
            return z[-1]

class LinearODEF_STI(ODEF):
    def __init__(self, W, C):
        super().__init__()
        self.lin = nn.Linear(W.shape[0], W.shape[1], bias=False)
        self.lin.weight = nn.Parameter(W)
        
        self.sti = nn.Linear(C.shape[0], C.shape[1], bias=False)
        self.sti.weight = nn.Parameter(C)
        
    def forward(self, x, U, t):
        return self.lin(x) + self.sti(U)

        
class RandomLinearODEF_STI(LinearODEF_STI):
    def __init__(self, W, C):
        super().__init__(torch.randn(W.shape[0], W.shape[1])/2., torch.randn(C.shape[0], C.shape[1])/2.)

In [135]:
def to_np(x):
    return x.detach().cpu().numpy()

def conduct_experiment(ode_true, ode_trained, n_steps, epoch=5):
    # Create data
    z0 = Variable(torch.Tensor([[0.5, 0.3, -0.1]]))

    t_max = 6.29*5
    n_points = 200

    print(f"Training Epoch {epoch}...")

    index_np = np.arange(0, n_points, 1, dtype=np.int)
    index_np = np.hstack([index_np[:, None]])
    times_np = np.linspace(0, t_max, num=n_points)
    times_np = np.hstack([times_np[:, None]])

    times = torch.from_numpy(times_np[:, :, None]).to(z0)
    obs = ode_true(z0, U, times, return_whole_sequence=True).detach()
    obs = obs + torch.randn_like(obs) * 0.01

    # Get trajectory of random timespan
    min_delta_time = 1.0
    max_delta_time = 5.0
    max_points_num = 32
    def create_batch():
        t0 = np.random.uniform(0, t_max - max_delta_time)
        t1 = t0 + np.random.uniform(min_delta_time, max_delta_time)

        idx = sorted(np.random.permutation(index_np[(times_np > t0) & (times_np < t1)])[:max_points_num])

        obs_ = obs[idx]
        ts_ = times[idx]
        return obs_, ts_

    # Train Neural ODE
    optimizer = torch.optim.Adam(ode_trained.parameters(), lr=0.01)
    train_losses = []
    for i in range(n_steps):
        obs_, ts_ = create_batch()
        z_ = ode_trained(obs_[0], U, ts_, return_whole_sequence=True)
        loss = F.mse_loss(z_, obs_.detach())

        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()

        train_losses.append(loss.item())

    print('Mean Train loss: {:.5f}'.format(np.mean(train_losses)))
    print(list(ode_trained.parameters()))
    final_pars = list(ode_trained.parameters())
    return train_losses, final_pars

# Q1:given addition parameter C

given equation below<br>
$\frac{dz}{dt} =  \begin{bmatrix}
-0.6 & 0.7 & -0.8 \\
-0.2 & 0.3 & 1.1 \\
0.2 & -0.5 & -0.5 
\end{bmatrix}z +  \begin{bmatrix}
-0.4 & 0 \\
0 & 0 \\
0 & 0 
\end{bmatrix} * u $

In [136]:
A = Tensor([[-0.6, 0.7, -0.8], [-0.2, 0.3, 1.1], [0.2, -0.5, -0.5]])
C = Tensor([[-0.4, 0], [0, 0], [0, 0]])
U = Tensor([[0.5, 0]])

ode_true = NeuralODE(LinearODEF_STI(A, C))
ode_trained = NeuralODE(RandomLinearODEF_STI(A, C))

print(ode_true.parameters)
print(ode_trained.parameters)

<bound method Module.parameters of NeuralODE(
  (func): LinearODEF_STI(
    (lin): Linear(in_features=3, out_features=3, bias=False)
    (sti): Linear(in_features=3, out_features=2, bias=False)
  )
)>
<bound method Module.parameters of NeuralODE(
  (func): RandomLinearODEF_STI(
    (lin): Linear(in_features=3, out_features=3, bias=False)
    (sti): Linear(in_features=3, out_features=2, bias=False)
  )
)>


In [142]:
index_np = np.arange(0, n_points, 1, dtype=np.int)
index_np = np.hstack([index_np[:, None]])
times_np = np.linspace(0, t_max, num=n_points)
times_np = np.hstack([times_np[:, None]])

times = torch.from_numpy(times_np[:, :, None]).to(z0)
obs = ode_true(z0, U, times, return_whole_sequence=True)
obs = obs + torch.randn_like(obs) * 0.01

In [143]:
obs

tensor([[[ 0.5062,  0.2966, -0.1073]],

        [[ 0.4803,  0.2892, -0.1017]],

        [[ 0.4405,  0.2530, -0.1007]],

        [[ 0.3884,  0.2383, -0.1042]],

        [[ 0.3733,  0.2210, -0.0983]],

        [[ 0.3499,  0.2012, -0.0822]],

        [[ 0.3351,  0.1849, -0.1084]],

        [[ 0.2965,  0.1886, -0.1146]],

        [[ 0.2527,  0.1602, -0.0799]],

        [[ 0.2508,  0.1572, -0.0806]],

        [[ 0.2087,  0.1271, -0.0654]],

        [[ 0.1769,  0.1286, -0.0899]],

        [[ 0.1434,  0.0973, -0.0824]],

        [[ 0.1303,  0.0740, -0.0563]],

        [[ 0.0919,  0.0794, -0.0725]],

        [[ 0.0773,  0.0502, -0.0751]],

        [[ 0.0517,  0.0540, -0.0387]],

        [[ 0.0169,  0.0452, -0.0559]],

        [[ 0.0015,  0.0422, -0.0475]],

        [[-0.0167,  0.0363, -0.0424]],

        [[-0.0249,  0.0129, -0.0602]],

        [[-0.0497,  0.0032, -0.0495]],

        [[-0.0687, -0.0025, -0.0358]],

        [[-0.0790,  0.0032, -0.0542]],

        [[-0.1198,  0.0089, -0.0344]],



In [None]:
min_delta_time = 1.0
max_delta_time = 5.0
max_points_num = 32


def create_batch():
    t0 = np.random.uniform(0, t_max - max_delta_time)
    t1 = t0 + np.random.uniform(min_delta_time, max_delta_time)

    idx = sorted(np.random.permutation(index_np[(times_np > t0) & (times_np < t1)])[:max_points_num])

    obs_ = obs[idx]
    ts_ = times[idx]
    return obs_, ts_

    # Train Neural ODE
optimizer = torch.optim.Adam(ode_trained.parameters(), lr=0.01)
train_losses = []

obs_, ts_ = create_batch()
z_ = ode_trained(obs_[0], U, ts_, return_whole_sequence=True)