In [None]:
# -----------------------------------------------------------------------------
# Portions of this code are adapted from:
#   - https://github.com/TamaraGrossmann/FEM-vs-PINNs.git
#   - Grossmann, T. G., Komorowska, U. J., Latz, J., & Schönlieb, C.-B. (2023).
#     Can Physics-Informed Neural Networks beat the Finite Element Method?
#     arXiv:2302.04107.
# -----------------------------------------------------------------------------

In [None]:
# (for Google Colab)
!pip install pyDOE

In [None]:
import jax, flax, optax, time, pickle
import os
import jax.numpy as np
import numpy as onp
from functools import partial
from pyDOE import lhs
from typing import Sequence
import json
from tensorflow_probability.substrates import jax as tfp
import matplotlib.pyplot as plt

In [None]:
# Must set CUDA_VISIBLE_DEVICES before importing JAX or any other library that initializes GPUs.
# Otherwise, the environment variable change might be ignored.
# "0, 1": first two GPUs / "": no GPU (CPU instead)

# Run on the first GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from jax.extend.backend import get_backend
print(get_backend().platform)

# Hyperparameters

In [None]:
architecture_list = [[20, 20, 20, 1], [100, 100, 100, 1], [500, 500, 500, 1], [2500, 2500, 2500, 1]]

# architecture_list = [[20, 20, 20, 1], [100, 100, 100, 1], [500, 500, 500, 1],
#                      [20, 20, 20, 20, 1], [100, 100, 100, 100, 1], [500, 500, 500, 500, 1],]
                     # [20, 20, 20, 20, 20, 1], [100, 100, 100, 100, 100, 1], [500, 500, 500, 500, 500, 1],
                     # [20, 20, 20, 20, 20, 20, 1], [100, 100, 100, 100, 100, 100, 1], [500, 500, 500, 500, 500, 500, 1],
                     # [20, 20, 20, 20, 20, 20, 20, 1], [100, 100, 100, 100, 100, 100, 100, 1]] # NN architecture list
lr = 1e-3 # learning rate
num_epochs = 50000 # number of training epochs

# NN Architecture

In [None]:
# Define NN architecture
class PDESolution(flax.linen.Module): # inherit from Module class
    features: Sequence[int] # dataclass (e.g. [10, 20, 1])

    @flax.linen.compact # a decorator to define the model in more concise and readable way
    def __call__(self, x): # __call__: makes an object callable, which enables you to use instances of the class like functions
        for feature in self.features[:-1]:
            x = flax.linen.tanh(flax.linen.Dense(feature)(x))
        # Final Dense layer
        x = flax.linen.Dense(self.features[-1])(x)
        return x

# Loss Function

In [None]:
# PDE residual
@partial(jax.vmap, in_axes = (None, 0, 0, 0), out_axes = 0)
@partial(jax.jit, static_argnums = (0,)) # decorator closest to the function is applied first
def residual(u, t, x, y):
    u_t = jax.jvp(u, (t, x, y), (1., 0., 0.))[1] # partial derivative w.r.t t = directional derivative along <1, 0>
    u_xx = jax.hessian(u, argnums = 1)(t, x, y) # differentiate w.r.t argument 1(x)
    u_yy = jax.hessian(u, argnums = 2)(t, x, y)
    f = 10.0 * np.sin(np.pi*x) * np.sin(np.pi*y) * np.cos(2.0*np.pi*t) # time-varying source
    return u_t - (u_xx + u_yy) - f

# Inital condition
@partial(jax.vmap, in_axes=(0, 0)) # vectorized over "xs" and "ys"
def u_init(xs, ys):
    return np.array([np.exp(-50.0*((xs-0.5)**2 + (ys-0.5)**2))])

# Loss functionals
@jax.jit
def pde_residual(params, points):
    return np.mean(residual(lambda t, x, y: model.apply(params, np.stack((t, x, y))), points[:, 0], points[:, 1], points[:, 2]) ** 2) # Mean Squared Error

