<a href="https://colab.research.google.com/github/ikbenali/ReproducibilityProject_DL/blob/main/Poisson1DNTK.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import sympy as sm
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt

dtype  = torch.float32

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# device = torch.device('cpu')
sympyTorchmodules = {'sin': torch.sin, 'cos': torch.cos}


### 1D Poisson PDE Class

\begin{align}
u_{xx} &= f(x), \hspace{1.3cm} x \in \Omega \\
u(x) &= g(x), \hspace{1.3cm} x \in \partial \Omega
\end{align}

Where: 

\begin{align}
    f(x) &= -a^{2}\pi^{2}\sin(a\pi x), \hspace{0.2cm} x \in [0,1] \\
    g(x) &= 0, \hspace{2.3cm} x = 0,1
\end{align}

In [None]:
class Poisson1D:

    def __init__(self, a=None):

        self.setup_equations()
        self.setup_residuals()

        if a != None:
            self.a = 1

    def setup_equations(self, f_eqn=None, g_eqn=None):
        ### Setup

        # Variables/Coefficients
        a   = sm.symbols('a'); 

        # PDE States
        x   = sm.symbols('x')        # domain
        # xbc = sm.symbols('x1:3')   # partial domain for boundary condition

        u   = sm.symbols('u', cls=sm.Function)(x)
        ux  = u.diff(x)
        uxx = ux.diff(x)

        # Forcing/External/Boundary/Initial condition functions
        f   = sm.symbols('f', cls=sm.Function)(x)
        g   = sm.symbols('g', cls=sm.Function)(x)

        # Set up PDE_eqn
        self.PDE_eqn = sm.Eq(uxx,f)

        # Set up boundary condition
        # bc_eq1 = sm.Piecewise((u, sm.Eq(x, xbc[0])),  (u, sm.Eq(x, xbc[1])), (0, True))
        # bc_eq2 = sm.Piecewise((g, sm.Eq(x, xbc[0])),  (g, sm.Eq(x, xbc[1])), (0, True))
        self.BC_eqn  = sm.Eq(u, g)

        # For reuse in class
        self.x   = x 
        # self.xbc = xbc
        self.U = [u, ux, uxx]
        self.f = f 
        self.g = g 

    def setup_residuals(self):
        pde_residual = self.PDE_eqn.lhs - self.PDE_eqn.rhs
        bc_residual  = self.BC_eqn.lhs  - self.BC_eqn.rhs

        self.pde_residual = sm.lambdify([self.x, self.U, self.f], pde_residual, modules=sympyTorchmodules)
        self.bc_residual  = sm.lambdify([self.x, self.U, self.g], bc_residual,  modules=sympyTorchmodules)      

    def compute_gradient(self, u, x, t=None):
            
        ux   = torch.autograd.grad(u,  x, grad_outputs=torch.ones_like(u),  retain_graph=True, create_graph=True)[0]
        uxx  = torch.autograd.grad(ux, x, grad_outputs=torch.ones_like(ux), create_graph=True)[0]

        return torch.hstack([u, ux, uxx]).T



##### Define exact, source and boundary condition functions

In [None]:
def f_u_exact(a,x):
    """ 
    Exact solution
    """
    u_exact = torch.sin(a*torch.pi*x)

    return u_exact

def f_x(a, x):
    """
    Source/Forcing function
    """
    fx = -(a**2)*(torch.pi**2)*torch.sin(a*torch.pi*x)
       
    return fx

def g_x(x, xb):
    """
    Boundary condition
    """
    
    ub = torch.zeros(x.size(), dtype=dtype)

    xb1_idx = torch.where(x == xb[0])[0]
    xb2_idx = torch.where(x == xb[1])[0]

    ub[xb1_idx] = 0
    ub[xb2_idx] = 0

    return ub

### PINN Class

