In [None]:
# (for Google Colab)
!pip install pyDOE

In [None]:
import os, time, pickle
import jax, flax, optax
import jax.numpy as np
import numpy as onp
from functools import partial
from pyDOE import lhs
from typing import Sequence, Callable
import json
from tensorflow_probability.substrates import jax as tfp
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [None]:
# Must set CUDA_VISIBLE_DEVICES before importing JAX or any other library that initializes GPUs.
# Otherwise, the environment variable change might be ignored.
# "0, 1": first two GPUs / "": no GPU (CPU instead)

# Run on the first GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from jax.extend.backend import get_backend
print(get_backend().platform)

# Hyperparameters

In [None]:
architecture_list = [[20, 2], [60, 2], [20, 20, 2], [60, 60, 2],
                     [20, 20, 20, 2], [60, 60, 60, 2],
                     [20, 20, 20, 20, 2], [60, 60, 60, 60, 2],
                     [20, 20, 20, 20, 20, 2], [60, 60, 60, 60, 60, 2],
                     [120, 120, 120, 120, 120, 2]] # NN architecture list
lr = 1e-3 # learning rate
num_epochs = 20000 # number of training epochs

# NN Architecture

In [None]:
# Define NN architecture
class PDESolution(flax.linen.Module): # inherit from Module class
    # One behavior of "flax.linen.Module" is to assign the provided argument to the "self.features"
    features: Sequence[int] # dataclass (e.g. [10, 20, 2])

    @flax.linen.compact # a decorator to define the model in more concise and readable way
    def __call__(self, x): # __call__: makes an object callable, which enables you to use instances of the class like functions
        for feature in self.features[:-1]:
            x = flax.linen.tanh(flax.linen.Dense(feature)(x))
        # Final Dense layer
        x = flax.linen.Dense(self.features[-1])(x)
        return x

# Stress Tensor

# Loss Function

In [None]:
# PDE residual
@partial(jax.vmap, in_axes = (None, 0, 0), out_axes = 0)
@partial(jax.jit, static_argnums = (0,)) # decorator closest to the function is applied first
def residual1(u, x, y):
    def stress(x, y): # tensor
        E = 2.35e3
        nu = 0.33
        mu = E / (2. * (1 + nu))
        lmbda = E * nu / ((1 + nu) * (1 - 2*nu))
        u_grad = jax.jacobian(u)(x, y)
        epsilon = 0.5 * (u_grad + u_grad.T) # (2, 2)
        # Stress-Strain relationship
        sigma = lmbda * np.trace(epsilon) * np.eye(2) + 2 * mu * epsilon 
        return sigma

    jac_wrt_x = jax.jacobian(stress, argnums=0)(x, y)
    jac_wrt_y = jax.jacobian(stress, argnums=1)(x, y)
    dsigma11_dx = jac_wrt_x[0, 0]
    dsigma12_dy = jac_wrt_y[0, 1]
    lhs = dsigma11_dx + dsigma12_dy
    rhs = 0.
    return lhs - rhs

@partial(jax.vmap, in_axes = (None, 0, 0), out_axes = 0)
@partial(jax.jit, static_argnums = (0,))
def residual2(u, x, y):
    def stress(x, y): # tensor
        E = 2.35e3
        nu = 0.33
        mu = E / (2. * (1 + nu))
        lmbda = E * nu / ((1 + nu) * (1 - 2*nu))
        u_grad = jax.jacobian(u)(x, y)
        epsilon = 0.5 * (u_grad + u_grad.T) # (2, 2)
        # Stress-Strain relationship
        sigma = lmbda * np.trace(epsilon) * np.eye(2) + 2 * mu * epsilon 
        return sigma

    jac_wrt_x = jax.jacobian(stress, argnums=0)(x, y)
    jac_wrt_y = jax.jacobian(stress, argnums=1)(x, y)
    dsigma21_dx = jac_wrt_x[1, 0]
    dsigma22_dy = jac_wrt_y[1, 1]
    lhs = dsigma21_dx + dsigma22_dy
    rhs = 0.
    return lhs - rhs

# Loss functionals
@jax.jit
def pde_residual1(params, points):
    return np.mean(residual1(lambda x, y: model.apply(params, np.stack((x, y))), points[:, 0], points[:, 1]) ** 2) # Mean Squared Error

@jax.jit
def pde_residual2(params, points):
    return np.mean(residual2(lambda x, y: model.apply(params, np.stack((x, y))), points[:, 0], points[:, 1]) ** 2) # Mean Squared Error

