In [1]:
from typing import List
import numpy as np
import scipy.integrate as spi
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from torchdiffeq import odeint

In [5]:
%matplotlib notebook

plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  

def get_ode_trajectories(dydt, t_span, n_snapshots, initial_points: np.ndarray, max_step=1e-1):
    trajectories = []
    for y0 in initial_points:
        soln = spi.solve_ivp(dydt, t_span, y0, t_eval=np.linspace(*t_span, n_snapshots), max_step=max_step)
        traj = soln.y
        trajectories.append(traj)
    return np.array(trajectories)

def animate_2d_trajectories(trajectories, xlim=(-5, 5), ylim=(-5, 5)):
    def draw_func(ax, t):
        ax.cla()
        ax.scatter(trajectories[:, 0, t], trajectories[:, 1, t], s=1)
        ax.set_title(str(t))
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
    fig, ax = plt.subplots()
    f = lambda t: draw_func(ax, t % trajectories.shape[-1])
    return FuncAnimation(fig, func=f, frames=trajectories.shape[2])

In [7]:
# 2d rotation matrix by angle t
R = lambda t: np.array([[np.cos(t), -np.sin(t)], [-np.sin(t), np.cos(t)]])

def dydt(t, y):
    return R(t) @ y

t_span = (0, 10)
dim = 2
n_points = 300
n_snapshots = 30
initial_points = 2*np.random.randn(n_points, dim)

trajectories = get_ode_trajectories(dydt, t_span, n_snapshots, initial_points, max_step=1)
animate_2d_trajectories(trajectories, xlim=(-20, 20), ylim=(-20, 20))

<IPython.core.display.Javascript object>

In [4]:
def dydt(t, y):
    A = np.array([[0, -1], [1, 0]])
    return A @ y

t_span = (0, 5)
dim = 2
n_points = 500
n_snapshots = 40
initial_points = 5*np.random.randn(n_points, dim)

trajectories = get_ode_trajectories(dydt, t_span, n_snapshots, initial_points, max_step=1)
print(trajectories.shape)
animate_2d_trajectories(trajectories, xlim=(-20, 20), ylim=(-20, 20))

(500, 2, 40)


<IPython.core.display.Javascript object>

<matplotlib.animation.FuncAnimation at 0x177f05d20>