In [35]:
from jax import grad, value_and_grad, jacfwd, jit, vmap
import jax.numpy as jnp
from jax.nn import softplus
from tqdm import tqdm
from jax.experimental.optimizers import adam
from jax import random
import matplotlib.pyplot as plt


In [44]:
def initialize_mlp(sizes, key):
    keys = random.split(key, len(sizes))

    def initialize_layer(m, n, key, scale=1e-2):
        w_key, b_key = random.split(key)
        return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

    return [initialize_layer(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]


def net(parameters, x):
    activations = x
    for w, b in parameters[:-1]:
        activations = softplus(jnp.dot(w, activations) + b)
    
    w, b = parameters[-1]
    u = jnp.sum(jnp.dot(w, activations) + b)
    return u

In [62]:
key = random.PRNGKey(0)
layer_sizes = [1, 32, 1]
params = initialize_mlp(layer_sizes,key)

init_fun, update_fun, get_params = adam(1e-4)
opt_state = init_fun(params)


def forward(params,x,a):
           
    def f(x):
        u, dudx = value_and_grad(net,argnums=1)(params,x)
        return dudx*a

    v, dvdx = value_and_grad(f)(x)
    u = net(params,x)

    return u, v, dvdx

@jit
def update(opt_state):

    params = get_params(opt_state)

    def loss_inlet(params):
        u, v ,dvdx = forward(params, 0.0, 1.0)
        return (v - 1.0)**2

    def loss_outlet(params):
        u, v ,dvdx = forward(params, 1.0, 1.0)
        return (u-0.0)**2

    def loss_interior(params):
        u, v ,dvdx = forward(params, 0.0, 1.0)
        return (dvdx-0.0)**2

    def loss(params):
        l1 = loss_inlet(params)
        l2 = loss_outlet(params)
        l3 = loss_interior(params)
        return l1+l2+l3

    # value_1, grads_1 = value_and_grad(loss_inlet)(params)
    # value_2, grads_2 = value_and_grad(loss_outlet)(params)
    # value_3, grads_3 = value_and_grad(loss_interior)(params)
    
    value, grads = value_and_grad(loss)(params)

    opt_state = update_fun(0, grads,opt_state)
    
    return value, opt_state


for i in tqdm(range(100000)):
    value, opt_state = update(opt_state)

params = get_params(opt_state)

net(params, 0.0)

100%|██████████| 100000/100000 [00:03<00:00, 25842.14it/s]


DeviceArray(-1.0000764, dtype=float32)

In [34]:
init_params, net = serial(
    Dense(1), 
    Softplus,
    Dense(32),
    Softplus,
    Dense(1)
)

key = random.PRNGKey(0)
output_shape, params = init_params(key,(1,))

init_fun, update_fun, get_params = adam(1e-4)
opt_state = init_fun(params)

grad(net)(params, jnp.asarray((0.0, 0,0)))


TypeError: Incompatible shapes for dot: got (3,) and (1, 1).