# Simple PINN demo

This notebook demonstrates the use of the PINN class to solve a simple Partial Differential Equation (PDE). It has been largely inspired by the following [github repository](https://github.com/nanditadoloi/PINN).

### Problem Statement
Let us consider the following PDE:
$$ \frac{\partial u}{\partial x} = 2\frac{\partial u}{\partial t} + u $$
with boundary condition: 
$$ u(x,0) = 6e^{-3x} $$

The variables are:
* $x,t$ for the input
* $u$ for the output

**Goal:** We are searching for $u(x,t)$ for all $x$ in range $[0,2]$ and $t$ in range $[0,1]$.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np

# Torch
import torch
import torch.nn as nn
from torch.autograd import Variable
# Pytorch Lightning
import pytorch_lightning as pl


# Plotting
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter


### Analytical solution
Solving the PDE analytically, we obtain the following solution:
$$ u(x,t) = 6e^{-3x-2t} $$

In [None]:


x=np.arange(0,2,0.02)
t=np.arange(0,1,0.02)
X, T = np.meshgrid(x, t)

#  u(x,t) = 6e^{-3x-2t} 
U = 6*np.exp(-3*X-2*T)

plt.imshow(U, extent=[0, 2, 0, 1], origin='lower')
plt.title('Exact')
plt.xlabel('x')
plt.ylabel('t')
plt.colorbar(shrink=0.5, aspect=5)


### PINN solution using torch

We start by defining a MLP neural network with `tanh` activation function.

In [None]:

class FCNet(nn.Module):
    """A fully connected feed forward neural network with tanh activation function.
    
    Parameters
    ----------
    input_dimension : int
        Dimension of the input.
    output_dimension : int
        Dimension of the output.
    n_hidden_layers : int
        Number of hidden layers.
    neurons : int
        Number of neurons in each hidden layer.
    """

    def __init__(self, input_dimension, output_dimension, n_hidden_layers, neurons):
        super(FCNet, self).__init__()
        self.input_dimension = input_dimension
        self.output_dimension = output_dimension
        self.neurons = neurons
        self.n_hidden_layers = n_hidden_layers
        self.activation = nn.Tanh()

        self.input_layer = nn.Linear(self.input_dimension, self.neurons)
        # print(self.n_hidden_layers)
        self.hidden_layers = nn.ModuleList([nn.Linear(self.neurons, self.neurons) for _ in range(n_hidden_layers - 1)])
        self.output_layer = nn.Linear(self.neurons, self.output_dimension)

    def forward(self, x):
        # performs the set of affine and non-linear transformations defining the network
        x = self.activation(self.input_layer(x))
        for _, l in enumerate(self.hidden_layers):
            x = self.activation(l(x))
        return self.output_layer(x)
    
# xavier initialization of network parameters
def init_xavier(model):
    """ Initializes the network parameters using xavier initialization.

    To be used with an MLP network with tanh non-linearities

    Parameters
    ----------
    model : torch.nn
        Network to be initialized.
    
    """
    def init_weights(m):
        if type(m) == nn.Linear and m.weight.requires_grad and m.bias.requires_grad:
            g = nn.init.calculate_gain('tanh')
            torch.nn.init.xavier_uniform_(m.weight, gain=g)
            m.bias.data.fill_(0)

    model.apply(init_weights)


In [None]:
n_hidden_layers = 3
neurons = 128

class PINNNet(nn.Module):
    def __init__(self, n_hidden_layers, neurons) -> None:
        super().__init__()
        self.n_hidden_layers = n_hidden_layers
        self.neurons = neurons
        self.fcnet = FCNet(2, 1, n_hidden_layers, neurons)
    def forward(self, x, t):
        d = torch.cat((x, t), axis=-1)
        return self.fcnet(d)
    def init(self):
        init_xavier(self)

torch.manual_seed(1200)
net = PINNNet(n_hidden_layers, neurons)
net.init()

Let us define the different loss functions for the PINN class.
We start with the PDE loss function:
$$ err_{\text{PDE}}(x,t) = \frac{\partial u}{\partial x}(x,t) - 2\frac{\partial u}{\partial t}(x,t) - u(x,t) $$


In [None]:
def pde_error(net, x, t):
    """Computes the PDE error for the given network at point x,t."""
    x = Variable(x, requires_grad = True)
    t = Variable(t, requires_grad = True)
    u = net(x,t)
    ## Based on our f = du/dx - 2du/dt - u, we need du/dx and du/dt
    u_x = torch.autograd.grad(u.sum(), x, create_graph=True)[0]
    u_t = torch.autograd.grad(u.sum(), t, create_graph=True)[0]
    pde = u_x - 2*u_t - u
    return pde


Next we compute the error on the boundary conditions:
$$ err_{\text{BC}}(x,t) = u(x,0) - 6e^{-3x} $$

In [None]:
def boundary_error(net, n=512):
    """Computes the boundary error for the given network."""

    # BC tells us that for any x in range[0,2] and time=0, the value of u is given by 6e^(-3x)
    # Take say n random numbers of x
    x_bc = torch.rand(size=(n,1), requires_grad=False)*2
    t_bc = torch.zeros((n,1), requires_grad=False)
    # compute u based on BC. target u(x,0)=6e^(-3x)
    u_bc = 6*torch.exp(-3*x_bc)
    return net(x_bc,t_bc) - u_bc

In [None]:
class RandomBoundedDataset(torch.utils.data.Dataset):
    def __init__(self, n = 1024,  boundx = [0,2], boundt = [0,1]):
        self.boundx = torch.tensor(boundx)
        self.boundt = torch.tensor(boundt)
        self.n = n

    def __getitem__(self, index):
        d = torch.rand(size=(1,2), requires_grad=False)
        x = d[:,:1] * (self.boundx[1] - self.boundx[0]) + self.boundx[0]
        t = d[:,1:] * (self.boundt[1] - self.boundt[0]) + self.boundt[0]
        y = torch.zeros((1,1), requires_grad=False)
        return x,t, y
    def __len__(self):
        return self.n

n = 256 * 512
dataset = RandomBoundedDataset(n=n, boundx=[0,2], boundt=[0,1])


In [None]:
class PINN(pl.LightningModule):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.loss = nn.MSELoss()
        self.save_hyperparameters()
        
    
    def training_step(self, batch, batch_idx):
        x,t,y =  batch

        # compute the PDE loss
        pde_error_ = pde_error(self.net, x, t)
        pde_loss = self.loss(pde_error_, y)
        self.log('pde_loss', pde_loss)

        # compute the BC loss
        boundary_error_ = boundary_error(self.net, n=len(x))
        boundary_loss = self.loss(boundary_error_, y)
        self.log('boundary_loss', boundary_loss)

        # compute the total loss
        loss = pde_loss + boundary_loss
        self.log('loss', loss)
        
        return loss
    
    def configure_optimizers(self):
        # This is probably not the best optimizer for this problem
        # PINNs are usually trained with BGFS or LBFGS
        optimizer = torch.optim.Adam(self.net.parameters(), lr=1e-3)
        return optimizer

In [None]:
max_epochs = 10
batch_size = 512
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
model = PINN(net)


In [None]:
trainer = pl.Trainer(max_epochs=max_epochs)
trainer.fit(model=model, train_dataloaders=train_loader)

In [None]:
plt.figure(figsize=(12,4))
plt.subplot(131)
x=np.arange(0,2,0.02)
t=np.arange(0,1,0.02)
X, T = np.meshgrid(x, t)

#  u(x,t) = 6e^{-3x-2t} 
U = 6*np.exp(-3*X-2*T)
vmin = np.min(U)
vmax = np.max(U)
plt.imshow(U, extent=[0, 2, 0, 1], vmin=vmin, vmax=vmax, origin='lower')
plt.title('Exact solution')
plt.xlabel('x')
plt.ylabel('t')
plt.colorbar(shrink=0.5, aspect=5)

pt_x = torch.from_numpy(np.expand_dims(X,2)).float()
pt_t = torch.from_numpy(np.expand_dims(T,2)).float()
pt_u = net(pt_x,pt_t)[:,:,0]
U_pred = pt_u.data.cpu().numpy()

plt.subplot(132)
plt.imshow(U_pred, extent=[0, 2, 0, 1], vmin=vmin, vmax=vmax, origin='lower')
plt.title('PINN prediction')
plt.xlabel('x')
plt.ylabel('t')
plt.colorbar(shrink=0.5, aspect=5)

plt.subplot(133)
abs_error = np.abs(U_pred-U)
plt.imshow(abs_error, extent=[0, 2, 0, 1], vmin=0, vmax=np.max(abs_error), origin='lower')
plt.title('Absolute difference')
plt.xlabel('x')
plt.ylabel('t')
plt.colorbar(shrink=0.5, aspect=5)

plt.savefig('img/output.png', dpi=300)

In [None]:
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})

surf = ax.plot_surface(X,T,U, cmap=cm.coolwarm,linewidth=0, antialiased=False)
             
ax.zaxis.set_major_locator(LinearLocator(10))
ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))

fig.colorbar(surf, shrink=0.5, aspect=5)