In [None]:
import os
import time
from functools import partial
from pathlib import Path
from typing import Sequence

import jax
import jax.numpy as jnp
import optax
import numpy as np
import pandas as pd
import tqdm
from flax import linen as nn
from jax import jvp, value_and_grad

import matplotlib.pyplot as plt

In [None]:
def gaussian(alpha):
    phi = jnp.exp(-1 * alpha**2)
    return phi


def linear(alpha):
    return alpha


# Define the RBF module using FLAX
class RBF(nn.Module):

    out_features: int
    basis_func: callable

    @nn.compact
    def __call__(self, input):
        # Initialize learnable parameters
        centres = self.param("centres", nn.initializers.normal(), (self.out_features, 1))
        log_sigmas = self.param("log_sigmas", nn.initializers.constant(0.001), (self.out_features,))

        # Compute distances
        x = jnp.expand_dims(input, axis=1)
        c = jnp.expand_dims(centres, axis=0)
        distances = jnp.sqrt(jnp.sum((x - c) ** 2, axis=-1)) / jnp.exp(log_sigmas)

        # Apply radial basis function
        return self.basis_func(distances)

In [None]:
# forward function
class CPPINN(nn.Module):
    features: Sequence[int]

    # bases = [bases_x,bases_y,bases_z]
    @nn.compact
    def __call__(self, x, y, z):
        inputs, outputs = [x, y, z], []
        init = nn.initializers.xavier_normal()
        for X in inputs:
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            X = nn.Dense(self.features[-1], kernel_init=init)(X)

            outputs += [jnp.transpose(X, (1, 0))]

        xy = jnp.einsum("fx, fy->fxy", outputs[0], outputs[1])
        return jnp.einsum("fxy, fz->xyz", xy, outputs[-1])


class TTPINN(nn.Module):
    features: Sequence[int]

    # bases = [bases_x,bases_y,bases_z]
    @nn.compact
    def __call__(self, x, y, z):
        inputs, outputs = [x, y, z], []
        init = nn.initializers.xavier_uniform()
        for i, X in enumerate(inputs):
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            if i == 0:
                X = nn.DenseGeneral((self.features[-1], self.features[-1]), kernel_init=init)(X)

            else:
                X = nn.Dense(self.features[-1], kernel_init=init)(X)
            outputs += [X]

        # mid = jnp.einsum('ij,kj->ikj', outputs[1][:self.features[-1]], outputs[1][self.features[-1]:])
        # print(mid.shape)
        # mid = jnp.einsum('fx,ky->fyk',outputs[0],outputs[1])
        # xyz = jnp.einsum('fx, fy,fz->xyz', outputs[0], outputs[1],outputs[-1])
        return jnp.einsum("xfk,yf,zk->xyz", outputs[0], outputs[1], outputs[-1])


class TuckerPINN(nn.Module):
    features: Sequence[int]

    def setup(self):
        # Initialize learnable parameters
        # self.centres = self.param('centres', nn.initializers.uniform(1.01), (self.out_features, 1))
        self.core = self.param("core", nn.initializers.orthogonal(), (self.features[-1], self.features[-1], self.features[-1]))

    @nn.compact
    def __call__(self, x, y, z):
        inputs, outputs = [x, y, z], []
        init = nn.initializers.xavier_normal()
        for X in inputs:
            for fs in self.features[:-1]:
                X = nn.Dense(fs, kernel_init=init)(X)
                X = nn.activation.tanh(X)
            X = nn.Dense(self.features[-1], kernel_init=init)(X)

            outputs += [jnp.transpose(X, (1, 0))]
            # mid = jnp.einsum("fx,fy->fxy",outputs[0],outputs[1])
        return jnp.einsum("klm,kx,ly,mz->xyz", self.core, outputs[0], outputs[1], outputs[-1])


class RBFPINN(nn.Module):
    # features: Sequence[int]
    out_features: int
    basis_func: callable
    centers_x = jnp.linspace(-0.1, 10.1, 64).reshape((64, 1))
    centers_y = jnp.linspace(-1.1, 1.1, 64).reshape((64, 1))
    centers_z = jnp.linspace(-1.1, 1.1, 64).reshape((64, 1))
    all_centres = [centers_x, centers_y, centers_z]

    def setup(self):
        # Initialize learnable parameters
        # self.centres = self.param('centres', nn.initializers.uniform(1.01), (self.out_features, 1))
        self.log_sigmas = self.param("log_sigmas", nn.initializers.constant(0.0), (self.out_features,))

    @nn.compact
    def __call__(self, x, y, z):
        # Normalize input data
        # x = self.normalize(x,0,10)
        # print(x)
        # y = self.normalize(y,-1,2)
        # z = self.normalize(z,-1,2)

        inputs, outputs = [x, y, z], []
        init = nn.initializers.xavier_normal()
        for X, centres in zip(inputs, self.all_centres):
            # Compute distances
            x = jnp.expand_dims(X, axis=1)
            c = jnp.expand_dims(centres, axis=0)
            distances = jnp.sqrt(jnp.sum((x - c) ** 2, axis=-1)) / jnp.exp(self.log_sigmas)
            # Apply radial basis function
            X = self.basis_func(distances)
            # X = nn.Dense(self.out_features, kernel_init=init)(X)
            print(X)
            outputs += [jnp.transpose(X, (1, 0))]
        xy = jnp.einsum("fx, fy->fxy", outputs[0], outputs[-1])
        return jnp.einsum("fxy, fz->xyz", xy, outputs[1])

    def normalize(self, data, mean, std):
        # Normalize data
        # mean = jnp.mean(data, axis=0)
        # std = jnp.std(data, axis=0)
        normalized_data = (data - mean) / std
        return normalized_data


