In [67]:
# The basic thing
import math
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes._axes import Axes

# The class thing
from abc import ABC, abstractmethod
from typing import Optional
import seaborn as sns

# The pytorch 
import torch
import torch.distributions as D
from torch.func import vmap, jacrev

from tqdm import tqdm

# The device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(type(device), device)

<class 'torch.device'> cpu


In [69]:
class ODE(ABC):
    @abstractmethod
    def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Returns the drift coefficient of the ODE.
        Args:
            - xt: state at time t, shape (batch_size, dim)
            - t: time, shape ()
        Returns:
            - drift_coefficient: shape (batch_size, dim)
        """
        pass

class SDE(ABC):
    @abstractmethod
    def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # Args and Returns are the same as the previous
        pass

    @abstractmethod
    def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # Args are the same as the previous, and the Returns has the same dtype
        pass
    

In [65]:
class Simulator(ABC):
    @abstractmethod
    def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor):
        """
        Takes one simulation step
        Args:
            - xt: state at time t, shape (batch_size, dim)
            - t: time, shape ()
            - dt: time, shape ()
        Returns:
            - nxt: state at time t + dt
        """
        pass

    @torch.no_grad() # forbidden the autodif to save space
    def simulate(self, x: torch.Tensor, ts: torch.tensor):
        """
        Simulates using the discretization gives by ts
        Args:
            - x_init: initial state at time ts[0], shape (batch_size, dim)
            - ts: timesteps, shape (nts,)
        Returns:
            - x_final: final state at time ts[-1], shape (batch_size, dim)
        """
        for t_index in range(len(ts) - 1):
            t = ts[t_index]
            h = ts[t_index + 1] - ts[t_index]
            x = self.step(x, t, h)
        return x

    @torch.no_grad()
    def simulate_with_trajectory(self, x: torch.Tensor, ts: torch.Tensor):
        """
        Simulates using the discretization gives by ts
        Args:
            - x_init: initial state at time ts[0], shape (batch_size, dim)
            - ts: timesteps, shape (num_timesteps,)
        Returns:
            - xs: trajectory of xts over ts, shape (batch_size, num_timesteps, dim)
        """
        xs = [x.clone]
        for t_index in tqdm(range(len(ts) - 1)):
            t = ts[t_index]
            h = ts[t_index + 1] - ts[t_index]
            x = self.step(x, t, h)
            xs.append(x.clone())
        return torch.stack(xs, dim = 1)
            

In [71]:
class EulerMethod(Simulator):
    def __init__(self, ode: ODE):
        self.ode = ode

    def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor):
        drift_c = self.ode.drift_coefficient(xt, t)
        x = xt + h * drift_c
        return x

In [75]:
class EulerMaruyamaSimulator(Simulator):
    def __init__(self, sde: SDE):
        self.sde = sde

    def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor):
        drift_c = self.sde.drift_coefficient(xt, t)
        diffusion_c = self.sde.diffusion_coefficient(xt, t)
        noise = torch.randn_like(xt)
        x = xt + drift_c * h + diffusion_c * torch.sqrt(h) * noise
        return x