# **Tutorial: PINN simulation with JINNAX**

In this tutorial, we present the main features of the JINNAX package for simulation Physics Inform Neural Network (PINN) in JAX. The purpose of JINNAX is to simulate the solution of the PDE

\begin{align*}
  \mathcal{N}_{\boldsymbol{x},t}[u(\boldsymbol{x},t)] &= f(\boldsymbol{x},t), \ \ \boldsymbol{x} \in \Omega, \ t \in (0,T] \\
  u(\boldsymbol{x},t) &= g(\boldsymbol{x},t), \ \ \boldsymbol{x} \in \partial \Omega, \ t \in (0,T]\\
  u(\boldsymbol{x},0) &= h(\boldsymbol{x})\ \ \ , \ \ \boldsymbol{x} \in \bar{\Omega}
\end{align*}
in which $\mathcal{N}_{\boldsymbol{x},t}$ is a spatial-temporal differential perator; $\Omega = (x_{l},x_{u})^{d}$ for $d \geq 1$ is an open d-dimensional cube; and $T > 0$.

The simulation pipeline in JINNAX has three steps:

1.   **Data generation**: Generate data in $\bar{\Omega} \times [0,T]$
2.   **Training**: Train a PINN with the generated data
3.   **Evaluation**: Evaluate the trained PINN

This tutorial will cover how these steps can be performed with JINNAX.


### **Import JINNAX**

The latest version of JINNAX can be imported from Github. It has `data` and `training` modules.

In [None]:
#Install jinnax
import os
os.system("pip3 install --upgrade git+https://github.com/dmarcondes/JINNAX")

#Import jinnax
from jinnax import data as jd
from jinnax import training as jtr

### **Step 1: Data generation**

#### **Reasoning for data generation**

Data is generated by sensors placed in $\bar{\Omega}$ which realize measures over time as follows:

*  There are $N_{s}^{d}, N_{s} > 1,$ sensors in $\Omega$, placed either in a grid or in uniformly sampled positions.
*  There are $N_{b}^{d} -  (N_{b} - 2)^{d}, N_{b} > 1,$ sensors in $\partial \Omega$. More specifically, for each coordinate $j$ of $x$, there are $N_{b}^{d-1}$ sensors placed in points in which $x_{j} = x_{l}$ and the remaining $d-1$ coordinates are in a grid of $[x_{l},x_{u}]^{d-1}$. There are also $N_{b}^{d-1}$ sensors placed in points in which $x_{j} = x_{u}$ and the remaining $d-1$ coordinates are in a grid of $[x_{l},x_{u}]^{d-1}$.
* These sensors realize measures in $0 = t_{0} < t_{1} < \dots < t_{N_{t}} \leq T$ for $N_{t} > 1$. The last $N_{t}$ times are eitheir in a grid or uniformly sampled from $(0,T]$.
* For $t > 0$, the sensors in $\Omega$ have a mean-zero Gaussian noise with variance $\sigma_{s}^{2} \geq 0$.
* For $t \geq 0$, the sensors in $\partial \Omega$ have a mean-zero Gaussian noise with variance $\sigma_{b}^{2} \geq 0$.
* For $t = 0$, the sensors in $\Omega$ have a mean-zero Gaussian noise with variance $\sigma_{i}^{2} \geq 0$.
* There are $N_{c}^{d}, N_{c} > 1$, collocation points in $\Omega$ placed either in a grid or in uniformly sampled positions. These spatial points are considered in $N_{ct}$ times in grid of $[0,T]$ or uniformly sampled from this set.

#### **JINNAX funtion for data generation**

Simulation data is generated by the function `generate_PINNdata` from the `data` module. This function parameters are:

*   `u`: Solution of the PDE (function)
*   `xlo`: Lower bound of each x coordinate
*   `xup`: Upper bound of each x coordinate
*   `tlo`: Lower bound of the time interval. Default 0
*   `tup`: Upper bound of the time interval
*   `Ns`: Number of points along each x coordinate for sensor data
*   `Nt`: Number of points along the time axis for sensor data
*   `Nb`: Number of points along each x coordinate for boundary data
*   `Nc`: Number of points along each x coodinate for collocation points
*   `Ntc`: Number of points along the time axis for collocation points
*   `d`: Domain dimension. Default 1
*   `poss`: Position of points the spatial domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
*   `post`: Position of points in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
*   `posc`: Position of the collocation points in the x domain. Either 'grid' or 'random' for uniform sampling. Default 'grid'
*   `posct`: Position of the collocation points in the time interval. Either 'grid' or 'random' for uniform sampling. Default 'grid'
*   `sigmas`: Standard deviation of the Gaussian noise of sensor data (x inside the domain). Default 0
*   `sigmab`: Standard deviation of the Gaussian noise of boundary data. Default 0
*   `sigmai`: Standard deviation of the Gaussian noise of initial data. Default 0

