In [None]:
# -----------------------------------------------------------------------------
# Portions of this code are adapted from:
#   - https://github.com/TamaraGrossmann/FEM-vs-PINNs.git
#   - Grossmann, T. G., Komorowska, U. J., Latz, J., & Schönlieb, C.-B. (2023).
#     Can Physics-Informed Neural Networks beat the Finite Element Method?
#     arXiv:2302.04107.
# -----------------------------------------------------------------------------

In [None]:
# (for Google Colab)
!pip install pyDOE

In [None]:
import jax, flax, optax, time, pickle
import os
import jax.numpy as np
import numpy as onp
from functools import partial
from pyDOE import lhs
from typing import Sequence
import json
from tensorflow_probability.substrates import jax as tfp

In [None]:
# Must set CUDA_VISIBLE_DEVICES before importing JAX or any other library that initializes GPUs.
# Otherwise, the environment variable change might be ignored.
# "0, 1": first two GPUs / "": no GPU (CPU instead)

# Run on the first GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from jax.extend.backend import get_backend
print(get_backend().platform)

# Hyperparameters

In [None]:
architecture_list = [[20, 1], [60, 1], [20, 20, 1], [60, 60, 1], 
                     [20, 20, 20, 1], [60, 60, 60, 1],
                     [20, 20, 20, 20, 1], [60, 60, 60, 60, 1],
                     [20, 20, 20, 20, 20, 1], [60, 60, 60, 60, 60, 1],
                     [120, 120, 120, 120, 120, 1]] # NN architecture list
lr = 1e-3 # learning rate
num_epochs = 20000 # number of training epochs

# NN Architecture

In [None]:
# Define NN architecture
class PDESolution(flax.linen.Module): # inherit from Module class
    features: Sequence[int] # dataclass (e.g. [10, 20, 1])

    @flax.linen.compact # a decorator to define the model in more concise and readable way
    def __call__(self, x): # __call__: makes an object callable, which enables you to use instances of the class like functions
        for feature in self.features[:-1]:
            x = flax.linen.tanh(flax.linen.Dense(feature)(x))
        # Final Dense layer
        x = flax.linen.Dense(self.features[-1])(x)
        return x

# Analytical Solution

In [None]:
@partial(jax.vmap, in_axes=(0, 0), out_axes=0)
@jax.jit
def analytic_sol(xs,ys):
    sol = (xs**2) * (xs-1)**2 * ys * (ys-1)**2
    return sol

# Loss Function

In [None]:
# Derivatives for the Neumann B.C.
@partial(jax.vmap, in_axes=(None, 0, 0), out_axes=(0, 0, 0))
@jax.jit
def neumann_derivatives(params, xs, ys):
    u = lambda x, y: model.apply(params, np.stack((x, y)))
    du_dx_0 = jax.jvp(u, (0., ys), (1., 0.))[1] # du/dx(0, ys)
    du_dx_1 = jax.jvp(u, (1., ys), (1., 0.))[1] # du/dx(1, ys)
    du_dy_1 = jax.jvp(u, (xs, 1.), (0., 1.))[1] # du/dy(xs, 1)
    return du_dx_0, du_dx_1, du_dy_1

# PDE residual
@partial(jax.vmap, in_axes = (None, 0, 0), out_axes = 0)
@partial(jax.jit, static_argnums = (0,)) # decorator closest to the function is applied first
def residual(u, x, y):
    H1 = jax.hessian(u, argnums=0)(x, y) # equivalent to "hvp" in 1D_Poisson_PINNs code
    H2 = jax.hessian(u, argnums=1)(x, y)
    lhs = H1 + H2
    rhs = 2*((x**4)*(3*y-2) + (x**3)*(4-6*y) + (x**2)*(6*(y**3)-12*(y**2)+9*y-2) - 6*x*((y-1)**2)*y + ((y-1)**2)*y)
    return lhs - rhs

# Loss functionals
@jax.jit
def pde_residual(params, points):
    return np.mean(residual(lambda x, y: model.apply(params, np.stack((x, y))), points[:, 0], points[:, 1]) ** 2) # Mean Squared Error

@jax.jit
def dirichlet_residual(params, points):
    return np.mean((model.apply(params, np.stack((points[:,0], np.zeros_like(points[:,1])), axis=1))) ** 2)

@partial(jax.jit, static_argnums=0) # du/dx(0,y) = 0, du/dx(1,y) = 0, du/dy(x,1) = 0
def neumann_residual(neumann_derivatives, params, points):
    du_dx_0, du_dx_1, du_dy_1 = neumann_derivatives(params, points[:, 0], points[:, 1])
    return np.mean((du_dx_0**2) + (du_dx_1**2) + (du_dy_1**2))

# Training Loop