# hessian-vector product
def hvp_fwdfwd(f, primals, tangents, return_primals=False):
    g = lambda primals: jvp(f, (primals,), tangents)[1]
    primals_out, tangents_out = jvp(g, primals, tangents)
    if return_primals:
        return primals_out, tangents_out
    else:
        return tangents_out


# loss function
def spinn_loss_klein_gordon3d(apply_fn, *train_data):
    def residual_loss(params, t, x, y, source_term):
        # calculate u
        u = apply_fn(params, t, x, y)
        # tangent vector dx/dx
        # assumes t, x, y have same shape (very important)
        v = jnp.ones(t.shape)
        # 2nd derivatives of u
        utt = hvp_fwdfwd(lambda t: apply_fn(params, t, x, y), (t,), (v,))
        uxx = hvp_fwdfwd(lambda x: apply_fn(params, t, x, y), (x,), (v,))
        uyy = hvp_fwdfwd(lambda y: apply_fn(params, t, x, y), (y,), (v,))
        return jnp.mean((utt - uxx - uyy + u**2 - source_term) ** 2)

    def initial_loss(params, t, x, y, u):
        return jnp.mean((apply_fn(params, t, x, y) - u) ** 2)

    def boundary_loss(params, t, x, y, u):
        loss = 0.0
        for i in range(4):
            loss += (1 / 4.0) * jnp.mean((apply_fn(params, t[i], x[i], y[i]) - u[i]) ** 2)
        return loss

    # unpack data
    tc, xc, yc, uc, ti, xi, yi, ui, tb, xb, yb, ub = train_data

    # isolate loss function from redundant arguments
    fn = lambda params: residual_loss(params, tc, xc, yc, uc) + initial_loss(params, ti, xi, yi, ui) + boundary_loss(params, tb, xb, yb, ub)

    return fn


# optimizer step function
@partial(jax.jit, static_argnums=(0,))
def update_model(optim, gradient, params, state):
    updates, state = optim.update(gradient, state)
    params = optax.apply_updates(params, updates)
    return params, state

## 2. Data generator

In [None]:
# 2d time-dependent klein-gordon exact u
def _klein_gordon3d_exact_u(t, x, y):
    return (x + y) * jnp.cos(2 * t) + (x * y) * jnp.sin(2 * t)


# 2d time-dependent klein-gordon source term
def _klein_gordon3d_source_term(t, x, y):
    u = _klein_gordon3d_exact_u(t, x, y)
    return u**2 - 4 * u


# train data
def spinn_train_generator_klein_gordon3d(nc, key):
    keys = jax.random.split(key, 3)
    # collocation points
    tc = jax.random.uniform(keys[0], (nc, 1), minval=0.0, maxval=10.0)
    xc = jax.random.uniform(keys[1], (nc, 1), minval=-1.0, maxval=1.0)
    yc = jax.random.uniform(keys[2], (nc, 1), minval=-1.0, maxval=1.0)
    tc_mesh, xc_mesh, yc_mesh = jnp.meshgrid(tc.ravel(), xc.ravel(), yc.ravel(), indexing="ij")
    uc = _klein_gordon3d_source_term(tc_mesh, xc_mesh, yc_mesh)
    # initial points
    ti = jnp.zeros((1, 1))
    xi = xc
    yi = yc
    ti_mesh, xi_mesh, yi_mesh = jnp.meshgrid(ti.ravel(), xi.ravel(), yi.ravel(), indexing="ij")
    ui = _klein_gordon3d_exact_u(ti_mesh, xi_mesh, yi_mesh)
    # boundary points (hard-coded)
    tb = [tc, tc, tc, tc]
    xb = [jnp.array([[-1.0]]), jnp.array([[1.0]]), xc, xc]
    yb = [yc, yc, jnp.array([[-1.0]]), jnp.array([[1.0]])]
    ub = []
    for i in range(4):
        tb_mesh, xb_mesh, yb_mesh = jnp.meshgrid(tb[i].ravel(), xb[i].ravel(), yb[i].ravel(), indexing="ij")
        ub += [_klein_gordon3d_exact_u(tb_mesh, xb_mesh, yb_mesh)]
    return tc, xc, yc, uc, ti, xi, yi, ui, tb, xb, yb, ub


