In [1]:
from typing import List
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from torchdiffeq import odeint

In [2]:
%matplotlib notebook

plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  

def animate_frames(frames: List[torch.Tensor], draw_func=None):
    if draw_func is None:
        def draw_func(ax, t, frames):
            ax.cla()
            ax.imshow(frames[t])
            ax.set_title(str(t))
    fig, ax = plt.subplots()
    f = lambda t: draw_func(ax, t % len(frames), frames)
    return FuncAnimation(fig, func=f, frames=len(frames))

In [5]:
class F(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, t, y):
        return torch.randn(*y.shape)

dydt = F()

t_span = (0, 10)
n_snapshots = 30
nrows = 10
ncols = 10
y0 = torch.randn(nrows, ncols)

t_eval = torch.linspace(*t_span, n_snapshots)
frames = odeint(dydt, y0, t_eval, atol=10)
animate_frames(frames)

<IPython.core.display.Javascript object>