@partial(jax.jit, static_argnums=0)
def init_residual(u_init, params, points):
    lhs = model.apply(params, np.stack((np.zeros_like(points[:, 0]), points[:, 1], points[:, 2]), axis=1))
    rhs = u_init(points[:, 1], points[:, 2])
    return np.mean((lhs - rhs) ** 2)

@jax.jit
def dirichlet_residual_x(params, points):
    return np.mean((model.apply(params, np.stack((points[:,0], np.zeros_like(points[:,1]), points[:,2]), axis=1))) ** 2)

@jax.jit
def dirichlet_residual_y(params, points):
    return np.mean((model.apply(params, np.stack((points[:,0], points[:,1], np.zeros_like(points[:,2])), axis=1))) ** 2)

# Training Loop

In [None]:
# Define Training Step
@partial(jax.jit, static_argnums = (1,))
def training_step_ini(params, opt, opt_state, key):
    """
    In the PDE problem with initial condition, training NN with I.C. first can be helpful because
    - it does not need so many training steps(epochs). (initial guess is already close to the solution)
    - It stabilizes subsequent training, because the solution is already near the correct initial profile.

    Args:
        params: model parameters
        opt: optimizer
        opt_state: optimizer state
        key: random key for sampling
    """
    lb = np.array([0., 0., 0.]) # lower bound
    ub = np.array([1., 1., 1.]) # upper bound
    domain_points = lb + (ub - lb) * lhs(3, 5000) # latin hypercube sampling 20000 points within (ti, xi) ∈ [0, 0.05] × [0, 1]
    boundary_points = lb + (ub - lb) * lhs(3, 100) # latin hypercube sampling 250 points within ti ∈ [0, 0.05]
    init_points = lb[1:] + (ub[1:] - lb[1:]) * lhs(2, 100) # latin hypercube sampling 500 points within xi, yi ∈ [0, 1]

    loss_init = init_residual(u_init, params, init_points)

    loss_val, grad = jax.value_and_grad(lambda params: loss_init)(params)

    update, opt_state = opt.update(grad, opt_state, params) # update using "grad"
    params = optax.apply_updates(params, update) # apply updates to "params"
    return params, opt_state, key, loss_val

@partial(jax.jit, static_argnums = (1,))
def training_step(params, opt, opt_state, key):
    lb = np.array([0., 0., 0.])
    ub = np.array([1., 1., 1.])
    domain_points = lb + (ub - lb) * lhs(3, 5000) # latin hypercube sampling 20000 points within (ti, xi) ∈ [0, 0.05] × [0, 1]
    boundary_points = lb + (ub - lb) * lhs(3, 100) # latin hypercube sampling 250 points within ti ∈ [0, 0.05]
    init_points = lb[1:] + (ub[1:] - lb[1:]) * lhs(2, 100) # latin hypercube sampling 500 points within xi, yi ∈ [0, 1]

    loss_pde = pde_residual(params, domain_points)
    loss_init = init_residual(u_init, params, init_points)
    loss_dirichlet_x = dirichlet_residual_x(params, boundary_points)
    loss_dirichlet_y = dirichlet_residual_y(params, boundary_points)

    loss_val, grad = jax.value_and_grad(lambda params: loss_pde + loss_init + loss_dirichlet_x + loss_dirichlet_y)(params)

    update, opt_state = opt.update(grad, opt_state, params)
    params = optax.apply_updates(params, update)
    return params, opt_state, key, loss_val

