Note that this repository is no longer actively maintained. Please use the latest version of deep branching solver.
Authors: Jiang Yu Nguwi and Nicolas Privault.
If this code is used for research purposes, please cite as
J.Y. Nguwi, G. Penent, and N. Privault.
A deep branching solver for fully nonlinear partial differential equations.
arXiv preprint arXiv:2203.03234, 2022.
Deep branching solver based on [NPP22]
aims to solve
fully nonlinear PDEs of the form
with
We present the use of the deep branching solver with the particular example of and
This example admits the true PDE solution of
For illustration purposes, suppose that we are only interested in the solution u(0, x) for
There are two ways to utilize the deep branching solver:
- Edit the templates inside the
__main__
environment inbranch.py
, then runpython branch.py
from your terminal. - Write your own code and import the solver to your code via
from branch import Net
, see the sections defining the derivatives map and the functions and training the model.
Functions f and g must be written in the PyTorch framework, e.g.
import torch
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
T, x_lo, x_hi, dim = .05, -4.0, 4.0, 3
# deriv_map is n x d array defining lambda_1, ..., lambda_n
deriv_map = np.array(
[
[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
]
)
def f_fun(y):
"""
idx 0 -> no deriv
idx 1 to d -> first deriv
"""
return y[1:].sum(dim=0) + dim * torch.exp(-y[0]) * (1 - 2 * torch.exp(-y[0]))
def g_fun(x):
return torch.log(1 + x.sum(dim=0) ** 2)
Next, we are ready to initialize the model and to train it. After the training, we plot the graph comparing deep branching solution and the true solution.
from branch import Net
import matplotlib.pyplot as plt
# initialize model and training
model = Net(
deriv_map=deriv_map,
f_fun=f_fun,
phi_fun=g_fun,
T=T,
x_lo=x_lo,
x_hi=x_hi,
device=device,
verbose=True,
)
model.train_and_eval()
# define exact solution and plot the graph
def exact_fun(t, x, T):
return np.log(1 + (x.sum(axis=0) + dim * (T - t)) ** 2)
grid = torch.linspace(x_lo, x_hi, 100).unsqueeze(dim=-1)
nn_input = torch.cat((torch.zeros((100, 1)), grid, torch.zeros((100, 2))), dim=-1)
plt.plot(grid, model(nn_input).detach(), label="Deep branching")
plt.plot(grid, exact_fun(0, nn_input[:, 1:].numpy().T, T), label="True solution")
plt.legend()
plt.show()
The resulting plot is embedded below:
The deep branching solver can be compared with
other deep learning solvers such as
the deep BSDE method [HJE18] and
the deep Galerkin method [SS18].
The codes are available in
bsde.py
and galerkin.py
.
In the comparison.ipynb
notebook,
we present such comparisons using
five PDE examples
considered in [NPP22].
We also analyzed the blow-up time of the deep branching solver in the notebooks in the directory blow_up_analysis
.
This part is developed after we have the latest version of the solver,
and is therefore using the code from the latest version.
comparison.ipynb
still uses the code from this repository for legacy purposes.
[HJE18] J. Han, A. Jentzen, and W. E. Solving high-dimensional partial differential equations using deep learning. Proceedings of the National Academy of Sciences, 115(34):8505--8510, 2018.
[NPP22] J.Y. Nguwi, G. Penent, and N. Privault. A deep branching solver for fully nonlinear partial differential equations. arXiv preprint arXiv:2203.03234, 2022.
[SS18] J. Sirignano and K. Spiliopoulos. DGM: A deep learning algorithm for solving partial differential equations. Journal of computational physics, 375:1339--1364, 2018.