In [None]:
# -----------------------------------------------------------------------------
# Portions of this code are adapted from:
#   - https://github.com/TamaraGrossmann/FEM-vs-PINNs.git
#   - Grossmann, T. G., Komorowska, U. J., Latz, J., & Schönlieb, C.-B. (2023).
#     Can Physics-Informed Neural Networks beat the Finite Element Method?
#     arXiv:2302.04107.
# -----------------------------------------------------------------------------

In [None]:
# (for Google Colab)
!pip install pyDOE

In [None]:
import jax, flax, optax, time, pickle
import os
import jax.numpy as np
import numpy as onp
from functools import partial
from pyDOE import lhs
from typing import Sequence
import json
from tensorflow_probability.substrates import jax as tfp

In [None]:
# Run on the first GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from jax.extend.backend import get_backend
print(get_backend().platform)

# Hyperparameters

In [None]:
architecture_list = [[20, 20, 20, 1], [100, 100, 100, 1], [500, 500, 500, 1],
                     [20, 20, 20, 20, 1], [100, 100, 100, 100, 1], [500, 500, 500, 500, 1],
                     [20, 20, 20, 20, 20, 1], [100, 100, 100, 100, 100, 1], [500, 500, 500, 500, 500, 1],
                     [20, 20, 20, 20, 20, 20, 1], [100, 100, 100, 100, 100, 100, 1],
                     [500, 500, 500, 500, 500, 500, 1], [20, 20, 20, 20, 20, 20, 20, 1],
                     [100, 100, 100, 100, 100, 100, 100, 1],] # NN architecture list
lr = 1e-4 # learning rate
num_epochs = 50000 # number of training epochs
eps = 0.01

# NN Architecture

In [None]:
# Define NN architecture
class PDESolution(flax.linen.Module): # inherit from Module class
    features: Sequence[int] # dataclass (e.g. [10, 20, 1])

    @flax.linen.compact # a decorator to define the model in more concise and readable way
    def __call__(self, x): # __call__: makes an object callable, which enables you to use instances of the class like functions
        for feature in self.features[:-1]:
            x = flax.linen.tanh(flax.linen.Dense(feature)(x)) # initialize weights(parameters) w/ default initializer
        # Final Dense layer
        x = flax.linen.Dense(self.features[-1], kernel_init = flax.linen.initializers.glorot_uniform())(x) # initialize weights(parameters) w/ Glorot uniform for stable training
        return x

# Loss Function

In [None]:
# PDE residual
@partial(jax.vmap, in_axes = (None, 0, 0, None), out_axes = 0)
@partial(jax.jit, static_argnums = (0,))
def residual(u, t, x, eps):
    u_t = jax.jvp(u, (t, x), (1., 0.))[1] # partial derivative w.r.t t = directional derivative along <1, 0>
    u_xx = jax.hessian(u, argnums = 1)(t, x) # differentiate w.r.t argument 1(x)
    return u_t - eps*u_xx + (2/eps)*u(t,x)*(1-u(t,x))*(1-2*u(t,x))

# Inital condition
@partial(jax.vmap, in_axes=0) # vectorized over "xs"
def u_init(xs):
    return np.array([0.25*(np.sin(2*np.pi*xs) + 0.25*np.sin(16*np.pi*xs)) + 0.5])

# Loss functionals
@jax.jit
def pde_residual(params, points):
    return np.mean(residual(lambda t, x: model.apply(params, np.stack((t, x))), points[:, 0], points[:, 1], eps) ** 2) # Mean Squared Error

@jax.jit(jax.jit, static_argnums=0)
def init_residual(u_init, params, xs):
    lhs = model.apply(params, np.stack((np.zeros_like(xs[:, 0]), xs[:, 0]), axis=1))
    rhs = u_init(xs[:, 0])
    return np.mean((lhs - rhs) ** 2)

@jax.jit
def boundary_residual(params, ts): # u(t, 0) = u(t, 1)
    return np.mean((model.apply(params, np.stack((ts[:, 0], np.zeros_like(ts[:, 0])), axis = 1)) -
                    model.apply(params, np.stack((ts[:, 0], np.ones_like(ts[:, 0])), axis = 1)))**2) # ts : (n, 1)

