This notebook contains a PyTorch implementation for the paper "Solving stochastic differential equations and Kolmogorov equations by means of deep learning" by Christian Beck et al.

In [91]:
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import torch
import torch.nn as nn
import torch.nn.functional as func
from torch.distributions.multivariate_normal import MultivariateNormal as mvn
# Honestly though I wanted to do this with JAX+STAX (or FLAX if anyone ever had the guts to use a 3 month old library for something like this)

In [92]:
class ItoDiffusion:
    def __init__(self, mu, sigma, N):
        # Initializing the Ito Diffusion
        # mu    :: R^dxJ -> R^dxJ
        # sigma :: R^dxJ -> R^dxdxJ
        # where J is the batch size
        # N :: Int (Number of points in discretization)
        self.mu = mu
        self.sigma = sigma
        self.N = N
    
    def sample(self, x, T):
        # Generate J samples of X_T^x
        # x :: R^dxJ
        dt        = T/self.N
        self.d    = x.size(0)
        batchsize = x.size(1)
        for i in range(self.N):
            x = x + self.mu(x)*dt + self.sigma(x)*mvn(torch.zeros(self.d), dt*torch.eye(self.d)).sample([batchsize]).T
        return x

In [93]:
class FCNN(nn.Module):
    def __init__(self, dim, hidden_layers):
        # dim    :: Int (Input dimension and also neurons in each hidden layer)
        # hidden :: Int (Number of hidden layers)
        super().__init__()
        self.model = nn.Sequential()
        for i in range(hidden_layers):
            self.model.add_module("Hidden"+str(i), nn.Linear(dim,dim))
            self.model.add_module("Batch norm "+str(i),nn.BatchNorm1d(dim))
            self.model.add_module("Actvation "+str(i),nn.Tanh())
        self.model.add_module("Output",nn.Linear(dim,1))
        

    def forward(self,x):
        # x :: R^d (Spatial point)
        return self.model(x)
        # I know, this makes is harder to debug comparing to running the 
        # input through each layer separately

In [94]:
def Lp_rel_error(neural_net, exact_sol, a,b,d, N,p):
    # neural_net :: FCNN (instance of neural network)
    # exact_sol  :: RxR^d -> R (exact solution of the PDE)
    # a          :: Float (a in [a,b]^d)
    # b          :: Float (b in [a,b]^d)
    # d          :: Int (d in [a,b]^d)
    # N          :: Int (Number of points to sample in the MC estimate)
    # p          :: Int (p in the Lp)
    spatial_points = (a-b)*torch.rand(N,d) + b
    u_exact        = exact_sol(spatial_points)
    return float((torch.mean(torch.abs((neural_net(spatial_points)[:,0] - u_exact)/u_exact).pow(p))).pow(1/p))

In [None]:
# Heat equation approximation
T,N   = 1,1      # Time value we solve PDE for, number of discretization points to sample Ito diffusion
a,b = 0,1        # a,b in [a,b]^d
d   = 10         # d   in [a,b]^d
s   = 3          # total number of layers excluding input (including output, meaning we have s-1 hidden layers)
m   = 35000      # SGD iterations
J   = 8192
learning_rate = 0.001

# The Ito diffusion corresponding to the Heat equation is Brownian motion. Here we define the initial condition as 
# u(0,x)=||x||2 (squared Euclidean norm) and via an Ansatz you can get the exact solution as u(t,x) = ||x||2 + td
brownian_motion = ItoDiffusion(lambda x: torch.zeros(x.size(0),x.size(1)), lambda x: torch.sqrt(torch.tensor(2.0)), N)
phi = lambda x: torch.norm(x,2,1)
u   = lambda x: phi(x) + T*d

# Creating our neural network model
model = FCNN(d,s-1)

# Setting the optimizer and setting the learning rate to decay
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25000, gamma=0.1)

L1_errors = []
L2_errors = []
eta  = (a-b)*torch.rand(J,d) + b


for i in range(m):
    model.train()
    X_T  = brownian_motion.sample(eta,T)
    loss = torch.mean((model(eta)[:,0]-phi(X_T))**2)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    if i%10 == 0:
        model.eval()
        L1_errors.append(Lp_rel_error(model, u, a,b,d, 1000,1))
        L2_errors.append(Lp_rel_error(model, u, a,b,d, 1000,2))
        print("Step {} L1 error is {}, L2 error is {} and loss function value is {}".format(i,L1_errors[-1],L2_errors[-1],loss))
        
        
plt.plot(range(len(L1_errors)),L1_errors)
plt.plot(range(len(L2_errors)),L2_errors)
plt.show()

Step 0 L1 error is 0.9761679768562317, L2 error is 0.9761002063751221 and loss function value is 21.64076805114746
Step 10 L1 error is 0.9777742028236389, L2 error is 0.9780119061470032 and loss function value is 21.24695587158203
Step 20 L1 error is 0.9781864285469055, L2 error is 0.9782872200012207 and loss function value is 21.14417266845703
Step 30 L1 error is 0.9766857624053955, L2 error is 0.9769785404205322 and loss function value is 20.78600311279297
Step 40 L1 error is 0.9725968241691589, L2 error is 0.9739687442779541 and loss function value is 20.484375