@jax.jit
def dirichlet_residual1(params, points): # Γ_u = -1 at x = 0
    return np.mean((model.apply(params, np.stack((np.zeros_like(points[:,1]), points[:,1]), axis=1))[:, 0] + 1.) ** 2)

@jax.jit
def dirichlet_residual2(params, points): # Γ_u = 1 at x = 40
    return np.mean((model.apply(params, np.stack((40 * np.ones_like(points[:,0]), points[:,1]), axis=1))[:, 0] - 1.) ** 2)


# Training Loop

In [None]:
# Define Training Step
@partial(jax.jit, static_argnums = (1, 4))
def training_step(params, opt, opt_state, key, d_neumann: Callable = None):
    """
    Args:
        params: network + geometric parameters
        opt: optimizer
        opt_state: optimizer state
        key: random key for sampling
    """
    # Generate random samples ("jax.grad" cannot receive the function with randomness)
    boundary_samples = lhs(2, 250)
    angle_samples1 = lhs(1, 150)
    angle_samples2 = lhs(1, 2000)
    radius_samples = lhs(1, 2000)
    
    # Total loss functional
    def loss_total(params, boundary_samples, angle_samples1, angle_samples2, radius_samples):
        net_params = params['network']
        geo_params = params['geometry']
        x_cen, y_cen, a, b, gamma = geo_params # unpack geometric parameters

        # Define the domain & external boundary points
        lb = np.array([0., 0.]) # lower bound
        ub = np.array([40., 40.]) # upper bound
        boundary_points = lb + (ub - lb) * boundary_samples
        boundary_points1 = np.column_stack((np.zeros_like(boundary_points[:, 0]), boundary_points[:, 1])) # latin hypercube sampling 150 points at x = 0
        boundary_points2 = np.column_stack((boundary_points[:, 0], 40 * np.ones_like(boundary_points[:, 1])))

        # Define the boundary points of the void
        def ellipse(x_cen, y_cen, a, b, gamma, rotated_ang):
            x = a * np.cos(rotated_ang) * np.cos(gamma) - b * np.sin(rotated_ang) * np.sin(gamma) + x_cen
            y = a * np.cos(rotated_ang) * np.sin(gamma) + b * np.sin(rotated_ang) * np.cos(gamma) + y_cen
            return np.array([x, y]).T
        random_angle1 = 2 * np.pi * angle_samples1
        void_boundary_points = ellipse(x_cen, y_cen, a, b, gamma, random_angle1)

        # Define the domain points
        random_angle2 = 2 * np.pi * angle_samples2
        random_radius = np.sqrt(radius_samples)
        domain_points = ellipse(x_cen, y_cen, a * random_radius, b * random_radius, gamma, random_angle2)
        
        # Define the loss function
        loss_pde1 = pde_residual1({'params': net_params}, domain_points) # parameters to be used in "model.apply" should be in the dict with 'params' key.
        loss_pde2 = pde_residual2({'params': net_params}, domain_points)
        loss_dirichlet1 = dirichlet_residual1({'params': net_params}, boundary_points1)
        loss_dirichlet2 = dirichlet_residual2({'params': net_params}, boundary_points2)
        # loss_neumann = neumann_residual(neumann_derivatives, net_params, void_boundary_points)
        return loss_pde1 + loss_pde2 + loss_dirichlet1 + loss_dirichlet2

    # Evaluate the loss function and its gradient
    loss_val, grad = jax.value_and_grad(lambda x: loss_total(x, boundary_samples, angle_samples1, angle_samples2, radius_samples))(params)

    # Update model parameters
    update, opt_state = opt.update(grad, opt_state, params) # update using "grad"
    params = optax.apply_updates(params, update) # apply updates to "params"
    return params, opt_state, key, loss_val

# Training loop
def train_loop(params, adam, opt_state, key, d_neumann: Callable = None):
    losses = []
    for _ in range(num_epochs): # "_" is used because the variable is not used in for loop
        params, opt_state, key, loss_val = training_step(params, adam, opt_state, key)
        losses.append(loss_val.item())
    return losses, params, opt_state, key, loss_val # return final values


# Helper Functions for L-BFGS Wrapper

