# PINN Solution of the Gray Scott PDEs

This PyTorch code demonstrates the application of physically-informed neural networks (PINN) in the solution of a well-known Gray Scott PDEs with periodic boundary condition
\begin{aligned}
  &u_t = \epsilon_1\Delta u + b(1 - u) - uv^2, \quad (t, x) \in [0, T]\times[-L, L]\\
  &v_t = \epsilon_2\Delta v -dv + uv^2, \\
  &u(0, x) = u_0(x), \quad v(0, x) = v_0(x), \quad \forall x \in [-L, L] \\
  &u(t, -L) = u(t, L), \quad v(t, -L) = v(t, L), \quad \forall t \in [0, T]
\end{aligned}
where $\epsilon_1, \epsilon_2, b, d > 0$ are some parameters, and $[x_{\min}, x_{\max}]$ covers one full period.

See [this link](https://www.chebfun.org/examples/pde/GrayScott.html) for a description

## Libraries and Dependencies

In [1]:
import torch
import torch.optim.lr_scheduler as lr_scheduler
from itertools import chain
from collections import OrderedDict
from pyDOE import lhs
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import scipy.io
from scipy.interpolate import griddata
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec
import time
# set the random seed
np.random.seed(1234)

In [2]:
def figsize(scale, nplots = 1):
    fig_width_pt = 390.0                          # Get this from LaTeX using \the\textwidth
    inches_per_pt = 1.0/72.27                       # Convert pt to inch
    golden_mean = (np.sqrt(5.0)-1.0)/2.0            # Aesthetic ratio (you could change this)
    fig_width = fig_width_pt*inches_per_pt*scale    # width in inches
    fig_height = nplots*fig_width*golden_mean       # height in inches
    fig_size = [fig_width,fig_height]
    return fig_size

In [3]:
pgf_with_latex = {                      # setup matplotlib to use latex for output
    "pgf.texsystem": "pdflatex",        # change this if using xetex or lautex
    "text.usetex": True,                # use LaTeX to write all text
    "font.family": "serif",
    "font.serif": [],                   # blank entries should cause plots to inherit fonts from the document
    "font.sans-serif": [],
    "font.monospace": [],
    "axes.labelsize": 10,               # LaTeX default is 10pt font.
    "font.size": 10,
    "legend.fontsize": 8,               # Make the legend/label fonts a little smaller
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,
    "figure.figsize": figsize(1.0),     # default fig size of 0.9 textwidth
    "pgf.preamble": r"\usepackage[utf8x]{inputenc} \usepackage[T1]{fontenc}"
    # use utf8 fonts becasue your computer can handle it :)
    # plots will be generated using this preamble
    }
mpl.rcParams.update(pgf_with_latex)

In [4]:
# MPS or CUDA or CPU
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
#
print(f"Working on {device}")

Working on mps


In [5]:
"""
Problem Definition
"""
# define grid for quadrature solution
epsilon1 = 1
epsilon2 = 1e-2
b = 2e-2
d = 8.26e-2
L = 50.0
xlo = -L
xhi = L
period = xhi - xlo
tlo = 0.0
thi = 80.0
u0 = lambda x: 1 - 0.5 * np.power(np.sin(np.pi * (x - L)/(2.0 * L)), 100.0)
v0 = lambda x: 0.25 * np.power(np.sin(np.pi * (x - L)/(2.0 * L)), 100.0)
# we need to prepare the true solution from chebfun 

## Physics-informed Neural Networks

In [6]:
# the deep neural network
class DNN(torch.nn.Module):
    def __init__(self, layers):
        super(DNN, self).__init__()
        # parameters
        self.depth = len(layers) - 1
        # set up layer order dict
        self.activation = torch.nn.Tanh
        layer_list = list()
        for i in range(self.depth - 1): 
            layer_list.append(
                ('layer_%d' % i, torch.nn.Linear(layers[i], layers[i+1]))
            )
            layer_list.append(('activation_%d' % i, self.activation()))
            
        layer_list.append(
            ('layer_%d' % (self.depth - 1), torch.nn.Linear(layers[-2], layers[-1]))
        )
        layerDict = OrderedDict(layer_list)
        # deploy layers
        self.layers = torch.nn.Sequential(layerDict)
    def forward(self, x):
        out = self.layers(x)
        return out

In [7]:
# the physics-guided neural network
class PhysicsInformedNN():
    def __init__(self, X_data, u_data, v_data, X_pbc, period, X_pde, layers, epsilon1, epsilon2, b, d):
        # data (Dirichelet BC or Initial Condition or noise data)
        self.N_data = X_data.shape[0]
        self.t_data = torch.tensor(X_data[:, 0:1]).float().to(device)
        self.x_data = torch.tensor(X_data[:, 1:2]).float().to(device)
        self.u_data = torch.tensor(u_data).float().to(device)
        self.v_data = torch.tensor(v_data).float().to(device)
        # Periodic BC data
        self.N_pbc = X_pbc.shape[0]
        self.t_pbc = torch.tensor(X_pbc[:, 0:1]).float().to(device)
        self.x_pbc = torch.tensor(X_pbc[:, 1:2]).float().to(device)
        self.period = torch.tensor(period).float().to(device)
        # PDE data, gradients will be computed on these points so requires_grad = True
        self.N_pde = X_pde.shape[0]
        self.t_pde = torch.tensor(X_pde[:, 0:1], requires_grad=True).float().to(device)
        self.x_pde = torch.tensor(X_pde[:, 1:2], requires_grad=True).float().to(device)
        # layers to build Neural Net
        self.layers = layers
        # equation related parameters
        self.epsilon1 = epsilon1
        self.epsilon2 = epsilon2
        self.b = b
        self.d = d
        # construct the loss weights
        self.lws = torch.tensor(np.ones((self.N_data + self.N_pbc + self.N_pde, 1), dtype = np.single), requires_grad=True).to(device)
        self.lws = torch.nn.Parameter(self.lws)
        # deep neural networks
        self.dnn = DNN(layers).to(device)
        self.optimizer_Adam = torch.optim.Adam(self.dnn.parameters(), lr = 1e-3)
        self.optimizer_lws = torch.optim.Adam(chain((self.lws, )), lr = 1e-3)
        self.iter = 0
    # evaluater neural network
    def NN_eval(self, t, x):  
        NN = self.dnn(torch.cat([t, x], dim = 1))
        u = NN[:, 0][:, None]
        v = NN[:, 1][:, None]
        return u, v
    # compute the PDE
    def pde_eval(self, t, x):
        """ The pytorch autograd version of calculating residual """
        u, v = self.NN_eval(t, x)
        # compute the derivatives for u
        u_t  = torch.autograd.grad(u,   t, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]
        u_x  = torch.autograd.grad(u,   x, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]
        u_xx = torch.autograd.grad(u_x, x, grad_outputs = torch.ones_like(u), retain_graph = True, create_graph=True)[0]
        v_t  = torch.autograd.grad(v,   t, grad_outputs = torch.ones_like(v), retain_graph = True, create_graph=True)[0]
        v_x  = torch.autograd.grad(v,   x, grad_outputs = torch.ones_like(v), retain_graph = True, create_graph=True)[0]
        v_xx = torch.autograd.grad(v_x, x, grad_outputs = torch.ones_like(v), retain_graph = True, create_graph=True)[0]
        Eq1  = u_t - self.epsilon1 * u_xx - self.b * (1.0 - u) + u * torch.pow(v, 2.0)
        Eq2  = v_t - self.epsilon2 * v_xx + self.d * v - u * torch.pow(v, 2.0)
        return Eq1, Eq2
    # # compute the total loss for the second-order optimizer
    # def loss_func(self):
    #     # reset the gradient
    #     self.optimizer.zero_grad()
    #     # compute the data loss
    #     u_pred, v_pred = self.NN_eval(self.t_data, self.x_data)
    #     loss_data = torch.mean((self.u_data - u_pred)**2.0) + torch.mean((self.v_data - v_pred)**2.0)
    #     # compute the PBC loss
    #     u_pbc1, v_pbc1 = self.NN_eval(self.t_pbc, self.x_pbc)
    #     u_pbc2, v_pbc2 = self.NN_eval(self.t_pbc, self.x_pbc + self.period)
    #     loss_pbc = torch.mean((u_pbc1 - u_pbc2)**2.0) + torch.mean((v_pbc1 - v_pbc2)**2.0)
    #     # compute PDE loss
    #     pde1_pred, pde2_pred = self.pde_eval(self.t_pde, self.x_pde)
    #     loss_pde = torch.mean(pde1_pred ** 2) + torch.mean(pde2_pred ** 2)
    #     # compute the total loss, it can be weighted
    #     loss = loss_data + loss_pbc + loss_pde
    #     # backward propagation
    #     loss.backward()
    #     # increase the iteration counter
    #     self.iter += 1
    #     # output
    #     if (self.iter + 1) % 1000 == 0:
    #         print(
    #             'Iter %5d, Total: %10.4e, Data: %10.4e, PBC: %10.4e, PDE: %10.4e' 
    #             % (self.iter + 1, loss.item(), loss_data.item(), loss_pbc.item(), loss_pde.item())
    #         )          
    #     return loss
    #
    def train(self, nIter):
        self.dnn.train()
        for epoch in range(nIter):
            # compute the data loss
            u_pred, v_pred = self.NN_eval(self.t_data, self.x_data)
            lws_IC = torch.pow(self.lws[:self.N_data, 0][:, None], 2.0)
            loss_data = torch.mean(lws_IC * (self.u_data - u_pred)**2.0) + torch.mean(lws_IC * (self.v_data - v_pred)**2.0)
            # compute the PBC loss
            u_pbc1, v_pbc1 = self.NN_eval(self.t_pbc, self.x_pbc)
            u_pbc2, v_pbc2 = self.NN_eval(self.t_pbc, self.x_pbc + self.period)
            lws_pbc = torch.pow(self.lws[self.N_data : (self.N_data + self.N_pbc), 0][0, None], 2.0)
            loss_pbc = torch.mean(lws_pbc * (u_pbc1 - u_pbc2)**2.0) + torch.mean(lws_pbc * (v_pbc1 - v_pbc2)**2.0)
            # compute PDE loss
            pde1_pred, pde2_pred = self.pde_eval(self.t_pde, self.x_pde)
            lws_pde = torch.pow(self.lws[(self.N_data + self.N_pbc) : , 0][0, None], 2.0)
            loss_pde = torch.mean(lws_pde * (pde1_pred ** 2)) + torch.mean(lws_pde * (pde2_pred ** 2))
            # compute the total loss, it can be weighted
            loss = loss_data + loss_pbc + loss_pde
            # Backward and optimize
            self.optimizer_Adam.zero_grad()
            loss.backward()
            self.optimizer_Adam.step()   
            # update the loss weights
            for lw in self.lws:
                if lw.grad is not None:
                    lw.grad.data.mul_(-1)
            self.optimizer_lws.step()
            # output the progress
            if (epoch + 1) % 1000 == 0:
                print(
                    'Iter %5d, Total: %10.4e, Data: %10.4e, PBC: %10.4e, PDE: %10.4e' 
                    % (epoch + 1, loss.item(), loss_data.item(), loss_pbc.item(), loss_pde.item())
                )
            
        
    #        
    def predict(self, X):
        t = torch.tensor(X[:, 0:1], requires_grad=True).float().to(device)
        x = torch.tensor(X[:, 1:2], requires_grad=True).float().to(device)
        self.dnn.eval()
        u, v = self.NN_eval(t, x)
        u = u.detach().cpu().numpy()
        v = v.detach().cpu().numpy()
        return u, v

## Configurations

In [8]:
N_IC = 512
dt = 2
N_PBC = int(thi/dt)
# where we know the data
# IC
# IC: t = 0, x \in [xlo, xhi], it can be on a gridd or random
# make sure they are all column vectors, which is a 2-dim tensor
# we will use chebyshev points
cheb_pts = (xhi - xlo)/2.0 * np.cos((2.0 * np.arange(1, N_IC + 1) - 1.0)/N_IC * np.pi/2.0) + (xhi + xlo)/2.0
xIC = np.expand_dims(cheb_pts, axis = 1)
tIC = np.zeros_like(xIC)
ptsIC = np.hstack((tIC, xIC))
uIC = u0(xIC)
vIC = v0(xIC)
# PBC: x = xlo, t \in [0, T], it can be on a grid or random
tPBC = np.expand_dims(np.linspace(tlo, thi, N_PBC + 1)[1:], axis = 1)
xPBC = xlo * np.ones_like(tPBC)
ptsPBC = np.hstack((tPBC, xPBC))
# assemble them (the data first)
data_pts = ptsIC
u_pts = uIC
v_pts = vIC
# not going to use random points, but use grided points
# # collocation points for PDE loss
# pts_PDE = lhs(2, N_PDE)
# # transform it back to proper range
# t_pde = tlo + (thi - tlo) * pts_PDE[:, 0:1]
# x_pde = xlo + (xhi - xlo) * pts_PDE[:, 1:2]
# pts_PDE = np.hstack((t_pde, x_pde))
t = np.linspace(tlo, thi, N_PBC + 1)[1:] # not taking the initial time
x = cheb_pts
T, X = np.meshgrid(t, x)
pts_PDE = np.hstack((T.flatten()[:, None], X.flatten()[:, None]))
N_PDE = pts_PDE.shape[0]

## Training

In [9]:
layers = [2, 20, 20, 20, 20, 2]
model = PhysicsInformedNN(data_pts, u_pts, v_pts, ptsPBC, period, pts_PDE, layers, epsilon1, epsilon2, b, d)

In [None]:
%%time
model.train(50000)

  if lw.grad is not None:


Iter  1000, Total: 8.9627e-06, Data: 6.1960e-06, PBC: 1.2601e-07, PDE: 2.6406e-06
Iter  2000, Total: 2.1632e-04, Data: 7.2769e-05, PBC: 1.2234e-04, PDE: 2.1208e-05
Iter  3000, Total: 3.5134e-05, Data: 2.1005e-05, PBC: 3.4842e-06, PDE: 1.0644e-05
Iter  4000, Total: 3.2183e-04, Data: 1.2887e-04, PBC: 1.9285e-04, PDE: 1.0905e-07
Iter  5000, Total: 2.6222e-04, Data: 1.1059e-04, PBC: 1.3972e-04, PDE: 1.1912e-05
Iter  6000, Total: 2.5997e-04, Data: 1.5130e-04, PBC: 9.9342e-05, PDE: 9.3295e-06


In [None]:
# apply PINN to the same grid as the quadrature solution for comparison
t = np.linspace(tlo, thi, 101)
x = np.linspace(xlo, xhi, 101)
T, X = np.meshgrid(t, x)
pts_flat = np.hstack((T.flatten()[:, None], X.flatten()[:, None]))
u_pred, v_pred = model.predict(pts_flat)
# #
# Exact = u_quad.T
# Exact_vec = Exact.flatten()[:, None]
# error_u = np.linalg.norm(Exact_vec-u_pred,2)/np.linalg.norm(Exact_vec,2)
# print('Error u: %e' % (error_u))                     
u_pred = griddata(pts_flat, u_pred.flatten(), (T, X), method='cubic')
v_pred = griddata(pts_flat, v_pred.flatten(), (T, X), method='cubic')
# Error = np.abs(Exact - U_pred)

## Visualizations

In [None]:
""" The aesthetic setting has changed. """

####### Row 0: u(t,x) ##################    

fig = plt.figure(figsize=(11, 5))
ax = fig.add_subplot(1, 1, 1)
#
ax.plot(ptsIC[:, 0], ptsIC[:, 1], 
    'kx', label = 'IC Data (%d points)' % (ptsIC.shape[0]), 
    markersize = 4,  # marker size doubled
    clip_on = False,
    alpha=1.0
)
#
ax.plot(ptsPBC[:, 0], ptsPBC[:, 1], 
    'bo', label = 'LPBC Data (%d points)' % (ptsPBC.shape[0]), 
    markersize = 4,  # marker size doubled
    clip_on = False,
    alpha=1.0
)
#
ax.plot(ptsPBC[:, 0], ptsPBC[:, 1] + period, 
    'bo', label = 'RPBC Data (%d points)' % (ptsPBC.shape[0]), 
    markersize = 4,  # marker size doubled
    clip_on = False,
    alpha=1.0
)
#
ax.plot(pts_PDE[:, 0], pts_PDE[:, 1], 
    'rd', label = 'PDE Data (%d points)' % (pts_PDE.shape[0]), 
    markersize = 4,  # marker size doubled
    clip_on = False,
    alpha=1.0
)
#
ax.set_xlabel('$t$', size=15)
ax.set_ylabel('$x$', size=15)
ax.legend(
    loc='upper center', 
    bbox_to_anchor=(0.9, -0.05), 
    ncol=5, 
    frameon=False, 
    prop={'size': 15}
)
ax.legend()
ax.set_title('Points', fontsize = 15) # font size doubled
ax.tick_params(labelsize=12)
#
plt.show()

In [None]:
####### Row 1: u(t,x) slices ################## 

""" The aesthetic setting has changed. """

fig = plt.figure(figsize=(14, 10))
#
ax = plt.subplot(1, 3, 1)
#ax.plot(x, Exact[:, 25], 'bo-', linewidth = 2, label = 'Exact')       
ax.plot(x, u_pred[:, 0], 'rx--', linewidth = 2, label = 'u')
ax.plot(x, v_pred[:, 0], 'bo--', linewidth = 2, label = 'v')
ax.set_xlabel('$x$')
ax.set_ylabel('$u(t,x)$')    
ax.set_title('$t = %.1f$' %(t[0]), fontsize = 15)
#ax.axis('square')
#ax.set_xlim([-1.1,1.1])
#ax.set_ylim([-1.1,1.1])
plt.locator_params(axis = 'y', nbins = 5)
plt.locator_params(axis = 'x', nbins = 5)
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
             ax.get_xticklabels() + ax.get_yticklabels()):
    item.set_fontsize(15)