The function returns a dictionary with:
* (x,t) sensor data ('sensor')
* u(x,t) sensor data ('usensor')
* (x,t) boundary data ('boundary')
* u(x,t) boundary data ('uboundary')
* (x,0) initial data ('initial')
* u(x,0) initial data ('uinitial')
* collocation data ('collocation')

####**More details: Sensor data**

For each coordinate $j = 1,\dots,d$, we fix $N_{s}$ points in $(x_{l},x_{u})$: $x_{1}^{(j)},\dots,x_{N_{s}}^{(j)}$. If `poss = 'grid'`, then these points form a grid of $(x_{l},x_{u})$ with

$$x_{i}^{(j)} = x_{l} + i\delta \ \ \ \text{with} \ \ \ \delta = \frac{x_{u} - x_{l}}{N_{s} + 1}$$

Otherwise, if `poss = 'random'`, then these points are uniformly sampled from $(x_{l},x_{u})$.

We also fix $N_{t}$ points in $(0,T]$: $t_{1},\dots,t_{N_{t}}$. Again, if `post = 'grid'`, then these points form a grid of $(0,T]$ with

$$t_{i} = t_{l} + i\delta \ \ \ \text{with} \ \ \ \delta = \frac{T}{N_{t}}$$

Otherwise, if `poss = 'random'`, then these points are uniformly sampled from $(0,T)$.

The sample is then generated by the Cartasean product of these points over the `d` x dimensions and the time dimension. This sample has size $N_{s}^{d} \times N_{t}$ and its points are of the form

$$\left((x_{i_{1}}^{(1)},\dots,x_{i_{d}}^{(d)}),t_{i_{t}},y_{i_{1},\dots,i_{d},i_{t}}\right)$$

with $i_{j} \in \{1,\dots,N_{s}\}$, $j \in \{1,\dots,d\}$, $i_{t} \in \{1,\dots,N_{t}\}$ and

$$y_{i_{1},\dots,i_{d},i_{t}} = u\left((x_{i_{1}}^{(1)},\dots,x_{i_{d}}^{(d)}),t_{i_{t}}\right) + ϵ_{i_{1},\dots,i_{d},i_{t}}$$

in which $ϵ_{i_{1},\dots,i_{d},i_{t}}$ is a mean-zero Normal random variable with stardard deviation `sigmas`. The Normal random variables for distinct indexes are independent.

####**More details: Initial data**

The sensors that measure data for $t > 0$, also measure for $t = 0$, so the initial data is of form

$$\left((x_{i_{1}}^{(1)},\dots,x_{i_{d}}^{(d)}),0,y_{i_{1},\dots,i_{d},0}\right)$$

in which the points $x_{i_{j}}^{(j)}$ are the same considered for sensor data and

$$y_{i_{1},\dots,i_{d},0} = u\left((x_{i_{1}}^{(1)},\dots,x_{i_{d}}^{(d)}),0\right) + ϵ_{i_{1},\dots,i_{d},0}$$

in which $ϵ_{i_{1},\dots,i_{d},0}$ is a mean-zero Normal random variable with stardard deviation `sigmai`. The Normal random variables for distinct indexes are independent.

####**More details: Boundary data**

Boundary data is generated in a grid of $\partial \Omega$. For each $j \in \{1,\dots,d\}$, we fix $N_{s}$ points in $[x_{l},x_{u}]$: $x_{0}^{(j,b)},\dots,x_{N_{b}-1}^{(j,b)}$. These points form a grid of $[x_{l},x_{u}]$ with

$$x_{i}^{(j,b)} = x_{l} + i\delta \ \ \ \text{with} \ \ \ \delta = \frac{x_{u} - x_{l}}{N_{s} - 1}$$.