In [None]:
class PINN(nn.Module):
    def __init__(self, input_size, output_size, neurons, PDE):
        super(PINN, self).__init__()

        # initialize values for nn
        self.xin        = input_size
        self.xout       = output_size
        self.neurons    = neurons

        # Define layers of network
        self.layer1     = nn.Linear(input_size, neurons, dtype=dtype)
        self.layer2     = nn.Linear(neurons, output_size, dtype=dtype)

        self.layers = [self.layer1, self.layer2]

        self.activation = nn.Tanh()      

        # import and initialize PDE
        if hasattr(PDE,'pde_residual'):
            self.pde_residual = PDE.pde_residual
        if hasattr(PDE, 'bc_residual'):
            self.bc_residual = PDE.bc_residual
        if hasattr(PDE, 'ic_residual'):
            self.ic_residual = PDE.ic_residual

        self.apply(self._init_weights)

        # copy gradient computation
        self.compute_pde_gradient = PDE.compute_gradient

    def _init_weights(self, module):
        # Glorot Weight initalisation
        if isinstance(module, nn.Linear):
            nn.init.xavier_normal_(module.weight.data)            
            if module.bias is not None:
                nn.init.normal_(module.bias.data)            


    def forward(self, x):

        for layer in self.layers[:-1]:
            x = layer(x)
            x = self.activation(x)

        x = self.layers[-1](x)

        return x
    
    def backward(self, X, U, f=None, g=None, h=None):

        if X.shape[1] == 2:
            xr = X[:, 0].view(-1,1)
            xb = X[:, 1].view(-1,1)
        else:
            xr = xb = X

        if len(U.shape) == 3:
            U_x = U[0]
            U_b = U[1]
        else:
            U_x = U
            U_b = U

        loss = []

        if hasattr(self, 'pde_residual') and f != None:
            residual        = self.pde_residual(xr, U_x, f).T
            self.pde_loss   = torch.mean(residual**2)
        
            loss.append(self.pde_loss)

        if hasattr(self, 'bc_residual') and g !=None:
            residual        = self.bc_residual(xb, U_b, g).T
            self.bc_loss    = torch.mean(residual**2)
        
            loss.append(self.bc_loss)   

        if hasattr(self, 'ic_residual') and h != None:
            residual        = self.ic_residual(X, U, h).T
            self.ic_loss    = torch.mean(residual**2)
            loss.append(self.ic_loss)

        loss = torch.stack(loss, dim=0).sum()
        
        self.loss = loss

    def NTK(self, X1, X2):

        # Get params of network
        params = {k: v for k, v in self.named_parameters()}

        PDE_K = False; BC_K = False

        if X1.shape[1] == 2 and X2.shape[1] == 2:
            xr1 = X1[:, 0].view(-1,1);   xb1 = X1[:, 1].view(-1,1)
            xr2 = X2[:, 0].view(-1,1);   xb2 = X2[:, 1].view(-1,1)

        else:
            xr1 = xb1 = X1
            xr2 = xb2 = X2

        if hasattr(self, 'pde_residual'):

            PDE_K = True;   

            f = torch.zeros(xr1.size(), device=device).T

            #### forward pass points with current parameters and compute gradients w.r.t interior points
            # Compute for X
            u_hat_x1    = self(xr1)
            U_x1        = self.compute_pde_gradient(u_hat_x1, xr1)
            # Compute for X'
            u_hat_x2    = self.forward(xr2)
            U_x2        = self.compute_pde_gradient(u_hat_x2, xr2)

            # Compute LHS of PDE 
            L_u1   = self.pde_residual(xr1, U_x1, f).T
            L_u2   = self.pde_residual(xr2, U_x2, f).T

            # L_u1.retain_grad(); L_u2.retain_grad()

            J_r1 = [];     J_r2 = []

            for i, layer_param in enumerate(params.keys()):
                theta    = params[layer_param]

                L_u1_grad = torch.zeros(theta.shape, device=device, dtype=dtype)
                L_u2_grad = torch.zeros(theta.shape, device=device, dtype=dtype)


                if 'bias' in layer_param:
                    if len(theta) == 1:
                        L_u1_grad = torch.stack([torch.autograd.grad(L_u1_i, theta, retain_graph=True)[0] for L_u1_i in L_u1], dim=0).flatten()
                        L_u2_grad = torch.stack([torch.autograd.grad(L_u2_i, theta, retain_graph=True)[0] for L_u2_i in L_u2], dim=0).flatten()
                    else:
                        L_u1_grad = torch.autograd.grad(L_u1, theta, grad_outputs=torch.ones_like(L_u1), retain_graph=True)[0].flatten()
                        L_u2_grad = torch.autograd.grad(L_u2, theta, grad_outputs=torch.ones_like(L_u2), retain_graph=True)[0].flatten()
                elif 'weight' in layer_param:
                    # compute backward graph w.r.t. parameters
                    L_u1_grad = torch.autograd.grad(L_u1, theta, grad_outputs=torch.ones_like(L_u1), retain_graph=True)[0].flatten()
                    L_u2_grad = torch.autograd.grad(L_u2, theta, grad_outputs=torch.ones_like(L_u2), retain_graph=True)[0].flatten()

                # L_u1.backward(gradient=theta, retain_graph=True)
                # L_u2.backward(gradient=theta, retain_graph=True)

                J_r1.append(L_u1_grad);    J_r2.append(L_u2_grad)

                # zero out gradiants
                # L_u1.grad.zero_();  L_u2.grad.zero_()
            ### End backward computation over parameters

            J_r1  = torch.stack(J_r1, dim=0).T;    J_r2 = torch.stack(J_r2, dim=0).T
            
            # compute NTK matrix for PDE residual
            self.K_rr       = J_r1 @ J_r2.T
            self.lambda_rr  = torch.linalg.eigvals(self.K_rr)

            #end

        if hasattr(self, 'bc_residual'):

            BC_K = True 
            g = torch.zeros(xb1.size(), device=device).T

            #### forward pass points with current parameters and compute gradients w.r.t boundary points
            # Compute for X
            u_hat_xb1    = self.forward(xb1)
            U_xb1        = self.compute_pde_gradient(u_hat_xb1, xb1)
            # Compute for X'
            u_hat_xb2    = self.forward(xb2)
            U_xb2        = self.compute_pde_gradient(u_hat_xb2, xb2)
            
            u1   = self.bc_residual(xb1, U_xb1, g).T.flatten()
            u2   = self.bc_residual(xb2, U_xb2, g).T.flatten()

            u1.retain_grad();     u2.retain_grad()


            J_u1 = [];     J_u2 = []

            for i, layer_param in enumerate(params.keys()):
                theta    = params[layer_param]

                u1_grad = torch.zeros(theta.shape, device=device, dtype=dtype)
                u2_grad = torch.zeros(theta.shape, device=device, dtype=dtype)


                if 'bias' in layer_param:
                    if len(theta) == 1:
                        u1_grad = torch.stack([torch.autograd.grad(u1_i, theta, retain_graph=True)[0] for u1_i in u1], dim=0).flatten()
                        u2_grad = torch.stack([torch.autograd.grad(u2_i, theta, retain_graph=True)[0] for u2_i in u2], dim=0).flatten()
                    else:
                        u1_grad = torch.autograd.grad(u1, theta, grad_outputs=torch.ones_like(u1), retain_graph=True)[0].flatten()
                        u2_grad = torch.autograd.grad(u2, theta, grad_outputs=torch.ones_like(u2), retain_graph=True)[0].flatten()
                elif 'weight' in layer_param:
                    # compute backward graph w.r.t. parameters
                    u1_grad = torch.autograd.grad(u1, theta, grad_outputs=torch.ones_like(u1), retain_graph=True)[0].flatten()
                    u2_grad = torch.autograd.grad(u2, theta, grad_outputs=torch.ones_like(u2), retain_graph=True)[0].flatten()

                # u1.backward(gradient=theta, retain_graph=True)
                # u2.backward(gradient=theta, retain_graph=True)

                J_u1.append(u1_grad);    J_u2.append(u2_grad)

                # zero out gradiants
                # u1.grad.zero_();  u2.grad.zero_()
            ### End backward computation over parameters
            
            J_u1  = torch.stack(J_u1, dim=0).T;    J_u2 = torch.stack(J_u2, dim=0).T

            self.K_uu       = J_u1 @ J_u2.T
            self.lambda_uu  = torch.linalg.eigvals(self.K_uu)
            
        if PDE_K and BC_K:

            K1 = torch.vstack((J_u1,   J_r1))
            K2 = torch.hstack((J_u2.T, J_r2.T))

            self.K = K1 @ K2

            self.lambda_K = torch.linalg.eigvals(self.K)



