In [None]:
# (for Google Colab)
!pip install pyDOE

In [None]:
import jax, flax, optax, time, pickle
import os
import jax.numpy as np
import numpy as onp
from functools import partial
from pyDOE import lhs
from typing import Sequence
import json
from tensorflow_probability.substrates import jax as tfp

In [None]:
# Run on the first GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from jax.extend.backend import get_backend
print(get_backend().platform)

# Hyperparameters

In [None]:
architecture_list = [[1, 1], [2, 1], [5, 1], [10, 1], [20, 1], [40, 1],
                     [5, 5, 1], [10, 10, 1], [20, 20, 1], [40, 40, 1],
                     [5, 5, 5, 1], [10, 10, 10, 1], [20, 20, 20, 1],
                     [40, 40, 40, 1]] # NN architecture list
lr = 1e-4 # learning rate
num_epochs = 15000 # number of training epochs

# NN Architecture

In [None]:
# Define NN architecture
class PDESolution(flax.linen.Module): # inherit from Module class
    features: Sequence[int] # dataclass (e.g. [10, 20, 1])

    @flax.linen.compact # a decorator to define the model in more concise and readable way
    def __call__(self, x): # __call__: makes an object callable, which enables you to use instances of the class like functions
        for feature in features[:-1]:
            x = flax.linen.tanh(flax.linen.Dense(feature)(x))
        # Final Dense layer
        x = flax.linen.Dense(features[-1])(x)

# Loss Function

In [None]:
# Hessian-vector product
# (it is more general approach than gradient, even if it doesn't make a change in this 1D problem)
def hvp(f, primals, tangents):
    return jax.jvp(jax.grad(lambda x: f(x)[0]), primals, tangents)[1]

# PDE residual
@partial(jax.vmap, in_axes = (None, 0), out_axes = 0)
@partial(jax.jit, static_argnums = (0,)) # decorator closest to the function is applied first
def residual(u, x):
    v = np.ones(x.shape)
    lhs = hvp(u, (x,), (v,))
    rhs = (-6*x + 4*x**3) * np.exp(-x ** 2)
    return lhs - rhs

# Loss functionals
@jax.jit
def pde_residual(params, points):
    return np.mean()