In [None]:
import os
import time
from functools import partial
from pathlib import Path
from typing import Sequence

import jax
import jax.numpy as jnp
import optax
import numpy as np
import pandas as pd
import tqdm
from flax import linen as nn
from jax import jvp, value_and_grad

In [None]:
class CPPINN(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, a, b, c, d, e):
        inputs, outputs = [a, b, c, d, e], []
        init = nn.initializers.xavier_uniform()
        for X in inputs:
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            X = nn.Dense(self.features[-1], kernel_init=init)(X)

            outputs += [jnp.transpose(X, (1, 0))]

        return jnp.einsum(
            "ra,rb,rc,rd,re->abcde",
            outputs[0],
            outputs[1],
            outputs[2],
            outputs[3],
            outputs[4],
        )


class TTPINN(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, a, b, c, d, e):
        inputs, outputs = [a, b, c, d, e], []
        init = nn.initializers.xavier_uniform()
        for i, X in enumerate(inputs):
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            if i != 0 and i != 4:
                X = nn.DenseGeneral((self.features[-1], self.features[-1]), kernel_init=init)(X)
            else:
                X = nn.Dense(self.features[-1], kernel_init=init)(X)
            outputs += [X]
        return jnp.einsum(
            "a1,b12,c23,d34,e4->abcde",
            outputs[0],
            outputs[1],
            outputs[2],
            outputs[3],
            outputs[4],
        )


class TuckerPINN(nn.Module):
    features: Sequence[int]

    def setup(self):
        self.core = self.param(
            "core",
            nn.initializers.orthogonal(),
            (
                self.features[-1],
                self.features[-1],
                self.features[-1],
                self.features[-1],
                self.features[-1],
            ),
        )

    @nn.compact
    def __call__(self, a, b, c, d, e):
        inputs, outputs = [a, b, c, d, e], []
        init = nn.initializers.xavier_normal()
        for X in inputs:
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            X = nn.Dense(self.features[-1], kernel_init=init)(X)

            outputs += [jnp.transpose(X, (1, 0))]
        return jnp.einsum(
            "klmno,ka,lb,mc,nd,oe->abcde",
            self.core,
            outputs[0],
            outputs[1],
            outputs[2],
            outputs[3],
            outputs[4],
        )

In [None]:
def hvp_fwdfwd(f, primals, tangents, return_primals=False):
    # g = lambda primals: jvp(f, (primals,), tangents)[1]
    def g(primals):
        return jvp(f, (primals,), tangents)[1]

    primals_out, tangents_out = jvp(g, primals, tangents)
    if return_primals:
        return primals_out, tangents_out

    return tangents_out

In [None]:
def loss_poisson(apply_fn, *train_data):
    def residual_loss(params, a, b, c, d, e, source):

        # tangent vector dx/dx
        # v_f = jnp.ones(f.shape)
        v_a = jnp.ones(a.shape)
        v_b = jnp.ones(b.shape)
        v_c = jnp.ones(c.shape)
        v_d = jnp.ones(d.shape)
        v_e = jnp.ones(e.shape)

        uaa = hvp_fwdfwd(lambda a: apply_fn(params, a, b, c, d, e), (a,), (v_a,))
        ubb = hvp_fwdfwd(lambda b: apply_fn(params, a, b, c, d, e), (b,), (v_b,))
        ucc = hvp_fwdfwd(lambda c: apply_fn(params, a, b, c, d, e), (c,), (v_c,))
        udd = hvp_fwdfwd(lambda d: apply_fn(params, a, b, c, d, e), (d,), (v_d,))
        uee = hvp_fwdfwd(lambda e: apply_fn(params, a, b, c, d, e), (e,), (v_e,))
        # uff = hvp_fwdfwd(lambda t: apply_fn(params,a,b,c,d,e,f), (f,), (v_f,))
        nabla_u = uaa + ubb + ucc + udd + uee
        return jnp.mean((nabla_u + source) ** 2)

    def boundary_loss(params, a, b, c, d, e, u):
        loss = 0
        for i in range(10):
            loss += jnp.mean((apply_fn(params, a[i], b[i], c[i], d[i], e[i]) - u[i]) ** 2)
            return loss

    ac, bc, cc, dc, ec, source_term, ab, bb, cb, db, eb, ub = train_data
    loss_fn = lambda params: residual_loss(params, ac, bc, cc, dc, ec, source_term) + boundary_loss(params, ab, bb, cb, db, eb, ub)
    return loss_fn


