In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
import tqdm.notebook as tqdm
import time

import matplotlib.pyplot as plt
%matplotlib inline

#Fixed Step Size Solver (Euler Method)
def euler(func, t, dt, y):
    return dt*func(t, y)

#Adaptive Step Size Solver (Runge-Kutta Method)
def rk4(func, t, dt, y):
    k1 = func(t,y)
    k2 = func(t + dt/2, y + dt/2 * k1)
    k3 = func(t + dt/2, y + dt/2 * k2)
    k4 = func(t + dt/2, y +dt*k3)
    return dt/6*(k1 + 2*k2 + 2*k3 + k4)

#Neural ODE
class NeuralODE(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func
        
    def forward(self, y0, t, solver):
        solution = torch.empty(len(t), *y0.shape, dtype = y0.dtype, device = y0.device)
        solution[0] = y0
        j = 1
        for t0, t1 in zip(t[:-1], t[1:]):
            dy = solver(self.func, t0, t1-t0, y0)
            y1 = y0 + dy
            solution[j] = y1
            j += 1
            y0 = y1
        return solution

In [None]:
#Experiment

ode_test = NeuralODE(func = cos())
test_result = ode_test(y0 = y0, t = t, solver = euler)
print(test_result.size())
test_result = test_result.transpose(0,1)
print(test_result.size())
test_result2 = ode_test(y0 = y0, t = t, solver = rk4)
test_result2 = test_result2.transpose(0,1)

#Visualize
plt.plot(t.numpy(), test_result[0].detach().numpy(), label = 'euler', color = 'blue')
plt.plot(t.numpy(), test_result2[0].detach().numpy(), label = 'rk4', color = 'red', linestyle = '--' )
plt.xlabel('Time [s]')
plt.ylabel('y(t)')
plt.title('Neural ODE')
plt.grid(True)
plt.show


In [None]:
#Experiment 2 - Spiral Data Set

data_size = 2000

#generate data

true_y0 = torch.tensor([[2., 0.]])
t = torch.linspace(0., 25., data_size)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])

class Lambda(nn.Module):
    def forward(self, t, y):
        return torch.mm(y**3, true_A)
    
with torch.no_grad():
    node = NeuralODE(func = Lambda())
    true_y = node(y0 = true_y0, t = t, solver = euler)
    
def visualize(true_y, pred_y = None):
    fig = plt.figure(figsize = (6,6), facecolor='white')
    ax = fig.add_subplot(1,1,1)
    ax.set_title('Phase Portrait')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.plot(true_y.numpy()[:,0,0], true_y.numpy()[:,0,1], color = 'green', label = 'True Trajectory')
    ax.scatter(true_y.numpy()[:,0,0], true_y.numpy()[:,0,1], color = 'blue', label = 'Sampled Data', s = 1)
    if pred_y is not None:
            ax.plot(pred_y.numpy()[:,0,0], pred_y.numpy()[:,0,1], color = 'red', label = 'Predicted Trajectory')
    ax.set_xlim(-2.5, 2.5)
    ax.set_ylim(-2.5, 2.5)
    plt.legend()
    plt.grid(True)
    plt.show
        
visualize(true_y)

In [None]:
#Mini Batch

batch_time = 10
batch_size = 16

def get_batch():
    s = torch.from_numpy(np.random.choice(np.arange(data_size - batch_time, dtype= np.int64), batch_size, replace=False))
    batch_y0 = true_y[s] #Initial value of the batch
    batch_t = t[:batch_time] #Time for the batch
    batch_y = torch.stack([true_y[s + i] for i in range(batch_time)], dim = 0)
    return batch_y0, batch_t, batch_y

In [None]:
#Neural ODE
class ODEFunc(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(2,50), nn.Tanh(), nn.Linear(50,2)) #y0 is a 2D tensor
        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean = 0, std = 0.1)
                nn. init.constant_(m.bias, val = 0)
                
    def forward(self, t, y):
        output = self.net(y**3) #I know this is the structure of the ODE
        return output