In [None]:
# L-BFGS requires the parameters to be a single flattened array!
def concat_params(params): # flatten the parameters
    params, tree = jax.tree_util.tree_flatten(params) # "params" is flattened to a list of arrays
    shapes = [param.shape for param in params] # shape of each array in the "params" list
    return np.concatenate([param.reshape(-1) for param in params]), tree, shapes # concat to single 1D array

def unconcat_params(params, tree, shapes): # unflatten the parameters
    split_vec = np.split(params, onp.cumsum([onp.prod(shape, dtype=onp.int32) for shape in shapes])) # "np.cumsum" figures out the boundaries where to split the flattened "params"
    split_vec = [vec.reshape(*shape) for vec, shape in zip(split_vec, shapes)] # reshape slices of vector ("*" unpack the tuple into individual arguments)
    return jax.tree_util.tree_unflatten(tree, split_vec)

# Evaluation Points & Ground Truth Solutions

In [None]:
# Create evaluation points
x_points = np.linspace(0, 40, 200 + 1)
y_points = np.linspace(0, 40, 200 + 1)
X, Y = np.meshgrid(x_points, y_points, indexing='ij')
X, Y = X.flatten(), Y.flatten()
eval_points = np.array([X, Y]).T

# Load evaluation solutions (by Ground Truth FEM)
with open('data/eval_solutions.json', 'r') as f:
    eval_sol = json.load(f)

# Ground truth soution (200 x 200 cells)
u_true = np.array(eval_sol) # shape: (n_points, 2)

print("Evaluation points: ", np.array(eval_points).shape)

# Train PINN & Approximate Solution

In [None]:
# Containers for the results
y_results, domain_pts, times_adam, times_lbfgs, times_total, times_eval, l2_rel, var, arch\
    = dict({}), dict({}), dict({}), dict({}), dict({}), dict({}), dict({}), dict({}), dict({})