# optimizer step function
@partial(jax.jit, static_argnums=(0,))
def update_model(optim, gradient, params, state):
    updates, state = optim.update(gradient, state)
    params = optax.apply_updates(params, updates)
    return params, state

In [None]:
def poisson_exact(a, b, c, d, e):
    sol = 0
    for i in [a, b, c, d, e]:
        sol += jnp.sin((jnp.pi / 2) * i)
    return sol


def relative_l2(u, u_gt):
    return jnp.linalg.norm(u - u_gt) / jnp.linalg.norm(u_gt)

In [None]:
def train_generator(nc, key):
    keys = jax.random.split(key, 6)
    ac = jax.random.uniform(keys[0], (nc,), minval=0.0, maxval=1.0)
    bc = jax.random.uniform(keys[1], (nc,), minval=0.0, maxval=1.0)
    cc = jax.random.uniform(keys[2], (nc,), minval=0.0, maxval=1.0)
    dc = jax.random.uniform(keys[3], (nc,), minval=0.0, maxval=1.0)
    ec = jax.random.uniform(keys[4], (nc,), minval=0.0, maxval=1.0)
    # fc = jax.random.uniform(keys[5], (nc,), minval=0., maxval=1.)

    acm, bcm, ccm, dcm, ecm = jnp.meshgrid(ac, bc, cc, dc, ec, indexing="ij")
    source_term = 0
    for i in [acm, bcm, ccm, dcm, ecm]:
        source_term = source_term + ((jnp.pi * jnp.pi / 4) * jnp.sin((jnp.pi / 2) * i))

    ac = ac.reshape(-1, 1)
    bc = bc.reshape(-1, 1)
    cc = cc.reshape(-1, 1)
    dc = dc.reshape(-1, 1)
    ec = ec.reshape(-1, 1)

    ab = [jnp.array([[0.0]]), jnp.array([[1.0]]), ac, ac, ac, ac, ac, ac, ac, ac]
    bb = [bc, bc, jnp.array([[0.0]]), jnp.array([[1.0]]), bc, bc, bc, bc, bc, bc]
    cb = [cc, cc, cc, cc, jnp.array([[0.0]]), jnp.array([[1.0]]), cc, cc, cc, cc]
    db = [dc, dc, dc, dc, dc, dc, jnp.array([[0.0]]), jnp.array([[1.0]]), dc, dc]
    eb = [ec, ec, ec, ec, ec, ec, ec, ec, jnp.array([[0.0]]), jnp.array([[1.0]])]

    ub = []
    for i in range(10):
        abm, bbm, cbm, dbm, ebm = jnp.meshgrid(ab[i].ravel(), bb[i].ravel(), cb[i].ravel(), db[i].ravel(), eb[i].ravel(), indexing="ij")
        ub += [poisson_exact(abm, bbm, cbm, dbm, ebm)]

    return ac, bc, cc, dc, ec, source_term, ab, bb, cb, db, eb, ub