# Training Loop

In [None]:
# Define Training Step
@partial(jax.jit, static_argnums = (1,))
def training_step_ini(params, opt, opt_state, key):
    """
    In the PDE problem with initial condition, training NN with I.C. first can be helpful because
    - it does not need so many training steps(epochs). (initial guess is already close to the solution)
    - It stabilizes subsequent training, because the solution is already near the correct initial profile.
    
    Args:
        params: model parameters
        opt: optimizer
        opt_state: optimizer state
        key: random key for sampling
    """
    lb = onp.array([0., 0.]) # lower bound
    ub = onp.array([0.05, 1.]) # upper bound
    # scale the samples from [0, 1] to [lb, ub]
    domain_points = lb + (ub - lb) * lhs(2, 20000) # latin hypercube sampling 20000 points within (ti, xi) ∈ [0, 0.05] × [0, 1]
    boundary_points = lb[0] + (ub[0] - lb[0]) * lhs(1, 250) # latin hypercube sampling 250 points within ti ∈ [0, 0.05]
    init_points = lb[1] + (ub[1] - lb[1]) * lhs(1, 500) # latin hypercube sampling 500 points within xi ∈ [0, 1]

    loss_val, grad = jax.value_and_grad(lambda params: init_residual(u_init,params, init_points))(params)
    update, opt_state = opt.update(grad, opt_state, params) # update using "grad"
    params = optax.apply_updates(params, update) # apply updates to "params"
    return params, opt_state, key, loss_val

@partial(jax.jit, static_argnums = (1,))
def training_step(params, opt, opt_state, key):
    # (same as above)
    lb = onp.array([0., 0.])
    ub = onp.array([0.05, 1.])
    domain_points = lb + (ub - lb) * lhs(2, 20000)
    boundary_points = lb[0] + (ub[0] - lb[0]) * lhs(1, 250)
    init_points = lb[1] + (ub[1] - lb[1]) * lhs(1, 500)

    # Weight factor 1000 for I.C. loss term heuristically gives the best result
    # (total loss will be more sensitive to I.C. loss than other loss terms)
    loss_val, grad = jax.value_and_grad(lambda params: pde_residual(params, domain_points) + 
                                        1000 * init_residual(u_init,params, init_points) +
                                        boundary_residual(params, boundary_points))(params)
    update, opt_state = opt.update(grad, opt_state, params)
    params = optax.apply_updates(params, update)
    return params, opt_state, key, loss_val

# Training loop
def train_loop(params, adam, opt_state, key):
    losses = []
    for i in range(7000):
        params, opt_state, key, loss_val = training_step_ini(params, adam, opt_state, key)
        losses.append(loss_val.item())
    
    for _ in range(num_epochs): # "_" is used because the variable is not used in for loop
        params, opt_state, key, loss_val = training_step(params, adam, opt_state, key)
        losses.append(loss_val.item())
    return losses, params, opt_state, key, loss_val # return final values

# Helper Functions for L-BFGS Wrapper

In [None]:
# L-BFGS requires the parameters to be a single flattened array!
def concat_params(params): # flatten the parameters
    params, tree = jax.tree_util.tree_flatten(params) # "params" is flattened to a list of arrays
    # "tree" describes the original structure of parameters. It allows to reconstruct the original nested format later.
    shapes = [param.shape for param in params] # shape of each array in the "params" list
    return np.concatenate([param.reshape(-1) for param in params]), tree, shapes # concat to single 1D array

def unconcat_params(params, tree, shapes): # unflatten the parameters
    split_vec = np.split(params, np.cumsum([np.prod(shape) for shape in shapes])) # "np.cumsum" figures out the boundaries where to split the flattened "params"
    split_vec = [vec.reshape(*shape) for vec, shape in zip(split_vec, shapes)] # reshape slices of vector ("*" unpack the tuple into individual arguments)
    return jax.tree_util.tree_unflatten(tree, split_vec)

# Evaluation Points & Ground Truth Solutions

In [None]:
# Load evaluation points
with open('1D_Allen-Cahn_eval_points.json', 'r') as f:
    eval_points = json.load(f) # pre-specified evaluation points (different from training points) for measuring error.