count = 0 # architecture index
for architecture in architecture_list :
    print('Architecture : %s' %architecture)
    times_adam_temp = [] # containers for 10 times training results
    times_lbfgs_temp = []
    times_total_temp = []
    times_eval_temp = []
    accuracy_temp = []
    domain_pts_temp = []
    for _ in range(10): # loop over 10 training runs
        # Initialize Model
        model = PDESolution(architecture)
        key, key2 = jax.random.split(jax.random.PRNGKey(0)) # create two keys for independent use
        batch_dim = 8 # dummy number(can be any value). Batch dimension will be reshaped after running "model.init".
        feature_dim = 2 # dimension of input point (x, y coord)
        # "flax.linen.Module.init(key, dummy_input)" triggers "__call__" (like ".apply") and infers the shape of weights & biases
        # where "key" is JAX-made PRNG key and "dummy_input" is the dummy input data.
        net_params = model.init(key, np.ones((batch_dim, feature_dim)))['params'] # net_params: weights & biases initialized randomly
        # geo_params = np.array([25., 30., 5.0, 2.0, np.pi/3]) # geo_params: x, y, a, b, gamma initialized by 1.0
        geo_params = (25., 30., 5.0, 2.0, float(onp.pi/3)) # "float(onp)" for preventing the trouble when jax object is inside the tuple.
        params: dict = {'network': net_params, 'geometry': geo_params}
        masks: dict = {'network': True, 'geometry': (False, False, False, False, False)}
        # False : freeze parameter (gradient will be zeroed out)
        # True  : train parameter (gradient will be used)

        # Initialize Adam Optimizer
        adam = optax.adam(learning_rate = lr)
        adam_masked = optax.masked(adam, masks)
        opt_state = adam_masked.init(params) # opt_state : internal states of the Adam optimizer

        # Start Training with Adam Optimizer
        start_time = time.time()
        losses, params, opt_state, key, loss_val = jax.block_until_ready(train_loop(params, adam_masked, opt_state, key))
        adam_time = time.time() - start_time
        times_adam_temp.append(adam_time)
        print('Adam Training Time : %f secs' %adam_time)

        # Generate random samples for L-BFGS optimization
        lbfgs_boundary_samples = lhs(2, 250)
        lbfgs_angle_samples1 = lhs(1, 150)
        lbfgs_angle_samples2 = lhs(1, 2000)
        lbfgs_radius_samples = lhs(1, 2000)

        def ellipse(x_cen, y_cen, a, b, gamma, rotated_ang):
            x = a * np.cos(rotated_ang) * np.cos(gamma) - b * np.sin(rotated_ang) * np.sin(gamma) + x_cen
            y = a * np.cos(rotated_ang) * np.sin(gamma) + b * np.sin(rotated_ang) * np.cos(gamma) + y_cen
            return np.array([x, y]).T

        # Total loss functional
        def loss_total(params, boundary_samples, angle_samples1, angle_samples2, radius_samples):
            net_params = params['network']
            geo_params = params['geometry']
            x_cen, y_cen, a, b, gamma = geo_params # unpack geometric parameters

            # Define the domain & external boundary points
            lb = np.array([0., 0.]) # lower bound
            ub = np.array([40., 40.]) # upper bound
            boundary_points = lb + (ub - lb) * boundary_samples
            boundary_points1 = np.column_stack((np.zeros_like(boundary_points[:, 0]), boundary_points[:, 1])) # latin hypercube sampling 150 points at x = 0
            boundary_points2 = np.column_stack((boundary_points[:, 0], 40 * np.ones_like(boundary_points[:, 1])))

            # Define the boundary points of the void
            random_angle1 = 2 * np.pi * angle_samples1
            void_boundary_points = ellipse(x_cen, y_cen, a, b, gamma, random_angle1)

            # Define the domain points
            random_angle2 = 2 * np.pi * angle_samples2
            random_radius = np.sqrt(radius_samples)
            domain_points = ellipse(x_cen, y_cen, a * random_radius, b * random_radius, gamma, random_angle2)
            
            # Define the loss function
            loss_pde1 = pde_residual1({'params': net_params}, domain_points) # parameters to be used in "model.apply" should be in the dict with 'params' key.
            loss_pde2 = pde_residual2({'params': net_params}, domain_points)
            loss_dirichlet1 = dirichlet_residual1({'params': net_params}, boundary_points1)
            loss_dirichlet2 = dirichlet_residual2({'params': net_params}, boundary_points2)
            # loss_neumann = neumann_residual(neumann_derivatives, net_params, void_boundary_points)
            return loss_pde1 + loss_pde2 + loss_dirichlet1 + loss_dirichlet2

        # (Freeze geometric parameters)
        def loss_total_frozen(params, boundary_samples, angle_samples1, angle_samples2, radius_samples):
            loss_val, grad = jax.value_and_grad(lambda x: loss_total(x, boundary_samples, angle_samples1, 
                                                                     angle_samples2, radius_samples))(params)            
            # Create a PyTree with the same structure as 'geometry', but filled with zeros
            grad_frozen = jax.tree_util.tree_map(lambda g: np.zeros_like(g), grad['geometry'])
            # Overwrite the geometry gradients with frozen gradients
            grad['geometry'] = grad_frozen
            # Flatten the gradients
            flat_grad_frozen, _, _ = concat_params(grad)
            return loss_val, flat_grad_frozen

        init_point, tree, shapes = concat_params(params)

        # L-BFGS Optimization
        print('Starting L-BFGS Optimization')
        start_time2 = time.time()
        # results = tfp.optimizer.lbfgs_minimize(jax.value_and_grad(lambda x: loss_total_frozen(unconcat_params(x, tree, shapes), 
        #                                                                                lbfgs_boundary_samples, lbfgs_angle_samples1, 
        #                                                                                lbfgs_angle_samples2, lbfgs_radius_samples)),
        #                                        init_point,
        #                                        max_iterations = 50000,
        #                                        num_correction_pairs = 50, # number of past updates to use for the approximation of the Hessian inverse.
        #                                        f_relative_tolerance = 1.0*np.finfo(float).eps) # stopping criterion
        results = tfp.optimizer.lbfgs_minimize(lambda x: loss_total_frozen(unconcat_params(x, tree, shapes), 
                                                                           lbfgs_boundary_samples, lbfgs_angle_samples1, 
                                                                           lbfgs_angle_samples2, lbfgs_radius_samples),
                                        init_point,
                                        max_iterations = 50000,
                                        num_correction_pairs = 50, # number of past updates to use for the approximation of the Hessian inverse.
                                        f_relative_tolerance = 1.0*np.finfo(float).eps) # stopping criterion
        lbfgs_time = time.time() - start_time2
        times_lbfgs_temp.append(lbfgs_time)
        times_total_temp.append(adam_time + lbfgs_time)
        print("L-BFGS training time : %.3f secs" % lbfgs_time)

        # Comparison to Ground Truth
        tuned_params = unconcat_params(results.position, tree, shapes)

        start_time3 = time.time()
        # Pass the "eval_points" to the trained model
        u_approx = jax.block_until_ready(model.apply({'params': tuned_params['network']}, np.stack((eval_points[:, 0], eval_points[:, 1]), axis=1)))
        eval_time = time.time() - start_time3
        times_eval_temp.append(eval_time)

        # MSE
        run_accuracy = (onp.linalg.norm(u_approx - u_true)) / onp.linalg.norm(u_true) # relative L2 error
        accuracy_temp.append(run_accuracy)

        # Save domain points (last generated domain points)
        x_cen, y_cen, a, b, gamma = tuned_params['geometry']
        random_angle2 = 2 * np.pi * lbfgs_angle_samples2
        random_radius = np.sqrt(lbfgs_radius_samples)
        final_domain_points = ellipse(x_cen, y_cen, a * random_radius, b * random_radius, gamma, random_angle2)

    y_gt = u_true.tolist() # for storing into dict
    y_results[count] = u_approx.tolist()
    domain_pts[count] = final_domain_points.tolist()
    times_adam[count] = onp.mean(times_adam_temp) # mean times across the 10 runs
    times_lbfgs[count] = onp.mean(times_lbfgs_temp)
    times_total[count] = onp.mean(times_total_temp)
    times_eval[count] = onp.mean(times_eval_temp)
    l2_rel[count] = onp.mean(accuracy_temp).tolist()
    var[count] = onp.var(accuracy_temp).tolist() # variance of the error across the 10 runs
    arch[count] = architecture_list[count]
    count += 1

    results = dict({'domain_pts': domain_pts,
                    'y_results': y_results,
                    'y_gt': y_gt})

    evaluation = dict({'arch': arch,
                       'times_adam': times_adam,
                       'times_lbfgs': times_lbfgs,
                       'times_total': times_total,
                       'times_eval': times_eval,
                       'l2_rel': l2_rel,
                       'var': var})

    sol_json = 'PINNs_results.json'
    with open(sol_json, "w") as f:
        json.dump(results, f)

    eval_json = 'PINNs_evaluation.json'
    with open(eval_json, "w") as f:
        json.dump(evaluation, f)

