In [None]:
# (for Google Colab)
!pip install pyDOE

In [None]:
import jax, flax, optax, time, pickle
import os
import jax.numpy as np
import numpy as onp
from functools import partial
from pyDOE import lhs
from typing import Sequence
import json
from tensorflow_probability.substrates import jax as tfp

In [None]:
# Run on the first GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from jax.extend.backend import get_backend
print(get_backend().platform)

# Hyperparameters

In [None]:
architecture_list = [[20, 20, 20, 1], [100, 100, 100, 1], [500, 500, 500, 1],
                     [20, 20, 20, 20, 1], [100, 100, 100, 100, 1], [500, 500, 500, 500, 1],
                     [20, 20, 20, 20, 20, 1], [100, 100, 100, 100, 100, 1], [500, 500, 500, 500, 500, 1],
                     [20, 20, 20, 20, 20, 20, 1], [100, 100, 100, 100, 100, 100, 1],
                     [500, 500, 500, 500, 500, 500, 1], [20, 20, 20, 20, 20, 20, 20, 1],
                     [100, 100, 100, 100, 100, 100, 100, 1],] # NN architecture list
lr = 1e-4 # learning rate
num_epochs = 50000 # number of training epochs
eps = 0.01

# NN Architecture

In [None]:
# Define NN architecture
class PDESolution(flax.linen.Module): # inherit from Module class
    features: Sequence[int] # dataclass (e.g. [10, 20, 1])

    @flax.linen.compact # a decorator to define the model in more concise and readable way
    def __call__(self, x): # __call__: makes an object callable, which enables you to use instances of the class like functions
        for feature in self.features[:-1]:
            x = flax.linen.tanh(flax.linen.Dense(feature)(x)) # initialize weights(parameters) w/ default initializer
        # Final Dense layer
        x = flax.linen.Dense(self.features[-1], kernel_init = flax.linen.initializers.glorot_uniform())(x) # initialize weights(parameters) w/ Glorot uniform for stable training
        return x

# Loss Function

In [None]:
# Hessian-vector product
# (it is more general approach than gradient, even if it doesn't make a change in this 1D problem)
def hvp(f, primals, tangents):
    return jax.jvp(jax.grad(lambda x: f(x)[0]), primals, tangents)[1]

# PDE residual
@partial(jax.vmap, in_axes = (None, 0, 0, None), out_axes = 0)
@partial(jax.jit, static_argnums = (0,))
def residual(u, t, x, eps):
    u_t = jax.jvp(u, (t, x), (1., 0.))[1] # partial derivative w.r.t t = directional derivative along <1, 0>
    u_xx = jax.hessian(u, argnums = 1)(t, x)
    return u_t - eps*u_xx + (1/eps)*2*u(t,x)*(1-u(t,x))*(1-2*u(t,x))

# Inital condition
@partial(jax.vmap, in_axes=0) # vectorized over "xs"
def u_init(xs):
    return np.array([0.25*(np.sin(2*np.pi*xs) + 0.25*np.sin(16*np.pi*xs)) + 0.5])

# Training Loop

# Helper Functions for L-BFGS Wrapper

# Train PINN & Approximate Solution