In [None]:
# Define Training Step
@partial(jax.jit, static_argnums = (1, 4))
def training_step(params, opt, opt_state, key, neumann_derivatives):
    """
    Args:
        params: model parameters
        opt: optimizer
        opt_state: optimizer state
        key: random key for sampling
    """
    lb = np.array([0., 0.]) # lower bound
    ub = np.array([1., 1.]) # upper bound
    domain_points = lb + (ub - lb) * lhs(2, 2000) # latin hypercube sampling 256 points within [0, 1]
    boundary_points = lb + (ub - lb) * lhs(2, 250) # scaless the samples from [0, 1] to [lb, ub]

    loss_pde = pde_residual(params, domain_points)
    loss_dirichlet = dirichlet_residual(params, boundary_points)
    loss_neumann = neumann_residual(neumann_derivatives, params, boundary_points)

    loss_val, grad = jax.value_and_grad(lambda params: loss_pde + loss_dirichlet + loss_neumann)(params)
    
    update, opt_state = opt.update(grad, opt_state, params) # update using "grad"
    params = optax.apply_updates(params, update) # apply updates to "params"
    return params, opt_state, key, loss_val

# Training loop
def train_loop(params, adam, opt_state, key, neumann_derivatives):
    losses = []
    for _ in range(num_epochs): # "_" is used because the variable is not used in for loop
        params, opt_state, key, loss_val = training_step(params, adam, opt_state, key, neumann_derivatives)
        losses.append(loss_val.item())
    return losses, params, opt_state, key, loss_val # return final values

# Train PINN & Approximate Solution

In [None]:
# Containers for the results
y_results, domain_pts, times_adam, times_lbfgs, times_total, times_eval, l2_rel, var, arch\
    = dict({}), dict({}), dict({}), dict({}), dict({}), dict({}), dict({}), dict({}), dict({})

count = 0 # architecture index
for architecture in architecture_list:
    print('Architecture : %s' %architecture)
    times_adam_temp = [] # containers for 10 times training results
    times_lbfgs_temp = []
    times_total_temp = []
    times_eval_temp = []
    accuracy_temp = []
    for _ in range(10): # loop over 10 training runs
        # Initialize Model
        model = PDESolution(architecture)
        key, key2 = jax.random.split(jax.random.PRNGKey(0)) # create two keys for independent use
        batch_dim = 8 # it is just for parameter initialization (can be any value)
        feature_dim = 2 # dimension of input point (x coord)
        params = model.init(key, np.ones((batch_dim, feature_dim))) # params(dict) : weights and biases initialized randomly

        # Initialize Optimizer
        adam = optax.adam(learning_rate = lr) #
        opt_state = adam.init(params) # opt_state : internal states of the Adam optimizer

        # Start Training with Adam Optimizer
        start_time = time.time()
        losses, params, opt_state, key, loss_val = jax.block_until_ready(train_loop(params, adam, opt_state, key, neumann_derivatives))
        adam_time = time.time() - start_time
        times_adam_temp.append(adam_time)
        print('Adam Training Time : %f secs' %adam_time)

        # Generate data
        lb = onp.array([0., 0.])
        ub = onp.array([1., 1.])
        domain_points = lb + (ub - lb) * lhs(2, 2000)
        boundary_points = lb + (ub - lb) * lhs(2, 250)

        init_point, tree, shapes = concat_params(params)

        # L-BFGS Optimization
        print('Starting L-BFGS Optimization')
        start_time2 = time.time()
        results = tfp.optimizer.lbfgs_minimize(jax.value_and_grad(lambda params:
                                                                  pde_residual(unconcat_params(params, tree, shapes), domain_points) +
                                                                  dirichlet_residual(unconcat_params(params, tree, shapes), boundary_points) +
                                                                  neumann_residual(neumann_derivatives, unconcat_params(params, tree, shapes), boundary_points)),
                                               init_point, max_iterations = 50000,
                                               num_correction_pairs = 50, # number of past updates to use for the approximation of the Hessian inverse.
                                               f_relative_tolerance = 1.0*np.finfo(float).eps) # stopping criterion
        lbfgs_time = time.time() - start_time2
        times_lbfgs_temp.append(lbfgs_time)
        times_total_temp.append(adam_time + lbfgs_time)

        # Comparison to Ground Truth
        tuned_params = unconcat_params(results.position, tree, shapes)

        with open('2D_Poisson_eval_points.json', 'r') as f:
            domain_points = json.load(f) # pre-specified evaluation points (different from training points) for measuring error.
            domain_points = np.array(domain_points)

        start_time3 = time.time()
        u_approx = jax.block_until_ready(model.apply(tuned_params, np.stack((domain_points[:, 0], domain_points[:, 1]), axis=1)).squeeze()) # pass the "domain_points" to the trained model
        eval_time = time.time() - start_time3
        times_eval_temp.append(eval_time)

        u_true = analytic_sol(domain_points[:,0],domain_points[:,1]).squeeze() # ground truth
        run_accuracy = (onp.linalg.norm(u_approx - u_true)) / onp.linalg.norm(u_true) # relative L2 error
        accuracy_temp.append(run_accuracy)

    y_gt = u_true.tolist() # for storing into dict
    y_results[count] = u_approx.tolist()
    domain_pts[count] = domain_points.tolist()
    times_adam[count] = onp.mean(times_adam_temp) # mean times across the 10 runs
    times_lbfgs[count] = onp.mean(times_lbfgs_temp)
    times_total[count] = onp.mean(times_total_temp)
    times_eval[count] = onp.mean(times_eval_temp)
    l2_rel[count] = onp.mean(accuracy_temp).tolist()
    var[count] = onp.var(accuracy_temp).tolist() # variance of the error across the 10 runs
    arch[count] = architecture_list[count]
    count += 1

    results = dict({'domain_pts': domain_pts,
                    'y_results': y_results,
                    'y_gt': y_gt})

    evaluation = dict({'arch': arch,
                    'times_adam': times_adam,
                    'times_lbfgs': times_lbfgs,
                    'times_total': times_total,
                    'times_eval': times_eval,
                    'l2_rel': l2_rel,
                    'var': var})

    # Save Results & Evaluation
    save_dir = './2D_Poisson'
    os.makedirs(save_dir, exist_ok = True)

    with open(os.path.join(save_dir, 'PINNs_results.json'), 'w') as f:
        json.dump(results, f)

    with open(os.path.join(save_dir, 'PINNs_evaluation.json'), 'w') as f:
        json.dump(evaluation, f)

    print(json.dumps(evaluation, indent = 4))

