In [3]:
import math
import numpy as np
from IPython.display import clear_output
from tqdm import tqdm_notebook as tqdm

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
import seaborn as sns
sns.color_palette("bright")
import matplotlib as mpl
import matplotlib.cm as cm

import torch
from torch import Tensor
from torch import nn
from torch.nn  import functional as F
from torch.autograd import Variable

def ode_solve(z0, t0, t1, f):
    """
    Simplest Euler ODE initial value solver
    """
    h_max = 0.05
    n_steps = math.ceil((abs(t1 - t0)/h_max).max().item())

    h = (t1 - t0)/n_steps
    t = t0
    z = z0

    for i_step in range(n_steps):
        z = z + h * f(z, t)
        t = t + h
    return z

class ODEF(nn.Module):
    def forward_with_grad(self, z, t, grad_outputs):
        """Compute f and a df/dz, a df/dp, a df/dt"""
        batch_size = z.shape[0]

        out = self.forward(z, t)

        a = grad_outputs
        adfdz, adfdt, *adfdp = torch.autograd.grad(
            (out,), (z, t) + tuple(self.parameters()), grad_outputs=(a),
            allow_unused=True, retain_graph=True
        )
        # grad method automatically sums gradients for batch items, we have to expand them back
        if adfdp is not None:
            adfdp = torch.cat([p_grad.flatten() for p_grad in adfdp]).unsqueeze(0)
            adfdp = adfdp.expand(batch_size, -1) / batch_size
        if adfdt is not None:
            adfdt = adfdt.expand(batch_size, 1) / batch_size
        return out, adfdz, adfdt, adfdp

    def flatten_parameters(self):
        p_shapes = []
        flat_parameters = []
        for p in self.parameters():
            p_shapes.append(p.size())
            flat_parameters.append(p.flatten())
        return torch.cat(flat_parameters)