## Run model

In [None]:
from torch.utils.data import DataLoader, RandomSampler

#### Setup PDE Equation

In [None]:
## Setup PDE Equation
a   = 2
PDE = Poisson1D(a)

# Define PDE domain
X_0,X_N = 0.,1.
X_bc  = [X_0, X_N]

N  = 500
dx = (X_N - X_0) / N

Xr = torch.linspace(X_0, X_N, N, dtype=dtype, device=device, requires_grad=True).view(-1,1)
Xb = torch.randint(0, 2, (N,1),  dtype=dtype, device=device, requires_grad=True)

X  = torch.hstack((Xr, Xb))

#### Setup PINN 

In [None]:
## Setup PINN parameters

Nr      = 100
Nb      = 100
rand_sampler = RandomSampler(X, replacement=True)
XTrain       = DataLoader(X, Nr ,sampler=rand_sampler)

size          = len(XTrain.dataset)
learning_rate = 1e-3
epochs        = int(10e3)

# net parameters
input_size  = 1
output_size = 1
neurons     = 100
net         = PINN(input_size, output_size, neurons, PDE); net.to(device)

loss_fn   = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), learning_rate)
# optimizer = optim.Adam(net.parameters(), learning_rate)

#### Run PINN

In [None]:
%%time

### TRAIN LOOP
train_losses = []

