In [1]:
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from torchdiffeq import odeint

In [2]:
%matplotlib notebook

plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150  

def animate_frames(frames: List[torch.Tensor], draw_func=None):
    if draw_func is None:
        def draw_func(ax, t, frames):
            ax.cla()
            ax.imshow(frames[t])
            ax.set_title(str(t))
    fig, ax = plt.subplots()
    f = lambda t: draw_func(ax, t % len(frames), frames)
    return FuncAnimation(fig, func=f, frames=len(frames))

In [3]:
class ResidualConvBlock(nn.Module):
    def __init__(self, d, kernel_size=3, padding=1):
        super(ResidualConvBlock, self).__init__()
        self.conv = nn.Conv2d(d, d, kernel_size, padding=padding)
        self.activation = nn.Softplus()

    def forward(self, x):
        residual = x
        out = self.conv(x)
        out = self.activation(out)
        out += residual
        return F.sigmoid(out)


class Func(nn.Module):
    def __init__(self, nchannels, nrows, ncols):
        super().__init__()
        # self.dim = nrows * ncols
        # self.fc = nn.Linear(self.dim, self.dim)
        self.rc = ResidualConvBlock(nchannels)
    
    def forward(self, t, y):
        """ y.shape should be (..., nchannels, nrows, ncols) """
        shape = y.shape
        # y = y.reshape(*shape[:-2], self.dim)
        # print('a', y.shape)
        # y = self.fc(y)
        # print('b', y.shape)
        # y = F.softplus(y)
        y = self.rc(y)
        return y
        


t_span = (0, 10)
nsnapshots = 30
nrows = 10
ncols = 10
nchannels = 3
nbatch = 1
y0 = torch.randn(nbatch, nchannels, nrows, ncols)
dydt = Func(nchannels, nrows, ncols)

t_eval = torch.linspace(*t_span, nsnapshots)
frames = odeint(dydt, y0, t_eval, atol=1)
print(frames.shape)
frames = frames[:, 0, ...] # first batch only
print(frames.shape)
frames = frames.permute(0, 2, 3, 1) # (t, c, h, w) -> (t, h, w, c)
animate_frames(frames.detach())

torch.Size([30, 1, 3, 10, 10])
torch.Size([30, 3, 10, 10])


<IPython.core.display.Javascript object>

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

In [4]:
## TODO make it so that image data in valid range
## TODO train the thing. first to evolve to a certain fixed point.
#       (think experiments from distill.pub article growing nca