mesh_coord = eval_points['mesh_coord']['0']
dt_coord = eval_points['dt_coord']['0']

# Load Ground Truth solutions at evaluation points
with open('1D_Allen-Cahn_eval_solution.json', 'r') as f:
    eval_solution = json.load(f) # pre-specified evaluation solution (FEM solution) for measuring error.
eval_solution = np.asarray(eval_solution)


class ImportData:
    def get_FEM_results(self):
        with open(os.path.join(self.save_dir,'eval_solution_mat.json'), 'r') as f:
            eval_solution_mat= json.load(f)
        eval_solution_mat = jnp.asarray(eval_solution_mat)
        return eval_solution_mat

def get_relative_error(u,v):
        l2 = jnp.linalg.norm(u - v)/jnp.linalg.norm(u)
        return l2

class CompareGT:

    def get_FEM_comparison(mesh_coord,dt_coord,FEM,model,tuned_params):
        dom_mesh = jnp.asarray(mesh_coord).squeeze()
        dom_mesh_ = jnp.tile(dom_mesh,len(dt_coord)) #repeating the dom_mesh, dt_coord_100.shape-times
        dom_ts = jnp.repeat(jnp.array(dt_coord),len(mesh_coord)) #repeating ts, len(mesh_coord)-times
        domain_pt = jnp.stack((dom_ts,dom_mesh_),axis=1) #stacking them together, meaning for each mesh coordinate we look at every time instance in ts
        
        start_time = time.time()
        approx = jax.block_until_ready(model.apply(tuned_params, domain_pt).squeeze())
        times_eval = time.time()-start_time
        
        approx = approx.reshape(len(dt_coord),len(mesh_coord)) 
        l2 = []

        for l in range(len(dt_coord)):
            l2.append(get_relative_error(FEM[int(l)],approx[int(l),:]))

        return l2, times_eval, approx, FEM, domain_pt

# ---------------------------------------------------------------------
GTloader = ImportData('./Eval_Points/1D_Allen-Cahn/')
FEM = GTloader.get_FEM_results()

# Train PINN & Approximate Solution

In [None]:
# Containers for the results
u_results, times_adam, times_lbfgs, times_total, times_eval, l2_rel, var, arch\
    = dict({}), dict({}), dict({}), dict({}), dict({}), dict({}), dict({}), dict({})