# Training loop
def train_loop(params, adam, opt_state, key):
    losses = []
    for i in range(7000):
        params, opt_state, key, loss_val = training_step_ini(params, adam, opt_state, key)
        losses.append(loss_val.item())
        if (i+1) % 2000 == 0: # print every 1000 epochs
            print(f"[Training with I.C. loss] Epoch {i + 1}: Loss = {loss_val.item()}")

    for j in range(num_epochs):
        params, opt_state, key, loss_val = training_step(params, adam, opt_state, key)
        losses.append(loss_val.item())
        if (j+1) % 10000 == 0:
            print(f"[Training with total loss] Epoch {j + 1}: Loss = {loss_val.item()}")
    return losses, params, opt_state, key, loss_val # return final values

# Helper Functions for L-BFGS Wrapper

In [None]:
# L-BFGS requires the parameters to be a single flattened array!
def concat_params(params): # flatten the parameters
    params, tree = jax.tree_util.tree_flatten(params) # "params" is flattened to a list of arrays
    shapes = [param.shape for param in params] # shape of each array in the "params" list
    return np.concatenate([param.reshape(-1) for param in params]), tree, shapes # concat to single 1D array

def unconcat_params(params, tree, shapes): # unflatten the parameters
    split_vec = np.split(params, onp.cumsum([onp.prod(shape) for shape in shapes])) # "np.cumsum" figures out the boundaries where to split the flattened "params"
    split_vec = [vec.reshape(*shape) for vec, shape in zip(split_vec, shapes)] # reshape slices of vector ("*" unpack the tuple into individual arguments)
    return jax.tree_util.tree_unflatten(tree, split_vec)

# Evaluation Points & Ground Truth Solutions

In [None]:
# Load evaluation points
with open('2D_Transient_Heat_eval_points.json', 'r') as f:
    data_points = json.load(f)

# Evaluation coordinates & time
mesh_coords = data_points["mesh_coord"]["0"]
dt_coords = data_points["dt_coord"]["0"] # [[0.0], [0.1], ..., [1.0]]
times = [dt_coord[0] for dt_coord in dt_coords]  # unpack to [0.0, 0.1, ...]

# Load evaluation solutions (by Ground Truth FEM)
with open('2D_Transient_Heat_eval_solutions.json', 'r') as f:
    data_sol = json.load(f)

# Ground truth soution (100 x 100 cells)
u_true = np.array(data_sol) # shape: (n_times, n_points)

print("Evaluation points: ", np.array(mesh_coords).shape)
print("Evaluation times: ", np.array(dt_coords).shape)

# PINNs Approximate Solution

In [None]:
# Containers for the results
y_results, times_adam, times_lbfgs, times_total, times_eval, l2_rel, var, arch\
    =  dict({}), dict({}), dict({}), dict({}), dict({}), dict({}), dict({}), dict({})
# y_results = []  # Initialize y_results as a list
# times_adam, times_lbfgs, times_total, times_eval, l2_rel, var, arch = [], [], [], [], [], [], []