# Plots

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

with open(os.path.join(save_dir, 'PINNs_results.json'), 'r') as f:
    data_results = json.load(f)
    
with open(os.path.join(save_dir, 'PINNs_evaluation.json'), 'r') as f:
    data_eval = json.load(f)


# Pick one architecture
arch_idx = '8'
arch_indices = sorted([int(k) for k in data_eval['arch'].keys()])
arch_indices_str = [str(i) for i in arch_indices]
# For instance, pick the first architecture index in data_results['y_results']
# example_key = list(data_results['y_results'].keys())[0]  # first architecture
domain_points = np.array(data_results['domain_pts'][arch_idx])  # shape (N, 2)
u_approx   = np.array(data_results['y_results'][arch_idx])  # shape (N,)
u_exact    = np.array(data_results['y_gt'])          # shape (N,)

# Unpack X, Y
X = domain_points[:, 0]
Y = domain_points[:, 1]

# Approximate solution (left)
fig = plt.figure(figsize=(12, 5))
ax1 = fig.add_subplot(1, 2, 1)
sc1 = ax1.scatter(X, Y, u_approx, c=u_approx, cmap='viridis')
sc1 = ax1.tricontourf(
    X, Y, u_approx, 
    levels=50, 
    cmap='viridis'
)
ax1.set_title(f"PINN Approx. Solution (Arch index={arch_idx})")
ax1.set_xlabel("x")
ax1.set_ylabel("y")
fig.colorbar(sc1, ax=ax1, shrink=0.5)

# Exact solution (right)
ax2 = fig.add_subplot(1, 2, 2)
sc2 = ax2.tricontourf(
    X, Y, u_exact, 
    levels=50, 
    cmap='viridis'
)
ax2.set_title("Exact Solution")
ax2.set_xlabel("x")
ax2.set_ylabel("y")
fig.colorbar(sc2, ax=ax2, shrink=0.5)

plt.tight_layout()
plt.savefig("2D_Poisson_PINNs_approx_vs_exact.png", dpi=300, bbox_inches='tight')
plt.show()

# Relative Error vs. Training Time or Evaluation Time
train_times = []
eval_times  = []
rel_errors  = []
arch_labels = []

for i in arch_indices_str:
    train_times.append(data_eval['times_total'][i])  # or 'times_adam'[i], 'times_lbfgs'[i], etc.
    eval_times.append(data_eval['times_eval'][i])
    rel_errors.append(data_eval['l2_rel'][i])
    arch_labels.append(str(data_eval['arch'][i]))  # e.g. "20,20,1"

# Relative error vs. total training time
plt.figure(figsize=(7,5))
plt.scatter(rel_errors, train_times, color='blue')
for idx, label in enumerate(arch_labels):
    plt.text(rel_errors[idx], train_times[idx], label, fontsize=8,
             ha='left', va='bottom')

plt.xlabel(r"Relative $L_2$ Error")
plt.ylabel("Total Training Time (seconds)")
plt.title("Relative Error vs. Total Training Time (by architecture)")
plt.grid(True)
plt.tight_layout()
plt.savefig("2D_Poisson_PINNs_L2rel_vs_Train_time.png", dpi=300, bbox_inches='tight')
plt.show()

# Relative error vs. evaluation time
plt.figure(figsize=(7,5))
plt.scatter(rel_errors, eval_times, color='red')
for idx, label in enumerate(arch_labels):
    plt.text(rel_errors[idx], eval_times[idx], label, fontsize=8,
             ha='left', va='bottom')

plt.xlabel(r"Relative $L_2$ Error")
plt.ylabel("Evaluation Time (seconds)")
plt.title("Relative Error vs. Evaluation Time (by architecture)")
plt.grid(True)
plt.tight_layout()
plt.savefig("2D_Poisson_PINNs_L2rel_vs_Eval_time.png", dpi=300, bbox_inches='tight')
plt.show()