count = 0 # architecture index
for architecture in architecture_list:
    print('Architecture : %s' %architecture)
    times_adam_temp = [] # containers for 10 times training results
    times_lbfgs_temp = []
    times_total_temp = []
    times_eval_temp = []
    accuracy_temp = []
    for _ in range(10): # loop over 10 training runs
        # Initialize Model
        model = PDESolution(architecture)
        key, key2 = jax.random.split(jax.random.PRNGKey(0)) # create two keys for independent use
        batch_dim = 4 # it is just for parameter initialization (can be any value)
        feature_dim = 2 # dimension of input point (t, x)
        params = model.init(key, np.ones((batch_dim, feature_dim))) # params(dict) : weights and biases initialized randomly

        # Initialize Optimizer
        adam = optax.adam(learning_rate = lr) #
        opt_state = adam.init(params) # opt_state : internal states of the Adam optimizer

        # Start Training with Adam Optimizer
        start_time = time.time()
        losses, params, opt_state, key, loss_val = jax.block_until_ready(train_loop(params, adam, opt_state, key))
        adam_time = time.time() - start_time
        times_adam_temp.append(adam_time)
        print('Adam Training Time : %f secs' %adam_time)

        # Generate data
        # (same as above)
        lb = onp.array([0., 0.])
        ub = onp.array([0.05, 1.])
        domain_points = lb + (ub - lb) * lhs(2, 20000)
        boundary_points = lb[0] + (ub[0] - lb[0]) * lhs(1, 250)
        init_points = lb[1] + (ub[1] - lb[1]) * lhs(1, 500)

        init_point, tree, shapes = concat_params(params)

        # L-BFGS Optimization
        print('Starting L-BFGS Optimization')
        start_time2 = time.time()
        results = tfp.optimizer.lbfgs_minimize(jax.value_and_grad(lambda params:
                                                                  pde_residual(unconcat_params(params, tree, shapes), domain_points) +
                                                                  1000 * init_residual(u_init, unconcat_params(params, tree, shapes), init_points) +
                                                                  boundary_residual(unconcat_params(params, tree, shapes), boundary_points)),
                                               init_point, max_iterations = 50000,
                                               num_correction_pairs = 50, # number of past updates to use for the approximation of the Hessian inverse.
                                               f_relative_tolerance = 1.0*np.finfo(float).eps) # stopping criterion
        lbfgs_time = time.time() - start_time2
        times_lbfgs_temp.append(lbfgs_time)
        times_total_temp.append(adam_time + lbfgs_time)

        # Comparison to Ground Truth
        tuned_params = unconcat_params(results.position, tree, shapes)

        # -------------------------------------------------------
        start_time3 = time.time()
        u_approx = jax.block_until_ready(model.apply(tuned_params, domain_points).squeeze()) # pass the "domain_points" to the trained model
        u_approx = approx.reshape(len(dt_coord),len(mesh_coord))
        eval_time = time.time() - start_time3
        times_eval_temp.append(eval_time)
        
        def get_FEM_comparison(mesh_coord,dt_coord,FEM,model,tuned_params):
            dom_mesh = jnp.asarray(mesh_coord).squeeze()
            dom_mesh_ = jnp.tile(dom_mesh,len(dt_coord)) #repeating the dom_mesh, dt_coord_100.shape-times
            dom_ts = jnp.repeat(jnp.array(dt_coord),len(mesh_coord)) #repeating ts, len(mesh_coord)-times
            domain_pt = jnp.stack((dom_ts,dom_mesh_),axis=1) #stacking them together, meaning for each mesh coordinate we look at every time instance in ts
            
            start_time = time.time()
            u_approx = jax.block_until_ready(model.apply(tuned_params, domain_pt).squeeze())
            u_approx = approx.reshape(len(dt_coord),len(mesh_coord)) 
            times_eval = time.time()-start_time
            
            
            
            l2 = []
            for l in range(len(dt_coord)):
                l2.append(get_relative_error(FEM[int(l)],approx[int(l),:]))

            return l2, times_eval, approx, FEM, domain_pt


        l2, times_temp, approx, gt_fem, domain_pt = CompareGT.get_FEM_comparison(mesh_coord,dt_coord,FEM,model,tuned_params)
        times_eval_temp.append(times_temp)
        l2_errors.append(jnp.mean(jnp.array(l2)))

        # --------------------------------------------------------
        u_true = (domain_points * np.exp(-domain_points**2)).squeeze() # ground truth
        run_accuracy = (onp.linalg.norm(u_approx - u_true)) / onp.linalg.norm(u_true) # relative L2 error
        accuracy_temp.append(run_accuracy)

    y_gt = u_true.tolist() # for storing into dict
    y_results[count] = u_approx.tolist()
    domain_pts[count] = domain_points.tolist()
    times_adam[count] = onp.mean(times_adam_temp) # mean times across the 10 runs
    times_lbfgs[count] = onp.mean(times_lbfgs_temp)
    times_total[count] = onp.mean(times_total_temp)
    times_eval[count] = onp.mean(times_eval_temp)
    l2_rel[count] = onp.mean(accuracy_temp).tolist()
    var[count] = onp.var(accuracy_temp).tolist() # variance of the error across the 10 runs
    arch[count] = architecture_list[count]
    count += 1

    results = dict({'domain_pts': domain_pts,
                    'y_results': y_results,
                    'y_gt': y_gt})

    evaluation = dict({'arch': arch,
                    'times_adam': times_adam,
                    'times_lbfgs': times_lbfgs,
                    'times_total': times_total,
                    'times_eval': times_eval,
                    'l2_rel': l2_rel,
                    'var': var})

    # Save Results & Evaluation
    save_dir = './1D_Poisson'
    os.makedirs(save_dir, exist_ok = True)

    with open(os.path.join(save_dir, 'PINNs_results.json'), 'w') as f:
        json.dump(results, f)

    with open(os.path.join(save_dir, 'PINNs_evaluation.json'), 'w') as f:
        json.dump(evaluation, f)

    print(json.dumps(evaluation, indent = 4))