count = 0 # architecture index
for architecture in architecture_list:
    times_adam_temp = [] # containers for 10 times training results
    times_lbfgs_temp = []
    times_total_temp = []
    times_eval_temp = []
    l2_errors = []
    l2_errors_total = []
    for i in range(10): # loop over 10 training runs

        # Initialize model
        model = PDESolution(architecture)
        key, key2 = jax.random.split(jax.random.PRNGKey(0)) # create two keys for independent use
        batch_dim = 4 # it is just for parameter initialization (can be any value)
        feature_dim = 3 # dimension of input point (t, x, y)
        params = model.init(key2, np.ones((batch_dim, feature_dim))) # params(dict) : weights & biases initialized randomly

        # Initialize Adam optimizer
        adam = optax.adam(lr)
        opt_state = adam.init(params)

        # Training with Adam optimiser
        start_time = time.time()
        losses, params, opt_state, key, loss_val = jax.block_until_ready(train_loop(params, adam, opt_state, key))
        adam_time = time.time() - start_time
        times_adam_temp.append(adam_time)
        print("Adam training time : %.3f secs" % adam_time)

        # Generate collocation points
        lb = np.array([0., 0., 0.])
        ub = np.array([1., 1., 1.])
        domain_points = lb + (ub - lb) * lhs(3, 5000)
        boundary_points = lb + (ub - lb) * lhs(3, 100)
        init_points = lb[1:] + (ub[1:] - lb[1:]) * lhs(2, 100)
        init_point, tree, shapes = concat_params(params)

        # Training with L-BFGS optimiser
        start_time2 = time.time()
        results = tfp.optimizer.lbfgs_minimize(jax.value_and_grad(lambda params: pde_residual(unconcat_params(params, tree, shapes), domain_points) +
                                                                                 init_residual(u_init, unconcat_params(params, tree, shapes), init_points) +
                                                                                 dirichlet_residual_x(unconcat_params(params, tree, shapes), boundary_points) +
                                                                                 dirichlet_residual_y(unconcat_params(params, tree, shapes), boundary_points)),
                                               init_point,
                                               max_iterations=50000,
                                               num_correction_pairs=50, # number of past updates to use for the approximation of Hessian inverse
                                               f_relative_tolerance=1.0 * np.finfo(float).eps) # stopping criterion
        lbfgs_time = time.time() - start_time2
        times_lbfgs_temp.append(lbfgs_time)
        times_total_temp.append(adam_time + lbfgs_time)
        print("L-BFGS training time : %.3f secs" % lbfgs_time)

        
        tuned_params = unconcat_params(results.position, tree, shapes)

        # l2, times_temp, approx, gt_fem, domain_pt = CompareGT.get_FEM_comparison(mesh_coord,dt_coords,FEM,model,tuned_params)
        mesh_coords_squeeze = np.asarray(mesh_coords).squeeze()
        dom_mesh_ = np.tile(mesh_coords_squeeze, (len(dt_coords), 1)) # (??? need more understand) repeating the dom_mesh, dt_coords_100.shape-times
        dom_ts = np.repeat(np.array(dt_coords),len(mesh_coords)) # (??? need more understand) repeating ts, len(mesh_coords)-times
        time_space_coords = np.stack((dom_ts,dom_mesh_[:,0],dom_mesh_[:,1]), axis=1)  #stacking them together, meaning for each mesh coordinate we look at every time instance in ts

        # Evaluate model
        start_time3 = time.time()
        u_approx = jax.block_until_ready(model.apply(tuned_params, time_space_coords).squeeze()) # pass "time_space_coords" to model
        eval_time = time.time() - start_time3
        times_eval_temp.append(eval_time)
        print("Evaluation time : %.3f secs" % eval_time)

        # L2 error
        # u_approx = u_approx.reshape(len(dt_coords),len(mesh_coords)) # going back to shape (n_times, n_points)
        u_true = u_true.squeeze()
        for j in range(len(dt_coords)):
            l2 = np.linalg.norm(u_approx[int(j)] - u_true[int(j)])/np.linalg.norm(u_true[int(j)])
            l2_errors.append(l2)

        # Total L2 error
        l2_errors_total.append(np.mean(np.array(l2_errors))) # mean of the L2 errors over all time steps

        print(f'Architecture {architecture}, RUN {i}')
        
    print('Average Training Time : ', onp.mean(times_total_temp))
    print('Average Evaluation Time : ', onp.mean(times_eval_temp))
    print('Average Accuracy : ', onp.mean(np.array(l2_errors_total)).tolist())
    
    y_gt = u_true.tolist()
    domain_pts = time_space_coords.tolist()
    y_results[count] = u_approx.tolist()
    times_adam[count] = onp.mean(times_adam_temp)
    times_lbfgs[count] = onp.mean(times_lbfgs_temp)
    times_total[count] = onp.mean(times_total_temp)
    times_eval[count] = onp.mean(times_eval_temp)
    l2_rel[count] = onp.mean(np.array(l2_errors_total)).tolist()
    var[count] = onp.var(np.array(l2_errors_total)).tolist()
    arch[count] = architecture_list[count]
    count += 1
    # y_results.append(u_approx.tolist()) # Append to the list
    # Append the calculated values for the current architecture to the respective lists
    # times_adam.append(onp.mean(times_adam_temp))
    # times_lbfgs.append(onp.mean(times_lbfgs_temp))
    # times_total.append(onp.mean(times_total_temp))
    # times_eval.append(onp.mean(times_eval_temp))
    # l2_rel.append(onp.mean(np.array(l2_errors_total)).tolist())
    # var.append(onp.var(np.array(l2_errors_total)).tolist())
    # arch.append(architecture) # Append the architecture

    results = dict({'domain_pts': domain_pts,
                    'y_results': y_results,
                    'y_gt': y_gt})

    evaluation = dict({'arch': arch,
                       'times_adam': times_adam,
                       'times_lbfgs': times_lbfgs,
                       'times_total': times_total,
                       'times_eval': times_eval,
                       'l2_rel': l2_rel,
                       'var': var})

    sol_json = 'PINNs_results.json'
    with open(sol_json, "w") as f:
        json.dump(results, f)

    eval_json = 'PINNs_evaluation.json'
    with open(eval_json, "w") as f:
        json.dump(evaluation, f)

