<a href="https://colab.research.google.com/github/jqiu19/ResAscent01/blob/main/poisson_combination.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install comet_ml
!pip install equinox

from comet_ml import Experiment
import jax.numpy as jnp
import equinox as eqx
import optax
import numpy as np
from jax.nn import gelu, tanh
from jax.lax import scan, stop_gradient
from jax import random, jit, vmap, grad
import os
import scipy
#import argparse

#parser = argparse.ArgumentParser(description="PINN")
#parser.add_argument("--data",type=str, default='/', metavar="DIR", help="path to dataset")
#parser.add_argument("--ntrain",type=int, default=500,  help="the number of training dataset")
#parser.add_argument("--ite", type=int,default=10, help="the number of iteration")
#parser.add_argument("--epochs", type=int,default=50000,  help="the number of epochs")
#parser.add_argument("--seed", type=int,default=0,  help="the name")
#parser.add_argument("--device", type=int,default=0,  help="cuda number")
#args = parser.parse_args()

#os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device)


def projection(point, p='inf', space=((-0.5, 1.0), (-0.5, 1.5)), a=8, atol=1e-8):
    '''

    :param point: should use single point
    :param p: projection norm
    :param space: space of projection
    :param a: parameter of power distribution
    :param atol: tolerance distance of close to the boundary
    :return: projected points

    '''
    if p == 'inf':
        x, y = point[:, 0], point[:, 1]
        N = len(x)
        len_x = space[0][1] - space[0][0]
        len_y = space[1][1] - space[1][0]
        x = jnp.clip(x, a_min=space[0][0], a_max=space[0][1])
        x = x + ((1 - np.random.power(a=a, size=(N,))) * len_x) * jnp.isclose(x, space[0][0], atol=atol) - (
                (1 - np.random.power(a=a, size=(N,))) * len_x) * jnp.isclose(x, space[0][1], atol=atol)
        y = jnp.clip(y, a_min=space[1][0], a_max=space[1][1])
        y = y + ((1 - np.random.power(a=a, size=(N,))) * len_y) * jnp.isclose(y, space[1][0], atol=atol) - (
                (1 - np.random.power(a=a, size=(N,))) * len_y) * jnp.isclose(y, space[1][1], atol=atol)
        point = jnp.stack([x, y], 1)
    elif p == 2:
        mid = (space[1] - space[0]) / 2
        temp_norm = jnp.linalg.norm(point - mid, ord=p)
        point = jnp.where(temp_norm > 1, (point - mid) / temp_norm + mid, point)
    return point


def analytic_solution_generate(x, y): #generate analytic solution of 2d poisson equation -Laplacian phi = f on domain [-1.0,1.0] x [-1.0,1.0], f(x, y) = 40 * np.exp(-10 * (x**2 + y**2)) * (-1 + 10 * x**2 + 10 * y**2) and boundary condition?
    phi = np.exp(-10 * (x**2 + y**2))
    return phi


def generate_data(N_TRAIN, seed):
    x = np.linspace(-1.0, 1.0, 101)
    y = np.linspace(-1.0, 1.0, 101)

    yb1 = np.array([-1.0] * 100)
    yb2 = np.array([1.0] * 100)
    xb1 = np.array([-1.0] * 100)
    xb2 = np.array([1.0] * 100)

    y_train1 = np.concatenate([y[1:101], y[0:100], xb1, xb2], 0).astype("float32")
    x_train1 = np.concatenate([yb1, yb2, x[0:100], x[1:101]], 0).astype("float32")

    xb_train = x_train1.reshape(x_train1.shape[0], 1).astype("float32")
    yb_train = y_train1.reshape(y_train1.shape[0], 1).astype("float32")
    ub_train = analytic_solution_generate(xb_train, yb_train)

    # generate test data
    x_star, y_star = np.meshgrid(np.linspace(-1.0, 1.0, 101), np.linspace(-1.0, 1.0, 101))
    x_star, y_star = x_star.reshape(-1, 1), y_star.reshape(-1, 1)
    u_star = analytic_solution_generate(x_star, y_star)

    x_train = x_star
    y_train = y_star

    return (
        x_train,
        y_train,
        xb_train,
        yb_train,
        ub_train,
        x_star,
        y_star,
        u_star,
    )