Sensor are positioned in points of the form $\left(x_{i_{1}}^{(1,b)},\dots,x_{i_{d}}^{(d,b)}\right)$ in which $x_{i_{j}}^{(j,b)} \in \{x_{l},x_{u}\}$ for at least one coordinate $j$. Observe that there are $N_{b}^{d} - (N_{b} - 2)^{d}$ such points.

These sensors realize measures for the times $0 = t_{0} < t_{1} < \cdots < t_{N_{t}} \leq T$, in which $t_{i}, i > 0,$ are the same times considered for the sensor data.

Therefore, the boundary sample has size $(N_{b}^{d} - (N_{b} - 2)^{d}) \times (N_{t} + 1)$ and its points are of the form

$$\left((x_{i_{1}}^{(1,b)},\dots,x_{i_{d}}^{(d,b)}),t_{i_{t}},y_{i_{1},\dots,i_{d},i_{t}}^{(b)}\right)$$

with $i_{j} \in \{1,\dots,N_{b}\}$, $j \in \{1,\dots,d\}$, $i_{t} \in \{0,\dots,N_{t}\}$ and

$$y_{i_{1},\dots,i_{d},i_{t}}^{b} = u\left((x_{i_{1}}^{(1)},\dots,x_{i_{d}}^{(d)}),t_{i_{t}}\right) + ϵ_{i_{1},\dots,i_{d},i_{t}}^{(b)}$$

in which $ϵ_{i_{1},\dots,i_{d},i_{t}}^{(b)}$ is a mean-zero Normal random variable with stardard deviation `sigmab`. Again, the Normal random variables for distinct indexes are independent.

####**More details: Collocation points**

We consider $N_{c}^{d}$ collocation points measure in $N_{ct}$ times. If `posc = 'grid'`, then the $N_{c}^{d}$ are in a grid of $(x_{l},x_{u})^{d}$. Otherwise, if `posc = 'random'`, then these points are uniformly sampled from $(x_{l},x_{u})^{d}$. The $N_{c}^{d}$ are considered in $N_{ct}$ in $[0,T]$: if `posc = 'grid'` these times are in a grid of $[0,T]$; if `posc = 'random'` these times are sampled from $[0,T]$.

#### **Example**

The code below presents examples of data generated for one and two dimensional domains.

In [None]:
#One dimensional in a grid
dat = jd.generate_PINNdata(u = lambda x,t: x + t,xlo = 0,xup = 1,tup = 1,Ns = 4,Nt = 4,Nb = 4,Nc = 4,Ntc = 4,tlo = 0,d = 1,poss = 'grid',post = 'grid',posc = 'grid',posct = 'grid',sigmas = 0,sigmab = 0,sigmai = 0)
print("Sensor")
print(dat['sensor'].shape)
print(dat['sensor'])
print("Boundary")
print(dat['boundary'].shape)
print(dat['boundary'])
print("Initial")
print(dat['initial'].shape)
print(dat['initial'])
print("Collocation")
print(dat['collocation'].shape)
print(dat['collocation'])

In [None]:
#One dimensional random
dat = jd.generate_PINNdata(u = lambda x,t: x + t,xlo = 0,xup = 1,tup = 1,Ns = 4,Nt = 4,Nb = 4,Nc = 4,Ntc = 4,tlo = 0,d = 1,poss = 'random',post = 'random',posc = 'random',posct = 'random',sigmas = 0,sigmab = 0,sigmai = 0)
print("Sensor")
print(dat['sensor'].shape)
print(dat['sensor'])
print("Boundary")
print(dat['boundary'].shape)
print(dat['boundary'])
print("Initial")
print(dat['initial'].shape)
print(dat['initial'])
print("Collocation")
print(dat['collocation'].shape)
print(dat['collocation'])

In [None]:
#Two dimensional in a grid
dat = jd.generate_PINNdata(u = lambda x,t: x + t,xlo = 0,xup = 1,tup = 1,Ns = 4,Nt = 4,Nb = 4,Nc = 4,Ntc = 4,tlo = 0,d = 2,poss = 'grid',post = 'grid',posc = 'grid',posct = 'grid',sigmas = 0,sigmab = 0,sigmai = 0)
print("Sensor")
print(dat['sensor'].shape)
print(dat['sensor'])
print("Boundary")
print(dat['boundary'].shape)
print(dat['boundary'])
print("Initial")
print(dat['initial'].shape)
print(dat['initial'])
print("Collocation")
print(dat['collocation'].shape)
print(dat['collocation'])