# Contour Plot (Solution)

In [None]:
# Load evaluation points
with open('2D_Transient_Heat_eval_points.json', 'r') as f:
    data_points = json.load(f)

# Evaluation coordinates & time
mesh_coords = data_points["mesh_coord"]["0"] # list
dt_coords = data_points["dt_coord"]["0"] # [[0.0], [0.1], ..., [1.0]]
times = [dt_coord[0] for dt_coord in dt_coords]  # unpack to [0.0, 0.1, ...]

# Load evaluation results (by PINNs)
with open(os.path.join(save_dir, 'PINNs_results.json'), 'r') as f:
    data_results = json.load(f)

with open(os.path.join(save_dir, 'PINNs_evaluation.json'), 'r') as f:
    data_eval = json.load(f)

# Slice results data
domain_pts = data_results['domain_pts']
y_results = data_results['y_results']
y_gt = data_results['y_gt']

arch = data_eval['arch']
times_adam = data_eval['times_adam']
times_lbfgs = data_eval['times_lbfgs']
times_total = data_eval['times_total']
times_eval = data_eval['times_eval']
l2_rel = data_eval['l2_rel']
var = data_eval['var']

# x, y coordinates
mesh_coords = np.array(mesh_coords)
X = mesh_coords[:, 0]
Y = mesh_coords[:, 1]

for idx, architecture in arch.items():
    # Get the approximate solution for this architecture
    u_approx = np.array(y_results[idx])  # (n,)
    u_approx = u_approx.reshape(len(dt_coords), len(mesh_coords)) # unsqueeze u_approx

    # Contour plot settings
    # For a consistent scale across all time steps
    u_min = u_approx.min()
    u_max = u_approx.max()
    # Create n levels between u_min & u_max:
    num_levels = 80
    levels = np.linspace(u_min, u_max, num_levels)

    for t in range(len(times)):

        u_approx_t = u_approx[t, :]

        fig = plt.figure(figsize=(6, 5))
        sc1 = plt.tricontourf(
            X, Y, u_approx_t,
            levels=levels,
            cmap='viridis')
        plt.title(f"PINNs Solution (Architecture={architecture}, t_step={t})")
        plt.xlabel("x")
        plt.ylabel("y")
        plt.colorbar(sc1, shrink=0.7)

        plt.tight_layout()

        # Save figures
        fig_dir = f'./fig/sol_contour/PINNs/arch_{idx}'
        if not os.path.exists(fig_dir):
            os.makedirs(fig_dir, exist_ok=True)

        filename = os.path.join(fig_dir, f'sol_{t:04d}.png')
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        plt.close(fig)