In [None]:
#Train NODE
niters = 400

node = NeuralODE(func = ODEFunc())
optimizer = optim.RMSprop(node.parameters(), lr=1e-3)

start_time = time.time()

for iter in tqdm.tqdm(range(niters + 1)):
    optimizer.zero_grad()
    batch_y0, batch_t, batch_y = get_batch()
    pred_y = node(y0 = batch_y0, t = batch_t, solver = rk4)
    loss = torch.mean(torch.abs(pred_y - batch_y))
    loss.backward()
    optimizer.step()
    
    if iter % 50 == 0: #For every 50 step we check the process
        with torch.no_grad():
            pred_y = node(true_y0, t, solver = rk4)
            loss = torch.mean(torch.abs(pred_y - true_y))
            print('Iteration {:04d} | Total Loss {:.6f}'.format(iter, loss.item()))
            plt.legend()
            plt.grid(True)
            fig = plt.figure(figsize=(10, 10), facecolor='white')
            ax = fig.add_subplot(1, 1, 1)
            ax.set_title('Phase Portrait')
            ax.set_xlabel('x')
            ax.set_ylabel('y')
            ax.set_xlim(-2.5, 2.5)
            ax.set_ylim(-2.5, 2.5)
            plt.legend()
            plt.grid(True)
            visualize(true_y, pred_y)
        
end_time = time.time() - start_time
print('process time: {} sec'.format(end_time))

In [None]:
#Using Adaptive Solvers - The torchdiffeq

from torchdiffeq import odeint
#odeint(func, y0, t, rtol, atol, method)
ninters = 400

func = ODEFunc()
optimizer = optim.RMSprop(func.parameters(), lr = 1e-3)

start_time = time.time()

for iter in tqdm.tqdm(range(ninters + 1)):
    optimizer.zero_grad()
    batch_y0, batch_t, batch_y = get_batch()
    pred_y = odeint(func = func, y0 = batch_y0, t = batch_t, rtol = 1e-7, atol = 1e-9, method = 'dopri5')
    loss = torch.mean(torch.abs(pred_y - batch_y))
    loss.backward()
    optimizer.step()
    if iter % 50 == 0: #For every 50 step we check the process
        with torch.no_grad():
            pred_y = odeint(func, true_y0, t, rtol = 1e-7, atol = 1e-9, method = 'dopri5')
            loss = torch.mean(torch.abs(pred_y - true_y))
            print('Iteration {:04d} | Total Loss {:.6f}'.format(iter, loss.item()))
            visualize(true_y, pred_y)
        
end_time = time.time() - start_time
print('process time: {} sec'.format(end_time))


In [None]:
#Adjoint Backpropagation Method

from torchdiffeq import odeint_adjoint
ninters = 400

func = ODEFunc()
optimizer = optim.RMSprop(func.parameters(), lr = 1e-3)

start_time = time.time()

for iter in tqdm.tqdm(range(ninters + 1)):
    optimizer.zero_grad()
    batch_y0, batch_t, batch_y = get_batch()
    pred_y = odeint_adjoint(func = func, y0 = batch_y0, t = batch_t, rtol = 1e-7, atol = 1e-9, method = 'dopri5')
    loss = torch.mean(torch.abs(pred_y - batch_y))
    loss.backward()
    optimizer.step()
    if iter % 50 == 0: #For every 50 step we check the process
        with torch.no_grad():
            pred_y = odeint_adjoint(func, true_y0, t, rtol=1e-7, atol=1e-9, method='dopri5')
            loss = torch.mean(torch.abs(pred_y - true_y))
            print('Iteration {:04d} | Total Loss {:.6f}'.format(iter, loss.item()))
            visualize(true_y, pred_y)
        
end_time = time.time() - start_time
print('process time: {} sec'.format(end_time))