# Plots

In [None]:
with open('PINNs_results.json', 'r') as f:
    data_results = json.load(f)

with open('PINNs_evaluation.json', 'r') as f:
    data_eval = json.load(f)

# Slice results data
domain_pts = data_results['domain_pts']
y_results = data_results['y_results']
y_gt = data_results['y_gt']

arch = data_eval['arch']
times_adam = data_eval['times_adam']
times_lbfgs = data_eval['times_lbfgs']
times_total = data_eval['times_total']
times_eval = data_eval['times_eval']
l2_rel = data_eval['l2_rel']
var = data_eval['var']

# Pick one architecture
arch_idx = '8'
arch_indices = sorted([int(k) for k in data_eval['arch'].keys()])
arch_indices_str = [str(i) for i in arch_indices]
# For instance, pick the first architecture index in data_results['y_results']
# example_key = list(data_results['y_results'].keys())[0]  # first architecture
domain_points = np.array(data_results['domain_pts'][arch_idx])  # shape (N, 2)
u_approx   = np.array(data_results['y_results'][arch_idx])  # shape (N,)
u_exact    = np.array(data_results['y_gt'])          # shape (N,)

# Unpack X, Y
X = domain_points[:, 0]
Y = domain_points[:, 1]

# Approximate solution (left)
fig = plt.figure(figsize=(12, 5))
ax1 = fig.add_subplot(1, 2, 1)
sc1 = ax1.scatter(X, Y, u_approx, c=u_approx, cmap='viridis')
sc1 = ax1.tricontourf(
    X, Y, u_approx, 
    levels=50, 
    cmap='viridis'
)
ax1.set_title(f"PINN Approx. Solution (Arch index={arch_idx})")
ax1.set_xlabel("x")
ax1.set_ylabel("y")
fig.colorbar(sc1, ax=ax1, shrink=0.5)

# Exact solution (right)
ax2 = fig.add_subplot(1, 2, 2)
sc2 = ax2.tricontourf(
    X, Y, u_exact, 
    levels=50, 
    cmap='viridis'
)
ax2.set_title("Exact Solution")
ax2.set_xlabel("x")
ax2.set_ylabel("y")
fig.colorbar(sc2, ax=ax2, shrink=0.5)

plt.tight_layout()
plt.savefig("2D_Poisson_PINNs_approx_vs_exact.png", dpi=300, bbox_inches='tight')
plt.show()