ax = plt.subplot(1, 3, 2)
#ax.plot(x,Exact[:,50], 'b-', linewidth = 2, label = 'Exact')       
ax.plot(x, u_pred[:, 50], 'rx--', linewidth = 2, label = 'u')
ax.plot(x, v_pred[:, 50], 'bo--', linewidth = 2, label = 'v')
ax.set_xlabel('$x$')
ax.set_ylabel('$u(t,x)$')
#ax.axis('square')
#ax.set_xlim([-1.1,1.1])
#ax.set_ylim([-1.1,1.1])
ax.set_title('$t = %.1f$' %(t[50]), fontsize = 15)
ax.legend(
    loc='upper center', 
    bbox_to_anchor=(0.5, -0.15), 
    ncol=5, 
    frameon=False, 
    prop={'size': 15}
)
plt.locator_params(axis = 'y', nbins = 5)
plt.locator_params(axis = 'x', nbins = 5)
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
             ax.get_xticklabels() + ax.get_yticklabels()):
    item.set_fontsize(15)

ax = plt.subplot(1, 3, 3)
#ax.plot(x,Exact[:,75], 'b-', linewidth = 2, label = 'Exact')       
ax.plot(x, u_pred[:, 100], 'rx--', linewidth = 2, label = 'u')
ax.plot(x, v_pred[:, 100], 'bo--', linewidth = 2, label = 'v')
ax.set_xlabel('$x$')
ax.set_ylabel('$u(t,x)$')
#ax.axis('square')
#ax.set_xlim([-1.1,1.1])
#ax.set_ylim([-1.1,1.1])    
ax.set_title('$t = %.1f$' %(t[100]), fontsize = 15)
plt.locator_params(axis = 'y', nbins = 5)
plt.locator_params(axis = 'x', nbins = 5)
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
             ax.get_xticklabels() + ax.get_yticklabels()):
    item.set_fontsize(15)

plt.show()