In [None]:
%matplotlib notebook
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation

# First set up the figure, the axis, and the plot element we want to animate
fig = plt.figure(figsize=(50,5))
ax = plt.axes(xlim=(0, 10))
t = torch.linspace(0,10,1000)

# initialization function: plot the background of each frame
def init():
    line1, = plt.plot(t, y[:1000,:,0].squeeze().detach(), color="C0")
    line2, = plt.plot(t, y[:1000,:,1].squeeze().detach(), color="C1")
    plt.axis("off")
    return line1, line2,

# animation function.  This is called sequentially
def animate(i):
    line1, = plt.plot(t, y[i+1000:i+1000,:,0].squeeze().detach())
    line2, = plt.plot(t, y[i+1000:i+1000,:, 1].squeeze().detach())
    plt.axis("off")
    return line1, line2

# call the animator.  blit=True means only re-draw the parts that have changed.
anim = animation.FuncAnimation(fig, animate, init_func=init,
                               frames=200, interval=20, blit=True)

# save the animation as an mp4.  This requires ffmpeg or mencoder to be
# installed.  The extra_args ensure that the x264 codec is used, so that
# the video can be embedded in html5.  You may need to adjust this for
# your system: for more information, see
# http://matplotlib.sourceforge.net/api/animation_api.html
#anim.save('basic_animation.mp4', fps=30, extra_args=['-vcodec', 'libx264'])


In [3]:
from torchsde import sdeint
import torch 

In [4]:
class SDE(torch.nn.Module):
    noise_type = 'general'
    sde_type = 'ito'

    def __init__(self):
        super().__init__()
        self.alpha = 1.
        self.beta =  0.1
        self.gamma = 0.6
        self.delta = 0.1

    # Drift
    def f(self, t, xy):
        x = xy[...,0].reshape(-1,1)
        mask = x <= 0
        x[mask]=0
        y = xy[...,1].reshape(-1,1)
        mask = y <= 0
        y[mask] = 0
        x_new =  self.alpha*x - self.beta *x*y
        y_new = self.delta*x*y - self.gamma*y
        return torch.hstack([x_new,y_new])

    # Diffusion
    def g(self, t, xy):
        return torch.eye(2).repeat(xy.shape[0],1,1)*2

In [None]:
t = torch.linspace(0,60, 6000)
sde = SDE()

y = sdeint(sde, torch.tensor([[9.,4.]]).repeat(1,1) + torch.randn(1,2)*0.5, t)

In [None]:
%matplotlib inline
plt.figure(figsize=(60,5))
plt.plot(t,y[...,0].squeeze().detach(), color="C0", lw=3)
plt.plot(t,y[...,1].squeeze().detach(), color="C1", lw=3)
plt.axis("off")
plt.ylim()