class ODEAdjoint(torch.autograd.Function):
    @staticmethod
    def forward(ctx, z0, t, flat_parameters, func):
        assert isinstance(func, ODEF)
        bs, *z_shape = z0.size()
        time_len = t.size(0)

        with torch.no_grad():
            z = torch.zeros(time_len, bs, *z_shape).to(z0)
            z[0] = z0
            for i_t in range(time_len - 1):
                z0 = ode_solve(z0, t[i_t], t[i_t+1], func)
                z[i_t+1] = z0

        ctx.func = func
        ctx.save_for_backward(t, z.clone(), flat_parameters)
        return z

    @staticmethod
    def backward(ctx, dLdz):
        """
        dLdz shape: time_len, batch_size, *z_shape
        """
        func = ctx.func
        t, z, flat_parameters = ctx.saved_tensors
        time_len, bs, *z_shape = z.size()
        n_dim = np.prod(z_shape)
        n_params = flat_parameters.size(0)

        # Dynamics of augmented system to be calculated backwards in time
        def augmented_dynamics(aug_z_i, t_i):
            """
            tensors here are temporal slices
            t_i - is tensor with size: bs, 1
            aug_z_i - is tensor with size: bs, n_dim*2 + n_params + 1
            """
            z_i, a = aug_z_i[:, :n_dim], aug_z_i[:, n_dim:2*n_dim]  # ignore parameters and time

            # Unflatten z and a
            z_i = z_i.view(bs, *z_shape)
            a = a.view(bs, *z_shape)
            with torch.set_grad_enabled(True):
                t_i = t_i.detach().requires_grad_(True)
                z_i = z_i.detach().requires_grad_(True)
                func_eval, adfdz, adfdt, adfdp = func.forward_with_grad(z_i, t_i, grad_outputs=a)  # bs, *z_shape
                adfdz = adfdz.to(z_i) if adfdz is not None else torch.zeros(bs, *z_shape).to(z_i)
                adfdp = adfdp.to(z_i) if adfdp is not None else torch.zeros(bs, n_params).to(z_i)
                adfdt = adfdt.to(z_i) if adfdt is not None else torch.zeros(bs, 1).to(z_i)

            # Flatten f and adfdz
            func_eval = func_eval.view(bs, n_dim)
            adfdz = adfdz.view(bs, n_dim)
            return torch.cat((func_eval, -adfdz, -adfdp, -adfdt), dim=1)

        dLdz = dLdz.view(time_len, bs, n_dim)  # flatten dLdz for convenience
        with torch.no_grad():
            ## Create placeholders for output gradients
            # Prev computed backwards adjoints to be adjusted by direct gradients
            adj_z = torch.zeros(bs, n_dim).to(dLdz)
            adj_p = torch.zeros(bs, n_params).to(dLdz)
            # In contrast to z and p we need to return gradients for all times
            adj_t = torch.zeros(time_len, bs, 1).to(dLdz)

            for i_t in range(time_len-1, 0, -1):
                z_i = z[i_t]
                t_i = t[i_t]
                f_i = func(z_i, t_i).view(bs, n_dim)

                # Compute direct gradients
                dLdz_i = dLdz[i_t]
                dLdt_i = torch.bmm(torch.transpose(dLdz_i.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]

                # Adjusting adjoints with direct gradients
                adj_z += dLdz_i
                adj_t[i_t] = adj_t[i_t] - dLdt_i

                # Pack augmented variable
                aug_z = torch.cat((z_i.view(bs, n_dim), adj_z, torch.zeros(bs, n_params).to(z), adj_t[i_t]), dim=-1)

                # Solve augmented system backwards
                aug_ans = ode_solve(aug_z, t_i, t[i_t-1], augmented_dynamics)

                # Unpack solved backwards augmented system
                adj_z[:] = aug_ans[:, n_dim:2*n_dim]
                adj_p[:] += aug_ans[:, 2*n_dim:2*n_dim + n_params]
                adj_t[i_t-1] = aug_ans[:, 2*n_dim + n_params:]

                del aug_z, aug_ans

            ## Adjust 0 time adjoint with direct gradients
            # Compute direct gradients
            dLdz_0 = dLdz[0]
            dLdt_0 = torch.bmm(torch.transpose(dLdz_0.unsqueeze(-1), 1, 2), f_i.unsqueeze(-1))[:, 0]

            # Adjust adjoints
            adj_z += dLdz_0
            adj_t[0] = adj_t[0] - dLdt_0
        return adj_z.view(bs, *z_shape), adj_t, adj_p, None


class NeuralODE(nn.Module):
    def __init__(self, func):
        super(NeuralODE, self).__init__()
        assert isinstance(func, ODEF)
        self.func = func

    def forward(self, z0, t=Tensor([0., 1.]), return_whole_sequence=False):
        t = t.to(z0)
        z = ODEAdjoint.apply(z0, t, self.func.flatten_parameters(), self.func)
        if return_whole_sequence:
            return z
        else:
            return z[-1]

class LinearODEF(ODEF):
    def __init__(self, W):
        super(LinearODEF, self).__init__()
        self.lin = nn.Linear(5, 5, bias=False)
        self.lin.weight = nn.Parameter(W)

    def forward(self, x, t):
        return self.lin(x)

class RandomLinearODEF(LinearODEF):
    def __init__(self):
        super(RandomLinearODEF, self).__init__(torch.randn(5, 5)/2.)


def to_np(x):
    return x.detach().cpu().numpy()

def conduct_experiment(obs, ode_trained, n_steps, name, plot_freq=50, epoch=5):
    # Create data
    z0 = Variable(torch.Tensor([[-1.0, 0, 0.1, 0.1, 0.42]]))

    t_max = 6.29*5
    n_points = 200

    print(f"Training Epoch {epoch}...")

    index_np = np.arange(0, n_points, 1, dtype=np.int)
    index_np = np.hstack([index_np[:, None]])
    times_np = np.linspace(0, t_max, num=n_points)
    times_np = np.hstack([times_np[:, None]])

    times = torch.from_numpy(times_np[:, :, None]).to(z0)
    
    # Get trajectory of random timespan
    min_delta_time = 1.0
    max_delta_time = 5.0
    max_points_num = 32
    def create_batch():
        t0 = np.random.uniform(0, t_max - max_delta_time)
        t1 = t0 + np.random.uniform(min_delta_time, max_delta_time)

        idx = sorted(np.random.permutation(index_np[(times_np > t0) & (times_np < t1)])[:max_points_num])

        obs_ = obs[idx]
        ts_ = times[idx]
        return obs_, ts_

    # Train Neural ODE
    optimizer = torch.optim.Adam(ode_trained.parameters(), lr=0.01)
    train_losses = []
    for i in range(n_steps):
        obs_, ts_ = create_batch()
        z_ = ode_trained(obs_[0], ts_, return_whole_sequence=True)
        loss = F.mse_loss(z_, obs_.detach())

        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()

        train_losses.append(loss.item())

    final_pars = list(ode_trained.parameters())
    return train_losses, final_pars

import scipy.io
mat = scipy.io.loadmat('sim/sim1.mat')

def cal_tp_acc(mat1, mat2):
    loc_gt = mat1 == 1
    loc_sim = mat2 == 1
    n_1_gt = len(mat1[loc_gt])
    n_1_sim = len(mat2[loc_sim])
    #cal tp,acc
    n_corr = sum(mat1[loc_gt] == mat2[loc_gt])
    tp = n_corr/n_1_gt
    acc = n_corr/n_1_sim
    return tp, acc

n_epochs = 20
result = []
for i in range(0, 50):
    print(f"subject{i}")
    pt_down = i*200
    pt_up = (i+1)*200-1
    obs = mat['ts'][pt_down:pt_up]
    obs2 = obs.reshape(obs, (200, 1, 5))
    obs2 = torch.from_numpy(obs2)
    gt = mat['net'][i]
    pars_ls = []
    ode_trained = NeuralODE(RandomLinearODEF())
    for epoch in range(1, n_epochs + 1):
        train_losses, final_pars = conduct_experiment(obs2, ode_trained, 3000, i, epoch = epoch)        
        pars = final_pars[0].detach()
        pars_ls.append(pars)
    par = pars_ls[-1].numpy()
    ##evaluate performance
    par_nom = (par - np.mean(par))/np.std(par)
    par_bin = (abs(par_nom)>1).astype(int_)
    gt_bin = (gt!=0).astype(int_)
    tp, acc = cal_tp_acc(gt_bin, par_bin)
    result.append([tp,acc])
result_df = pd.DataFrame(result)
result_df.columns = ['tp', 'acc']
result_df.to_csv("sim.csv")

In [None]:
mat['ts']

In [4]:
import scipy.io
mat = scipy.io.loadmat('sim/sim1.mat')
mat

{'__header__': b'MATLAB 5.0 MAT-file, Platform: GLNXA64, Created on: Fri Aug 24 15:45:37 2012',
 '__version__': '1.0',
 '__globals__': [],
 'ts': array([[-1.68689769, -1.28406135, -0.6760282 , -2.56011633, -1.05423662],
        [-0.73026723, -0.55486006, -2.55433115, -0.52633548,  1.10091925],
        [-1.26204546, -0.78184172, -0.73325796, -0.89806382,  0.08859272],
        ...,
        [ 0.81775919, -0.74486774, -0.03691457,  2.57935351,  0.95366508],
        [-2.39926171,  0.04882025, -1.39823262, -1.50496996,  2.38226749],
        [ 2.85295637,  1.79394477,  2.63520153, -1.28656521, -1.09581382]]),
 'net': array([[[-1.        ,  0.35674352,  0.        ,  0.        ,
           0.28535286],
         [ 0.        , -1.        ,  0.23344156,  0.        ,
           0.        ],
         [ 0.        ,  0.        , -1.        ,  0.41253323,
           0.        ],
         [ 0.        ,  0.        ,  0.        , -1.        ,
           0.42876764],
         [ 0.        ,  0.        ,  0.

In [6]:
mat['net'][2]

array([[-1.        ,  0.38132914,  0.        ,  0.        ,  0.38636041],
       [ 0.        , -1.        ,  0.47257905,  0.        ,  0.        ],
       [ 0.        ,  0.        , -1.        ,  0.34116835,  0.        ],
       [ 0.        ,  0.        ,  0.        , -1.        ,  0.6       ],
       [ 0.        ,  0.        ,  0.        ,  0.        , -1.        ]])

In [21]:
mat['Ntimepoints'][0][0]

50

In [None]:
5 6 9 19 20 25, 26, 27 not 200 

In [23]:
import scipy.io
for i in range(1,29):
    mat = scipy.io.loadmat(f'sim/sim{i}.mat')
    print(f"sim{i} has {mat['Ntimepoints'][0][0]} points")

sim1 has 200 points
sim2 has 200 points
sim3 has 200 points
sim4 has 200 points
sim5 has 1200 points
sim6 has 1200 points
sim7 has 5000 points
sim8 has 200 points
sim9 has 5000 points
sim10 has 200 points
sim11 has 200 points
sim12 has 200 points
sim13 has 200 points
sim14 has 200 points
sim15 has 200 points
sim16 has 200 points
sim17 has 200 points
sim18 has 200 points
sim19 has 2400 points
sim20 has 2400 points
sim21 has 200 points
sim22 has 200 points
sim23 has 200 points
sim24 has 200 points
sim25 has 100 points
sim26 has 50 points
sim27 has 50 points
sim28 has 100 points


In [9]:
ts1 = mat['ts']

In [10]:
ts1.shape

(2500, 5)

In [12]:
import numpy as np

In [16]:
pt_down

2400

In [17]:
pt_up

2600

In [18]:
mat['ts'].shape

(2500, 5)

In [13]:
i=12
pt_down = i*50
pt_up = (i+1)*200
obs = mat['ts'][pt_down:pt_up]
obs = obs.astype('float32')
#obs2 = np.reshape(obs, (200, 1, 5))
#obs2 = torch.from_numpy(obs2)
#gt = mat['net'][i]


In [15]:
obs.shape

(100, 5)

In [13]:
mat

{'__header__': b'MATLAB 5.0 MAT-file, Platform: GLNXA64, Created on: Fri Aug 24 15:45:50 2012',
 '__version__': '1.0',
 '__globals__': [],
 'ts': array([[-0.09369639,  0.30913995,  0.9171731 , -0.96691504,  0.53896468],
        [ 0.40971644,  0.58512361, -1.41434747,  0.6136482 ,  2.24090292],
        [ 0.20502844,  0.68523219,  0.73381594,  0.56901009,  1.55566662],
        ...,
        [-0.46984943, -2.03247636, -1.32452319,  1.29174489, -0.33394354],
        [-4.31158452, -1.86350256, -3.31055543, -3.41729277,  0.46994468],
        [ 1.68266271,  0.62365111,  1.46490787, -2.45685887, -2.26610748]]),
 'net': array([[[-1.        ,  0.35674352,  0.        ,  0.        ,
           0.28535286],
         [ 0.        , -1.        ,  0.23344156,  0.        ,
           0.        ],
         [ 0.        ,  0.        , -1.        ,  0.41253323,
           0.        ],
         [ 0.        ,  0.        ,  0.        , -1.        ,
           0.42876764],
         [ 0.        ,  0.        ,  0.

In [3]:
mat['net'][2]

array([[-1.        ,  0.38132914,  0.        ,  0.        ,  0.38636041],
       [ 0.        , -1.        ,  0.47257905,  0.        ,  0.        ],
       [ 0.        ,  0.        , -1.        ,  0.34116835,  0.        ],
       [ 0.        ,  0.        ,  0.        , -1.        ,  0.6       ],
       [ 0.        ,  0.        ,  0.        ,  0.        , -1.        ]])