# the model
class PiNN(eqx.Module):
    matrices: list
    biases: list

    def __init__(self, N_features, N_layers, key):
        keys = random.split(key, N_layers + 1)
        features = [N_features[0], ] + [N_features[1], ] * (N_layers - 1) + [N_features[-1], ]
        self.matrices = [random.normal(key, (f_in, f_out)) / jnp.sqrt((f_in + f_out) / 2) for f_in, f_out, key in
                         zip(features[:-1], features[1:], keys)]
        keys = random.split(keys[-1], N_layers)
        self.biases = [random.normal(key, (f_out,)) for f_in, f_out, key in zip(features[:-1], features[1:], keys)]

    def __call__(self, x, y, B):
        points = jnp.stack([x, y])
        if B is not None:
            lowb, upb = B
            f = 2.0 * (points - lowb) / (upb - lowb) - 1.0
            f = f @ self.matrices[0] + self.biases[0]
        else:
            f = points @ self.matrices[0] + self.biases[0]
        for i in range(1, len(self.matrices)):
            f = tanh(f)
            f = f @ self.matrices[i] + self.biases[i]
        return f


# get approximate solutions of different fidelity, exact solution and problem data
def get_trajectory(key):
    experiment = Experiment(api_key = "Cyx5G6KntacSPEdSzwycQKkTe", project_name="Pinn_RD" + 'seed'  # +str(seed)
    )
    keys = random.split(key, 3)
    # supervised
    N_train = 500
    N_epochs = 50000
    ite = 10
    # Load Data
    #Re = 40
    #lam = 0.5 * Re - np.sqrt(0.25 * (Re ** 2) + 4 * (np.pi ** 2))
    #mu = 1.0 / 40
    (
        x_train,
        y_train,
        xb_train,
        yb_train,
        ub_train,
        x_star,
        y_star,
        u_star,
    ) = generate_data(N_train, 1234)

    ob_xy = jnp.concatenate([x_train, y_train], -1)
    Xb = np.concatenate([xb_train, yb_train], 1)
    lowb = Xb.min(0)  # minimal number in each column
    upb = Xb.max(0)
    B = [lowb, upb]

    N_features = [2, 50, 1]
    N_layers = 5

    model = PiNN(N_features, N_layers, keys[0])

    # parameters of optimizer
    learning_rate = 1e-3
    N_drop = 10000
    gamma = 0.95
    sc = optax.exponential_decay(learning_rate, N_drop, gamma)
    optim = optax.adam(learning_rate=sc)
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    def u_net(model, x, y, B=None):
        u = model(x, y, B)[0]
        return u

    #def v_net(model, x, y, B=None):
    #    v = model(x, y, B)[1]
    #    return v

    #def p_net(model, x, y, B=None):
    #    p = model(x, y, B)[2]
    #    return p

    def residual(model, coordinates, B=None):
        x, y = coordinates
        #u = u_net(model, x, y, B)
        #v = v_net(model, x, y, B)
        #p_x = grad(p_net, argnums=1)(model, x, y, B)
        #p_y = grad(p_net, argnums=2)(model, x, y, B)
        u_x = grad(u_net, argnums=1)(model, x, y, B)
        #v_x = grad(v_net, argnums=1)(model, x, y, B)
        u_y = grad(u_net, argnums=2)(model, x, y, B)
        #v_y = grad(v_net, argnums=2)(model, x, y, B)
        u_xx = grad(grad(u_net, argnums=1), argnums=1)(model, x, y, B)
        u_yy = grad(grad(u_net, argnums=2), argnums=2)(model, x, y, B)
        #v_xx = grad(grad(v_net, argnums=1), argnums=1)(model, x, y, B)
        #v_yy = grad(grad(v_net, argnums=2), argnums=2)(model, x, y, B)
        f_u = -40 * jnp.exp(-10 * (x**2 + y**2)) * (-1 + 10 * x**2 + 10 * y**2) + (u_xx + u_yy) #residual
        #f_v = (u * v_x + v * v_y) + p_y - mu * (v_xx + v_yy)
        #f_e = u_x + v_y
        return f_u

    def u_convenient(model, coordinates, B=None):
        x, y = coordinates
        u = u_net(model, x, y, B)
        #v = v_net(model, x, y, B)
        #p = p_net(model, x, y, B)
        return u

    def residual_loss(model, x, y, B=None):

        coordinates = jnp.stack([x, y]).T
        f_u = residual(model, coordinates, B)
        return f_u ** 2

    def update_x(model, coordinates, weight=1, B=None):
        x, y = coordinates
        dx = grad(residual_loss, argnums=1)(model, x, y, B)
        dy = grad(residual_loss, argnums=2)(model, x, y, B)
        x = x + weight * dx
        y = y + weight * dy
        return jnp.stack([x, y])

    def compute_loss(model, coordinates, initial, uv, B=None):
        fun = lambda x: residual(model, x, B)
        f_u = vmap(fun)(coordinates)
        fun0 = lambda x: u_convenient(model, x, B)
        u = vmap(fun0)(initial)
        return jnp.mean(jnp.square(f_u)) + \
               jnp.mean(jnp.square((u - uv[:, 0])))
    compute_loss_and_grads = eqx.filter_value_and_grad(compute_loss)

    @eqx.filter_jit
    def make_step(model, coordinates, input_point0s, uv, B, optim, opt_state):
        loss, grads = compute_loss_and_grads(model, coordinates, input_point0s, uv, B)
        updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    input_points = random.choice(keys[1], ob_xy, shape=(N_train,), replace=False)
    uv = jnp.concatenate([ub_train], -1)
    input_point0s = jnp.concatenate([xb_train, yb_train], -1)
    width = 30000
    shift = 10
    alpha = lambda x: np.tanh(x / width - shift) / 2 + 0.5
    weight = lambda j: 10 * 2 ** (-0.995 * j / 1000)
    for j in range(ite * N_epochs):
        loss, model, opt_state = make_step(model, input_points, input_point0s, uv, B, optim, opt_state)
        if j % (100*ite) == 0:
            idx = round((alpha(j)) * N_train)
            if idx != 0:
                move_points = input_points[:idx]
                move_points = vmap(update_x, (None, 0, None, None))(model, move_points, weight(j), B)
                move_points = projection(move_points)
                input_points = jnp.concatenate([input_points[idx:], move_points], 0)
                input_points = stop_gradient(input_points)
            # points.append(input_points)
            # res = residual_point(model, jnp.stack([x_star, y_star, t_star]).T, B)
            # R.append(res)
        if experiment is not None and j % ite == 0:
            metrics = {'loss': loss.item()}
            experiment.log_metrics(metrics, step=j)
        if j % N_epochs == 0:
            u_pred = vmap(u_net, in_axes=(None, 0, 0, None))(model, x_star[:, 0], y_star[:, 0], B)
            #v_pred = vmap(v_net, in_axes=(None, 0, 0, None))(model, x_star[:, 0], y_star[:, 0], B)
            #p_pred = vmap(p_net, in_axes=(None, 0, 0, None))(model, x_star[:, 0], y_star[:, 0], B)

            error_u = jnp.linalg.norm(u_star.reshape(-1) - u_pred.reshape(-1)) / jnp.linalg.norm(u_star)
            print('u: %.3f' % error_u)
            #error_v = jnp.linalg.norm(v_star.reshape(-1) - v_pred.reshape(-1)) / jnp.linalg.norm(v_star)
            #print('v: %.3f' % error_v)
            #error_p = jnp.linalg.norm(p_star.reshape(-1) - p_pred.reshape(-1)) / jnp.linalg.norm(p_star)
            #print('p: %.3f' % error_p)
            if experiment is not None:
                metrics = {'error_u': error_u}
                experiment.log_metrics(metrics, step=j)
            print('++++++++++++++++++++++++')
    u_pred = vmap(u_net, in_axes=(None, 0, 0, None))(model, x_star[:, 0], y_star[:, 0], B)
    #v_pred = vmap(v_net, in_axes=(None, 0, 0, None))(model, x_star[:, 0], y_star[:, 0], B)
    #p_pred = vmap(p_net, in_axes=(None, 0, 0, None))(model, x_star[:, 0], y_star[:, 0], B)

    error_u = jnp.linalg.norm(u_star.reshape(-1) - u_pred.reshape(-1)) / jnp.linalg.norm(u_star)
    print('u: %.3f' % error_u)
    #error_v = jnp.linalg.norm(v_star.reshape(-1) - v_pred.reshape(-1)) / jnp.linalg.norm(v_star)
    #print('v: %.3f' % error_v)
    #error_p = jnp.linalg.norm(p_star.reshape(-1) - p_pred.reshape(-1)) / jnp.linalg.norm(p_star)
    #print('p: %.3f' % error_p)
    if experiment is not None:
        metrics = {'error_u': error_u}
        experiment.log_metrics(metrics, step=ite * N_epochs)
    print('++++++++++++++++++++++++')
    if experiment is not None:
        experiment.set_name('poisson_combination' + str(N_train) + '_' + str(ite * N_epochs)+'_'+str(seed))


if __name__ == "__main__":
    seed = 1234
    np.random.seed(seed)
    key = random.PRNGKey(seed)
    get_trajectory(key)

Collecting comet_ml
  Downloading comet_ml-3.38.1-py3-none-any.whl (611 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/611.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.8/611.7 kB[0m [31m4.8 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━[0m [32m542.7/611.7 kB[0m [31m7.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m611.7/611.7 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
Collecting python-box<7.0.0 (from comet_ml)
  Downloading python_box-6.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m14.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting requests-toolbelt>=0.8.0 (from comet_ml)
  Downloading requests_toolbelt-1.0.0-py2.py3-none-any.whl (54 kB)
[2K     [90m━━━━━━

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in '/content' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.
[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/jqiu19/pinn-rdseed/24cbdf353d994d6994e743568eaf9e6f



u: 7.569
++++++++++++++++++++++++