# NTK computation
compute_NTK          = False
compute_NTK_interval = 1
store_NTK    = True
eig_K        = []
eig_K_uu     = []
eig_K_rr     = []

# Auto Mixed Precision settings
use_amp = False
scaler  = torch.cuda.amp.GradScaler(enabled=use_amp)

for epoch in range(epochs+1):
    net.train()

    epoch_loss   = 0.0

    for i, x in enumerate(XTrain):

        # reset gradients
        optimizer.zero_grad()

        xr = x[:,0].view(-1,1).to(device); xb = x[:,1].view(-1,1).to(device)

        ### INTERIOR DOMAIN
        # make prediction w.r.t. interior points

        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):

            ### Predict interior points
            u_hat_x   = net(xr)
        
            # determine gradients w.r.t interior points
            U_x       =  net.compute_pde_gradient(u_hat_x, xr)

            ### BOUNDARY DOMAIN
            u_hat_xb    = net(xb)

            # determine gradients w.r.t boundary points
            U_xb       =  net.compute_pde_gradient(u_hat_xb, xb)
            
            # Compute forcing/source function
            fx = f_x(a, xr).T.to(device)

            # compute boundary condition
            gx = g_x(xb, X_bc).T.to(device)

            # Stack
            U = torch.stack((U_x, U_xb), dim=0)

            ## Backward step
            net.backward(x, U, fx, gx)
            epoch_loss += net.loss.item()
            if i == len(XTrain) - 1:
                x_prime  = x.clone()

        # Do optimisation step
        if use_amp:
            scaler.scale(net.loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            optimizer.step()

    ### END Batch loop

    # Compute NTK
    if epoch > 0:
        if (epoch % compute_NTK_interval == 0 or epoch == epochs - 1) and compute_NTK:

            net.eval()

            x_          = x.detach().clone().requires_grad_()
            x_prime_    = x_prime.detach().clone().requires_grad_()

            net.NTK(x_, x_prime_)

            if store_NTK:
                eig_K.append(net.lambda_K)
                eig_K_uu.append(net.lambda_uu)
                eig_K_rr.append(net.lambda_rr)

    train_losses.append(epoch_loss / len(XTrain))
    
    if epoch % 100 == 0 or epoch == epochs - 1: 
        print(f"Epoch: {epoch:4d}     loss: {train_losses[-1]:5f}")
### END training loop

if compute_NTK:
    # reformat eigenvalue of NTK matrices
    eig_K       = torch.stack(eig_K, dim=-1)
    eig_K_uu    = torch.stack(eig_K_uu, dim=-1)
    eig_K_rr    = torch.stack(eig_K_rr, dim=-1)


#### Results

In [None]:
xplot = torch.linspace(X_0, X_N, N, requires_grad=True, dtype=dtype).view(-1,1).to(device)

# compute exact solution
u_exact = f_u_exact(a, xplot)
u_pred  = net(xplot)

xplot   = xplot.cpu().detach().numpy()
u_exact = u_exact.cpu().detach().numpy()
u_pred  = u_pred.cpu().detach().numpy()

### PLOT Prediction accuracy and training loss

fig, axs = plt.subplots(1,2, figsize=(23,6))

# predict
axs[0].plot(xplot, u_exact, label='$u_{exact}$')
axs[0].plot(xplot, u_pred, label='$u_{pred}$')
axs[0].legend()
axs[0].set_ylabel(r'$u$')
axs[0].set_xlabel(r'$x$')


axs[1].semilogy(train_losses)
axs[1].set_ylabel(r'loss per epoch')
axs[1].set_xlabel(r'$Epoch$')


if compute_NTK:
        eig_K_plot    = np.real(eig_K.detach().cpu().numpy())
        eig_K_uu_plot = np.real(eig_K_uu.detach().cpu().numpy())
        eig_K_rr_plot = np.real(eig_K_rr.detach().cpu().numpy())

        ### PLOT Eigenvalue of NTK matrices
        fig, axs = plt.subplots(1,3, figsize=(23,6))

        axs[0].semilogx(eig_K_plot[:,-1],      label=r'$\lambda_{K}$');     axs[0].set_title('Eigenvalue of K')
        axs[1].semilogx(eig_K_uu_plot[:,-1],   label=r'$\lambda_{uu}$');    axs[1].set_title('Eigenvalue of {}'.format(r"$K_{uu}$"))
        axs[2].semilogx(eig_K_rr_plot[:,-1],   label=r'$\lambda_{rr}$');    axs[2].set_title('Eigenvalue of {}'.format(r"$K_{rr}$"))

        for ax in axs:
                ax.legend()
                ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0))
                ax.set_ylabel(r'$\lambda$')
                ax.set_xlabel(r'$Index$')

plt.show()
