In [None]:
# (for Google Colab)
!pip install pyDOE

In [None]:
import jax, flax, optax, time, pickle
import os
import jax.numpy as np
import numpy as onp
from functools import partial
from pyDOE import lhs
from typing import Sequence
import json
from tensorflow_probability.substrates import jax as tfp

In [None]:
# Run on the first GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from jax.extend.backend import get_backend
print(get_backend().platform)

# Hyperparameters

In [None]:
architecture_list = [[1, 1], [2, 1], [5, 1], [10, 1], [20, 1], [40, 1],
                     [5, 5, 1], [10, 10, 1], [20, 20, 1], [40, 40, 1],
                     [5, 5, 5, 1], [10, 10, 10, 1], [20, 20, 20, 1],
                     [40, 40, 40, 1]] # NN architecture list
lr = 1e-4 # learning rate
num_epochs = 15000 # number of training epochs

# NN Architecture

In [None]:
# Define NN architecture
class PDESolution(flax.linen.Module): # inherit from Module class
    features: Sequence[int] # dataclass (e.g. [10, 20, 1])

    @flax.linen.compact # a decorator to define the model in more concise and readable way
    def __call__(self, x): # __call__: makes an object callable, which enables you to use instances of the class like functions
        for feature in features[:-1]:
            x = flax.linen.tanh(flax.linen.Dense(feature)(x))
        # Final Dense layer
        x = flax.linen.Dense(features[-1])(x)

# Loss Function

In [None]:
# Hessian-vector product
# (it is more general approach than gradient, even if it doesn't make a change in this 1D problem)
def hvp(f, primals, tangents):
    return jax.jvp(jax.grad(lambda x: f(x)[0]), primals, tangents)[1]

# PDE residual
@partial(jax.vmap, in_axes = (None, 0), out_axes = 0)
@partial(jax.jit, static_argnums = (0,)) # decorator closest to the function is applied first
def residual(u, x):
    v = np.ones(x.shape)
    lhs = hvp(u, (x,), (v,)) # "tangents" arg is not optional!
    rhs = (-6*x + 4*x**3) * np.exp(-x ** 2)
    return lhs - rhs

# Loss functionals
@jax.jit
def pde_residual(params, points):
    return np.mean()

# Training Loop

In [None]:
# Define Training Step
@partial(jax.jit, static_argnums = (1,))
def training_step(params, opt, opt_state, key):
    """
    Args:
        params: model parameters
        opt: optimizer
        opt_state: optimizer state
        key: random key for sampling
    """
    lb = onp.array(0.) # lower bound
    ub = onp.array(1.) # upper bound
    domain_xs = lb + (ub - lb) * lhs(1, 256) # latin hypercube sampling 256 points within [0, 1]
    boundary_xs = lb + (ub - lb) * lhs(1, 2) # scaless the samples from [0, 1] to [lb, ub]

    loss_val, grad = jax.value_and_grad(lambda params: pde_residual(params, domain_xs) +
                                        boundary_residual0(params, boundary_xs) +
                                        boundary_residual1(params, boundary_xs))(params)
    update, opt_state = opt.update(grad, opt_state, params) # update using "grad"
    params = optax.apply_updates(params, update) # apply updates to "params"
    return params, opt_state, key, loss_val

# Training loop
def train_loop(params, adam, opt_state, key):
    losses = []
    for _ in range(num_epochs): # "_" is used because the variable is not used in for loop
        params, opt_state, key, loss_val = training_step(params, adam, opt_state, key)
        losses.append(loss_val.item())
    return losses, params, opt_state, key, loss_val # return final values

# Helper Functions for L-BFGS Wrapper

In [None]:
# L-BFGS requires the parameters to be a single flattened array!
def concat_params(params): # flatten the parameters
    params, tree = jax.tree_util.tree_flatten(params) # "params" is flattened to a list of arrays
    # "tree" describes the original structure of parameters. It allows to reconstruct the original nested format later.
    shapes = [param.shape for param in params] # shape of each array in the "params" list
    return onp.concatenate([param.reshape(-1) for param in params]), tree, shapes # concat to single 1D array

def unconcat_params(params, tree, shapes): # unflatten the parameters
    split_vec = onp.split(params, onp.cumsum([onp.prod(shape) for shape in shapes])) # "onp.cumsum" figures out the boundaries where to split the flattened "params"
    split_vec = [vec.reshape(*shape) for vec, shape in zip(split_vec, shapes)] # reshape slices of vector ("*" unpack the tuple into individual arguments)
    return jax.tree_util.tree_unflatten(tree, split_vec)

# Train PINN