In [1]:
import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import argparse

parser = argparse.ArgumentParser('ODE Test')
parser.add_argument('--method', type=str, choices=['dopri5', 'adams'], default='adams')
parser.add_argument('--data_size', type=int, default=1000)
parser.add_argument('--batch_time', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=20)
parser.add_argument('--niters', type=int, default=1000)
parser.add_argument('--test_freq', type=int, default=20)
parser.add_argument('--adjoint', action='store_true', default=False)
args = parser.parse_args(args=[])

if args.adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint

In [2]:
def visualize(true_y: torch.Tensor, pred_y: torch.Tensor, odefunc):
    t_np = t.cpu().numpy()
    true_y_np = true_y.cpu().numpy()
    pred_y_np = pred_y.cpu().numpy()

    # Clear previous contents
    ax_traj.cla()
    ax_phase.cla()
    ax_vecfield.cla()

    # Trajectories
    ax_traj.set_title('Trajectory')
    ax_traj.set_xlabel('Time')
    ax_traj.set_ylabel('Position')
    ax_traj.plot(t_np, true_y_np[:, 0, 0], 'g-', label='True X')
    ax_traj.plot(t_np, true_y_np[:, 0, 1], 'g-', label='True Y')
    ax_traj.plot(t_np, pred_y_np[:, 0, 0], 'b--', label='Pred X')
    ax_traj.plot(t_np, pred_y_np[:, 0, 1], 'b--', label='Pred Y')
    ax_traj.set_xlim(t_np.min(), t_np.max())
    ax_traj.set_ylim(-2, 2)
    ax_traj.legend()

    # Phase Spaces
    ax_phase.set_title('Phase Space')
    ax_phase.set_xlabel('Position')
    ax_phase.set_ylabel('Momentum')
    ax_phase.plot(true_y_np[:, 0, 0], true_y_np[:, 0, 1], 'g-')
    ax_phase.plot(pred_y_np[:, 0, 0], pred_y_np[:, 0, 1], 'b--')
    ax_phase.set_xlim(-2, 2)
    ax_phase.set_ylim(-2, 2)

    # Vector Field
    ax_vecfield.set_title('Vector Field')
    ax_vecfield.set_xlabel('X')
    ax_vecfield.set_ylabel('Y')
    y, x = np.mgrid[-2:2:21j, -2:2:21j]
    dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2)).to(device)).cpu().detach().numpy()
    mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
    dydt = (dydt / mag)
    dydt = dydt.reshape(21, 21, 2)

    ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")
    ax_vecfield.set_xlim(-2, 2)
    ax_vecfield.set_ylim(-2, 2)


In [3]:
class ODEFunc(nn.Module):
    def __init__(self):
        super(ODEFunc, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2, 50),
            nn.Tanh(),
            nn.Linear(50, 2),
        )

        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)

    def forward(self, t, y):
        return self.net(y**3)

In [4]:
def get_batch():
    s = torch.from_numpy(
        np.random.choice(
            np.arange(args.data_size - args.batch_time, dtype=np.int64),
            args.batch_size,
            replace=False
        )
    )
    batch_y0 = true_y[s]  # (M, D)
    batch_t = t[:args.batch_time]  # (T)
    batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0)  # (T, M, D)
    return batch_y0.to(device), batch_t.to(device), batch_y.to(device)
        
device = torch.device('cpu')
true_y0 = torch.tensor([[2., 0.]]).to(device)
t = torch.linspace(0., 25., args.data_size).to(device)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(device)

class Lambda(nn.Module):
    def forward(self, t, y):
        return torch.mm(y**3, true_A)

with torch.no_grad():
    true_y = odeint(Lambda(), true_y0, t, method=args.method)

ValueError: Invalid method "adams". Must be one of {"dopri8", "dopri5", "bosh3", "fehlberg2", "adaptive_heun", "euler", "midpoint", "rk4", "explicit_adams", "implicit_adams", "fixed_adams", "scipy_solver"}.

In [None]:


from matplotlib.animation import FuncAnimation
from functools import partial
from IPython.display import HTML

fig, (ax_traj, ax_phase, ax_vecfield) = plt.subplots(1, 3, figsize=(12, 4))
func = ODEFunc().to(device)
optimizer = optim.RMSprop(func.parameters(), lr=1e-3)

def train_loop(optimizer, true_y0, t):
    for itr in range(1, args.niters + 1):
        optimizer.zero_grad()
        batch_y0, batch_t, batch_y = get_batch()
        pred_y = odeint(func, batch_y0, batch_t).to(device)
        loss = torch.mean(torch.abs(pred_y - batch_y))
        loss.backward()
        optimizer.step()

        if itr % args.test_freq == 0:
            with torch.no_grad():
                test_pred_y = odeint(func, true_y0, t)
                test_loss = torch.mean(torch.abs(test_pred_y - true_y))
                print(f'Loss {test_loss.item()}')
                yield true_y, test_pred_y

def viz_update(frame):
    true_y, pred_y = frame
    visualize(true_y, pred_y, func)

if __name__ == '__main__':
    viz_frames = list(train_loop(optimizer, true_y0, t))

    ani = FuncAnimation(
        fig,
        viz_update,
        frames=viz_frames,
        repeat=False
    )
    html_animation = HTML(ani.to_jshtml())
    display(html_animation)
