In [4]:
"""
ENGD Optimization.
Two dimensional Poisson equation example. Solution given by

u(x,y) = sin(pi*x) * sin(py*y).

"""
import jax
import jax.numpy as jnp
from jax import random, grad, vmap, jit

from ngrad.models import init_params, mlp
from ngrad.domains import Square, SquareBoundary
from ngrad.integrators import DeterministicIntegrator
from ngrad.utility import laplace, grid_line_search_factory
from ngrad.inner import model_laplace, model_identity
from ngrad.gram import gram_factory, nat_grad_factory

jax.config.update("jax_enable_x64", True)

tau = 1.

# random seed
seed = 0

# domains
interior = Square(1.)
boundary = SquareBoundary(1.)

# integrators
interior_integrator = DeterministicIntegrator(interior, 30)
boundary_integrator = DeterministicIntegrator(boundary, 30)
eval_integrator = DeterministicIntegrator(interior, 200)

# model
activation = lambda x : jnp.tanh(x)
layer_sizes = [2, 32, 1]
params = init_params(layer_sizes, random.PRNGKey(seed))
model = mlp(activation)
# v_model = vmap(model, (None, 0))
v_model = vmap(lambda params, x: model(params, x), (None, 0))

# solution
@jit
def u_star(x):
    return jnp.prod(jnp.sin(jnp.pi * x))

# rhs
@jit
def f(x):
    return 2. * jnp.pi**2 * u_star(x)

# gramians
gram_bdry = gram_factory(
    model = model,
    trafo = model_identity,
    integrator = boundary_integrator
)

gram_laplace = gram_factory(
    model = model,
    trafo = model_laplace,
    integrator = interior_integrator
)

@jit
def gram(params):
    return gram_laplace(params) + gram_bdry(params)

# natural gradient
nat_grad = nat_grad_factory(gram)

# compute residual
laplace_model = lambda params: laplace(lambda x: model(params, x))
residual = lambda params, x: (laplace_model(params)(x) + f(x))**2.
v_residual =  jit(vmap(residual, (None, 0)))

# loss
@jit
def interior_loss(params):
    return interior_integrator(lambda x: v_residual(params, x))

@jit
def boundary_loss(params):
    return tau * boundary_integrator(lambda x: v_model(params, x)**2)

@jit
def loss(params):
    return interior_loss(params) + boundary_loss(params)

# set up grid line search
grid = jnp.linspace(0, 30, 31)
steps = 0.5**grid
ls_update = grid_line_search_factory(loss, steps)

# errors
error = lambda x: model(params, x) - u_star(x)
v_error = vmap(error, (0))
v_error_abs_grad = vmap(
        lambda x: jnp.dot(grad(error)(x), grad(error)(x))**0.5
        )

def l2_norm(f, integrator):
    return integrator(lambda x: (f(x))**2)**0.5    


iterations = 1000
save_freq = 10

import numpy as np
data = np.empty((iterations // save_freq + 1, 5))

# natural gradient descent with line search
alpha = 0.1
wb = 1.
for iteration in range(iterations + 1):
    interior_grads = grad(interior_loss)(params)
    interior_nat_grads = nat_grad(params, interior_grads)
    
    boundary_grads = grad(boundary_loss)(params)
    boundary_nat_grads = nat_grad(params, boundary_grads)
    
    updates = jax.tree_util.tree_map(
        lambda i, b: i + wb * b,
        interior_grads,
        boundary_grads,
    )
    params, actual_step = ls_update(params, updates)
    
    if iteration % save_freq == 0:
        # errors
        l2_error = l2_norm(v_error, eval_integrator)
        h1_error = l2_error + l2_norm(v_error_abs_grad, eval_integrator)
        
        data[iteration // save_freq, :] = [
            iteration,
            interior_loss(params),
            boundary_loss(params),
            l2_error,
            h1_error,
        ]
    
        print(
            f'ENGD Iteration: {iteration}'
            f'\n  with loss: {interior_loss(params)} + {boundary_loss(params)} = {loss(params)}'
            f'\n  with error L2: {l2_error} and error H1: {h1_error}'
            f'\n  with acual step: {actual_step}'
        )
        
    interior_grads_raveled, _ = jax.flatten_util.ravel_pytree(interior_grads)
    boundary_grads_raveled, _ = jax.flatten_util.ravel_pytree(boundary_grads)

    # update loss weights
    wb_hat = len(boundary_grads_raveled) * jnp.max(jnp.abs(interior_grads_raveled)) / (wb * jnp.sum(jnp.abs(boundary_grads_raveled)))
    wb = (1 - alpha) * wb + alpha * wb_hat
        
# jnp.save("data/engd.npy", data)


ENGD Iteration: 0
  with loss: 104.35925919813225 + 0.07558416209379887 = 104.43484336022605
  with error L2: 0.40265589073409175 and error H1: 2.624555747403158
  with acual step: 0.125
ENGD Iteration: 10
  with loss: 34.201900536079016 + 3.223910526887441 = 37.42581106296646
  with error L2: 0.8706637395832425 and error H1: 3.112001644713624
  with acual step: 0.00390625
ENGD Iteration: 20
  with loss: 16.050920153328434 + 0.3867210963871149 = 16.43764124971555
  with error L2: 0.2173899869321513 and error H1: 1.5117851780828009
  with acual step: 0.0009765625
ENGD Iteration: 30
  with loss: 14.490932268531752 + 0.28789844402766135 = 14.778830712559413
  with error L2: 0.16525799999294777 and error H1: 1.3557505354200734
  with acual step: 0.0019531249999999998
ENGD Iteration: 40
  with loss: 12.980381710768512 + 0.22569802377646683 = 13.206079734544979
  with error L2: 0.14545893889233724 and error H1: 1.2117288831908835
  with acual step: 0.00048828125
ENGD Iteration: 50
  with los

KeyboardInterrupt: 