def test_generator(nc_test):
    a = jnp.linspace(0, 1, nc_test)
    b = jnp.linspace(0, 1, nc_test)
    c = jnp.linspace(0, 1, nc_test)
    d = jnp.linspace(0, 1, nc_test)
    e = jnp.linspace(0, 1, nc_test)
    # f = jnp.linspace(0,1,nc_test)

    am, bm, cm, dm, em = jnp.meshgrid(a, b, c, d, e, indexing="ij")

    u_gt = poisson_exact(am, bm, cm, dm, em)

    a = a.reshape(-1, 1)
    b = b.reshape(-1, 1)
    c = c.reshape(-1, 1)
    d = d.reshape(-1, 1)
    e = e.reshape(-1, 1)
    # f = f.reshape(-1,1)

    return a, b, c, d, e, am, bm, cm, dm, em, u_gt

In [None]:
def main(mode, NC, NI, NB, NC_TEST, SEED, LR, EPOCHS, N_LAYERS, FEATURES, LOG_ITER):
    # force jax to use one device
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

    # random key
    key = jax.random.PRNGKey(SEED)
    key, subkey = jax.random.split(key, 2)

    # feature sizes
    feat_sizes = tuple(FEATURES for _ in range(N_LAYERS))

    if mode == "CPPINN":
        model = CPPINN(feat_sizes)
    elif mode == "TTPINN":
        model = TTPINN(feat_sizes)
    elif mode == "TuckerPINN":
        model = TuckerPINN(feat_sizes)

    params = model.init(
        subkey, jax.random.uniform(key, (NC, 1)), jax.random.uniform(key, (NC, 1)), jax.random.uniform(key, (NC, 1)), jax.random.uniform(key, (NC, 1)), jax.random.uniform(key, (NC, 1))
    )
    # optimizer
    optim = optax.adam(LR)
    state = optim.init(params)

    key, subkey = jax.random.split(key, 2)
    train_data = train_generator(NC, subkey)

    a, b, c, d, e, am, bm, cm, dm, em, u_gt = test_generator(NC_TEST)
    logger = []

    apply_fn = jax.jit(model.apply)
    loss_fn = loss_poisson(apply_fn, *train_data)

    @jax.jit
    def train_one_step(params, state):
        # compute loss and gradient
        loss, gradient = value_and_grad(loss_fn)(params)
        # update state
        params, state = update_model(optim, gradient, params, state)
        return loss, params, state

    start = time.time()

    pbar = tqdm.tqdm(total=EPOCHS)
    error = np.nan

    for iters in range(1, EPOCHS + 1):
        # single run
        loss, params, state = train_one_step(params, state)

        if iters % LOG_ITER == 0 or iters == 1:
            u = apply_fn(params, a, b, c, d, e)
            error = relative_l2(u, u_gt)
            logger.append([iters, loss, error])

        pbar.set_postfix({"loss": f"{loss:0.8f}", "error": f"{error:0.8f}"}, refresh=False)
        pbar.update(1)
    pbar.close()

    # add one last log
    u = apply_fn(params, a, b, c, d, e)
    error = relative_l2(u, u_gt)
    logger.append([iters, loss, error])

    end = time.time()
    print(f"Runtime: {((end-start)/EPOCHS*1000):.2f} ms/iter.")
    return logger

In [None]:
out_folder = Path("results")
out_folder.mkdir(exist_ok=True)

In [None]:
points = 24

for model in ["CPPINN", "TTPINN", "TuckerPINN"]:
    model_folder = out_folder / model
    model_folder.mkdir(exist_ok=True)

    for rank in [6, 8, 12]:
        model_folder_rank = model_folder / f"Rank_{rank:02d}"
        model_folder_rank.mkdir(exist_ok=True)

        for run in range(10):
            print(f"Running {model} with rank {rank} and run {run}")
            logs = main(mode=model, NC=points, NI=points, NB=points, NC_TEST=32, SEED=444444 + run, LR=1e-3, EPOCHS=80000, N_LAYERS=4, FEATURES=rank, LOG_ITER=5000)
            out_file = model_folder_rank / f"{model}-Rank_{rank:02d}-Points_{points:02d}-run_{run:02d}.csv"
            pd.DataFrame(logs, columns=["Iter", "Loss", "Error"]).to_csv(out_file, index=False, float_format="%.16f")