In [2]:
"""
ENGD Optimization.
Five dimensional Poisson equation example. Solution given by

u(x) = sum_{i=1}^5 sin(pi * x_i)

"""
import jax
import jax.numpy as jnp
from jax import random, grad, vmap, jit

from ngrad.domains import Hyperrectangle, HypercubeBoundary
from ngrad.models import mlp, init_params
from ngrad.integrators import EvolutionaryIntegrator
from ngrad.utility import laplace, grid_line_search_factory
from ngrad.inner import model_laplace, model_identity
from ngrad.gram import gram_factory, nat_grad_factory


jax.config.update("jax_enable_x64", True)

# random seed
seed = 0

# domains
dim = 3
interior = Hyperrectangle([(0., 1.) for _ in range(0, dim)])
boundary = HypercubeBoundary(dim)

# integrators
interior_integrator = EvolutionaryIntegrator(interior, key=random.PRNGKey(0), N=4000)
boundary_integrator = EvolutionaryIntegrator(boundary, key= random.PRNGKey(1), N=500)
eval_integrator = EvolutionaryIntegrator(interior, key=random.PRNGKey(0), N= 10 * 4000)

# model
activation = lambda x : jnp.tanh(x)
layer_sizes = [dim, 64, 1]
params = init_params(layer_sizes, random.PRNGKey(seed))
model = mlp(activation)
v_model = vmap(model, (None, 0))

# solution
@jit
def u_star(x):
    return (jnp.sum(jnp.sin(jnp.pi * x)))

v_u_star = vmap(u_star, (0))
v_grad_u_star = vmap(
    lambda x: jnp.dot(grad(u_star)(x), grad(u_star)(x))**0.5, (0)
    )

# rhs
@jit
def f(x):
    return jnp.pi**2 * u_star(x)

# gramians
gram_bdry = gram_factory(
    model = model,
    trafo = model_identity,
    integrator = boundary_integrator
)

gram_laplace = gram_factory(
    model = model,
    trafo = model_laplace,
    integrator = interior_integrator
)

@jit
def gram(params):
    return (gram_bdry(params) + gram_laplace(params))

# natural gradient
nat_grad = nat_grad_factory(gram)

# compute residual
laplace_model = lambda params: laplace(lambda x: model(params, x))
residual = lambda params, x: (laplace_model(params)(x) + f(x))**2.
v_residual =  jit(vmap(residual, (None, 0)))

# loss
@jit
def interior_loss(params):
    return 0.5 * interior_integrator(lambda x: v_residual(params, x))

@jit
def boundary_loss(params):
    return (
        0.5 * boundary_integrator(lambda x: (v_model(params, x) - v_u_star(x))**2.)
    )

@jit
def loss(params):
    return interior_loss(params) + boundary_loss(params)

# set up grid line search
grid = jnp.linspace(0, 30, 31)
steps = 0.5**grid
ls_update = grid_line_search_factory(loss, steps)

# errors
error = lambda x: model(params, x) - u_star(x)
v_error = vmap(error, (0))
v_error_abs_grad = vmap(
        lambda x: jnp.dot(grad(error)(x), grad(error)(x))**0.5
        )

def l2_norm(f, integrator):
    return integrator(lambda x: (f(x))**2)**0.5

norm_sol_l2 = l2_norm(v_u_star, eval_integrator)
norm_sol_h1 = norm_sol_l2 + l2_norm(v_grad_u_star, eval_integrator)    


# training loop
for iteration in range(201):
    interior_grads = grad(interior_loss)(params)
    interior_nat_grads = nat_grad(params, interior_grads)
    
    boundary_grads = grad(boundary_loss)(params)
    boundary_nat_grads = nat_grad(params, boundary_grads)
    
    updates = jax.tree_util.tree_map(lambda i, b: (i + b) / 2, interior_nat_grads, boundary_nat_grads)
    
    params, actual_step = ls_update(params, updates)

    if iteration % 10 == 0 and iteration > 0:
        l2_error = l2_norm(v_error, eval_integrator)
        h1_error = l2_error + l2_norm(v_error_abs_grad, eval_integrator)

        print(
            f'ENGD Iteration: {iteration}'
            f'\n  with loss: {interior_loss(params)} + {boundary_loss(params)} = {loss(params)}'
            f'\n  with relative errors L2: {l2_error/norm_sol_l2} and H1: {h1_error/norm_sol_h1}'
            f'\n  with step: {actual_step}'
        )

    # draw new points -- this can slow down the optimization
    if iteration % 1 == 0:
        interior_integrator.new_rand_points()
        boundary_integrator.new_rand_points()

ENGD Iteration: 10
  with loss: 54.68884812344515 + 21.610013825782787 = 76.29886194922793
  with relative error L2: 1.1223948956260013 and error H1: 1.1769769744652028
  with step: 0.125
ENGD Iteration: 20
  with loss: 0.9876317099300519 + 0.22495603033443942 = 1.2125877402644913
  with relative error L2: 0.1370635770226631 and error H1: 0.15297079350272477
  with step: 0.5
ENGD Iteration: 30
  with loss: 0.00023352589897040465 + 4.883026347165704e-05 = 0.0002823561624420617
  with relative error L2: 0.0016446045420251674 and error H1: 0.002677029323014301
  with step: 0.5
ENGD Iteration: 40
  with loss: 6.22047021627388e-07 + 1.484096927285855e-07 = 7.704567143559735e-07
  with relative error L2: 0.00010587898406792481 and error H1: 0.000126925306971915
  with step: 1.0
ENGD Iteration: 50
  with loss: 1.2978669541001043e-10 + 6.506143196306849e-12 = 1.362928386063173e-10
  with relative error L2: 3.812423750593745e-07 and error H1: 2.4856219170818905e-06
  with step: 1.0
ENGD Iterati

KeyboardInterrupt: 