In [1]:
from typing import List
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from torchdiffeq import odeint


In [2]:
%matplotlib notebook

plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  

def animate_2d_trajectories(trajectories, ax=None, xlim=(-5, 5), ylim=(-5, 5)):
    """
    trajectories should have shape (T, N, 2). N 2d points, at T different times.
    """
    if ax is None:
        fig, ax = plt.subplots()

    def draw_func(ax, t):
        ax.cla()
        ax.scatter(trajectories[t, :, 0], trajectories[t, :, 1], s=1)
        ax.set_title(str(t))
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
    fig, ax = plt.subplots()
    f = lambda t: draw_func(ax, t % trajectories.shape[0])
    return FuncAnimation(fig, func=f, frames=trajectories.shape[0])

In [3]:
class F(nn.Module):
    def forward(self, t, y):
        t = t.item()
        A = torch.Tensor([[np.cos(t), -np.sin(t)], [-np.sin(t), np.cos(t)]])
        return torch.mm(y, A)
dydt = F()

t_span = (0, 10)
n_snapshots = 30
n_points = 300
y0 = 2 * torch.randn(n_points, 2)



t_eval = torch.linspace(*t_span, n_snapshots)
trajectories = odeint(dydt, y0, t_eval)
animate_2d_trajectories(trajectories, xlim=(-20, 20), ylim=(-20, 20))

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [4]:
class F(nn.Module):
    def forward(self, t, y):
        A = torch.Tensor([[0, -1], [1, 0]])
        return torch.mm(y, A)
dydt = F()

t_span = (0, 6)
n_snapshots = 30
n_points = 100
y0 = 2 * torch.randn(n_points, 2)

t_eval = torch.linspace(*t_span, n_snapshots)
trajectories = odeint(dydt, y0, t_eval)
animate_2d_trajectories(trajectories, xlim=(-20, 20), ylim=(-20, 20))

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [5]:
class F(nn.Module):
    def forward(self, t, y):
        return torch.cos(y)
dydt = F()

t_span = (0, 10)
n_snapshots = 30
n_points = 100
y0 = 3 * torch.randn(n_points, 2)

t_eval = torch.linspace(*t_span, n_snapshots)
trajectories = odeint(dydt, y0, t_eval)
animate_2d_trajectories(trajectories, xlim=(-10, 10), ylim=(-10, 10))

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [6]:
class F(nn.Module):
    def forward(self, t, y):
        return torch.randn(*y.shape)
dydt = F()

t_span = (0, 1)
n_snapshots = 30
n_points = 1000
y0 = 3 * torch.randn(n_points, 2)

t_eval = torch.linspace(*t_span, n_snapshots)
trajectories = odeint(dydt, y0, t_eval, atol=1)
animate_2d_trajectories(trajectories, xlim=(-10, 10), ylim=(-10, 10))

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>