# test data
def spinn_test_generator_klein_gordon3d(nc_test):
    t = jnp.linspace(0, 10, nc_test)
    x = jnp.linspace(-1, 1, nc_test)
    y = jnp.linspace(-1, 1, nc_test)
    t = jax.lax.stop_gradient(t)
    x = jax.lax.stop_gradient(x)
    y = jax.lax.stop_gradient(y)
    tm, xm, ym = jnp.meshgrid(t, x, y, indexing="ij")
    u_gt = _klein_gordon3d_exact_u(tm, xm, ym)
    t = t.reshape(-1, 1)
    x = x.reshape(-1, 1)
    y = y.reshape(-1, 1)
    return t, x, y, u_gt, tm, xm, ym

## 3. Utils

In [None]:
def relative_l2(u, u_gt):
    return jnp.linalg.norm(u - u_gt) / jnp.linalg.norm(u_gt)


def plot_klein_gordon3d(t, x, y, u):
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection="3d")
    im = ax.scatter(t, x, y, c=u, s=0.5, cmap="viridis")
    # im2 = ax.scatter(0,0,0,c=-1,s=100,cmap='seismic')
    ax.set_title("U(t, x, y)", fontsize=20)
    ax.set_xlabel("t", fontsize=18, labelpad=10)
    ax.set_ylabel("x", fontsize=18, labelpad=10)
    ax.set_zlabel("y", fontsize=18, labelpad=10)
    fig.colorbar(im, ax=ax)
    plt.show()

## 4. Main function

In [None]:
def main(mode, NC, NI, NB, NC_TEST, SEED, LR, EPOCHS, N_LAYERS, FEATURES, LOG_ITER):
    # force jax to use one device
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

    # random key
    key = jax.random.PRNGKey(SEED)
    key, subkey = jax.random.split(key, 2)

    # feature sizes
    feat_sizes = tuple(FEATURES for _ in range(N_LAYERS))

    # model = RBFPINN(FEATURES,linear)
    # make & init model
    if mode == "CPPINN":
        model = CPPINN(feat_sizes)
    elif mode == "TTPINN":
        model = TTPINN(feat_sizes)
    elif mode == "TuckerPINN":
        model = TuckerPINN(feat_sizes)
    params = model.init(subkey, jax.random.uniform(key, (NC, 1)), jax.random.uniform(key, (NC, 1)), jax.random.uniform(key, (NC, 1)))
    # optimizer
    optim = optax.adam(LR)
    state = optim.init(params)

    # dataset
    key, subkey = jax.random.split(key, 2)
    train_data = spinn_train_generator_klein_gordon3d(NC, subkey)
    t, x, y, u_gt, tm, xm, ym = spinn_test_generator_klein_gordon3d(NC_TEST)
    # print(t,x,y)
    logger = []

    # forward & loss function
    apply_fn = jax.jit(model.apply)
    loss_fn = spinn_loss_klein_gordon3d(apply_fn, *train_data)

    @jax.jit
    def train_one_step(params, state):
        # compute loss and gradient
        loss, gradient = value_and_grad(loss_fn)(params)
        # update state
        params, state = update_model(optim, gradient, params, state)
        return loss, params, state

    start = time.time()
    pbar = tqdm.tqdm(total=EPOCHS)
    error = np.nan

    for iters in range(1, EPOCHS + 1):
        # single run
        loss, params, state = train_one_step(params, state)

        if iters % LOG_ITER == 0 or iters == 1:
            u = apply_fn(params, t, x, y)
            error = relative_l2(u, u_gt)
            # print(f"Epoch: {iters}/{EPOCHS} --> loss: {loss:.8f}, error: {error:.8f}")
            logger.append([iters, loss, error])

            # print("Solution:")
            # u = apply_fn(params, t, x, y)
            # plot_klein_gordon3d(tm, xm, ym, u)

        pbar.set_postfix({"loss": f"{loss:0.8f}", "error": f"{error:0.8f}"}, refresh=False)
        pbar.update(1)
    pbar.close()

    end = time.time()
    print(f"Runtime: {((end-start)/EPOCHS*1000):.2f} ms/iter.")

    # print("Solution:")
    # u = apply_fn(params, t, x, y)
    # plot_klein_gordon3d(tm, xm, ym, u)
    return logger

In [None]:
out_folder = Path("results")
out_folder.mkdir(exist_ok=True)

## 5. Run!

In [None]:
points = 64

for model in ["CPPINN", "TTPINN", "TuckerPINN"]:
    model_folder = out_folder / model
    model_folder.mkdir(exist_ok=True)

    for rank in [8, 16, 32]:
        model_folder_rank = model_folder / f"Rank_{rank:03d}"
        model_folder_rank.mkdir(exist_ok=True)

        for run in range(10):
            print(f"Running {model} with rank {rank} and run {run}")
            logs = main(mode=model, NC=points, NI=points, NB=points, NC_TEST=32, SEED=444444 + run, LR=1e-3, EPOCHS=80000, N_LAYERS=4, FEATURES=rank, LOG_ITER=5000)
            out_file = model_folder_rank / f"{model}-Rank_{rank:02d}-Points_{points:02d}-run_{run:02d}.csv"
            pd.DataFrame(logs, columns=["Iter", "Loss", "Error"]).to_csv(out_file, index=False, float_format="%.16f")