<a href="https://colab.research.google.com/github/leungronwai/Google_colab/blob/main/Test/jax_01.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install equinox

import jax

import jax.numpy as jnp
import equinox as eqx
import optax

from jax import random, vmap, grad
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def check_device():
    devices = jax.local_devices()
    if any('gpu' in str(device) for device in devices):
        return "Running on GPU"
    else:
        return "Running on CPU"


def load_training_data(path, ite):
    data = jnp.load(path)
    ic = data['ics'][ite]
    exact_solution = data['xs'][ite]
    t = data['t']
    mu = data['mus'][ite]

    return ic.astype('float32'), t.astype('float32'), exact_solution.astype('float32'), mu.astype('float32')


# the model
class PiNN(eqx.Module):
    matrices: list
    biases: list

    def __init__(self, N_features, N_layers, key):
        keys = random.split(key, N_layers + 1)
        features = [N_features[0], ] + [N_features[1], ] * (N_layers - 1) + [N_features[-1], ]
        self.matrices = [random.normal(key, (f_in, f_out)) / jnp.sqrt((f_in + f_out) / 2) for f_in, f_out, key in
                         zip(features[:-1], features[1:], keys)]
        keys = random.split(keys[-1], N_layers)
        self.biases = [random.normal(key, (f_out,)) for f_in, f_out, key in zip(features[:-1], features[1:], keys)]

    def __call__(self, x, y, t, B):
        x = jnp.stack([x, y, t])
        if B is not None:
            f = jnp.concatenate([jnp.cos(B * x), jnp.sin(B * x)], 0)
            f = f @ self.matrices[0] + self.biases[0]
        else:
            f = x @ self.matrices[0] + self.biases[0]
        for i in range(1, len(self.matrices)):
            f = jnp.tanh(f)
            # f = gelu(f)
            f = f @ self.matrices[i] + self.biases[i]
        return f


# get approximate solutions of different fidelity, exact solution and problem data
def get_trajectory(key):
    keys = random.split(key, 2)
    # Load Data
    data = jnp.load(r"/content/drive/MyDrive/tg_2d.npz")
    Re = 40
    mu = 2 * jnp.pi / Re
    # Rearrange Data
    XX = data['x']  # N x T
    YY = data['y']  # N x T
    TT = data['t']  # T x 1
    UU = data['u']  # N x T
    VV = data['v']  # N x T
    PP = data['p']  # N x T
    N = XX.shape[0]
    T = TT.shape[0]
    TT = jnp.tile(TT, (N, 1))

    ### residual point
    x = XX.flatten()[:, None]  # NT x 1
    y = YY.flatten()[:, None]  # NT x 1
    t = TT.flatten()[:, None]  # NT x 1

    x0 = XX[:, 0:1]
    y0 = YY[:, 0:1]
    t0 = TT[:, 0:1]
    u0 = UU[:, 0:1]
    v0 = VV[:, 0:1]

    lb_t = t.min()
    ub_t = t.max()
    t = 2.0 * (t - lb_t) / (ub_t - lb_t) - 1.0
    ob_xyt = jnp.concatenate([x, y, t], -1)

    ob_0 = jnp.concatenate([x0, y0, t0, u0, v0], -1)

    N_train = 100
    N0_train = 10000

    N_epochs = 5000
    ite = 1

    # parameters of neural network
    N_fourier_features = 25
    # B = random.normal(random.split(keys[0], 3)[-1], (N_fourier_features,)) * 10
    B = None
    N_features = [2 * N_fourier_features, 128, 3] if B is not None else [3, 128, 3]
    N_layers = 4

    model = PiNN(N_features, N_layers, keys[0])

    # parameters of optimizer
    learning_rate = 1e-3
    N_drop = 10000
    gamma = 0.95
    sc = optax.exponential_decay(learning_rate, N_drop, gamma)
    optim = optax.lion(learning_rate=sc)
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    def u_net(model, x, y, t, B=None):
        u = model(x, y, t, B)[0]
        return u

    def v_net(model, x, y, t, B=None):
        v = model(x, y, t, B)[1]
        return v

    def p_net(model, x, y, t, B=None):
        p = model(x, y, t, B)[2]
        return p

    def residual(model, coordinates, B=None):
        '''x_tt - mu * (1 - x ** 2) * x_t + x=0
        '''
        x, y, t = coordinates
        u = u_net(model, x, y, t, B)
        v = v_net(model, x, y, t, B)
        p_x = grad(p_net, argnums=1)(model, x, y, t, B)
        p_y = grad(p_net, argnums=2)(model, x, y, t, B)
        u_x = grad(u_net, argnums=1)(model, x, y, t, B)
        v_x = grad(v_net, argnums=1)(model, x, y, t, B)
        u_y = grad(u_net, argnums=2)(model, x, y, t, B)
        v_y = grad(v_net, argnums=2)(model, x, y, t, B)
        u_t = grad(u_net, argnums=3)(model, x, y, t, B)
        v_t = grad(v_net, argnums=3)(model, x, y, t, B)
        u_xx = grad(grad(u_net, argnums=1), argnums=1)(model, x, y, t, B)
        u_yy = grad(grad(u_net, argnums=2), argnums=2)(model, x, y, t, B)
        v_xx = grad(grad(v_net, argnums=1), argnums=1)(model, x, y, t, B)
        v_yy = grad(grad(v_net, argnums=2), argnums=2)(model, x, y, t, B)
        f_u = u_t + (u * u_x + v * u_y) + p_x - mu * (u_xx + u_yy)
        f_v = v_t + (u * v_x + v * v_y) + p_y - mu * (v_xx + v_yy)
        f_e = u_x + v_y
        return f_u, f_v, f_e

    def u_v_p(model, coordinates, B=None):
        x, y, t = coordinates
        u = u_net(model, x, y, t, B)
        v = v_net(model, x, y, t, B)
        p = p_net(model, x, y, t, B)
        return u, v, p

    def residual_loss(model, x, y, t, B=None):

        coordinates = jnp.stack([x, y, t]).T
        f_u, f_v, f_e = residual(model, coordinates, B)
        return f_u ** 2 + f_v ** 2 + f_e ** 2

    def residual_point(model, coordinates, B=None):
        f_u, f_v, f_e = vmap(residual, (None, 0, None))(model, coordinates, B)
        return (f_u ** 2) + (f_v ** 2) + (f_e ** 2)

    def update_x(model, coordinates, weight=1, B=None):
        x, y, t = coordinates
        dx = grad(residual_loss, argnums=1)(model, x, y, t, B)
        dy = grad(residual_loss, argnums=2)(model, x, y, t, B)
        x = x + weight * dx
        y = y + weight * dy
        return jnp.stack([x, y, t]).T

    def compute_loss(model, coordinates, initial, uv, B=None):
        fun = lambda x: residual(model, x, B)
        f_u, f_v, f_e = vmap(fun)(coordinates)
        fun0 = lambda x: u_v_p(model, x, B)
        u, v, _ = vmap(fun0)(initial)
        return (f_u ** 2).mean() + (f_v ** 2).mean() + (f_e ** 2).mean() + \
            ((u - uv[:, 0]) ** 2).mean() + ((v - uv[:, 1]) ** 2).mean()

    compute_loss_and_grads = eqx.filter_value_and_grad(compute_loss)

    @eqx.filter_jit
    def make_step(model, coordinates, input_point0s, uv, B, optim, opt_state):
        loss, grads = compute_loss_and_grads(model, coordinates, input_point0s, uv, B)
        updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    T = -1
    x_star = XX[:, T]
    y_star = YY[:, T]
    t_star = TT[:, T]

    u_star = UU[:, T]
    v_star = VV[:, T]
    p_star = PP[:, T]
    key = keys[1]
    # input_points = random.choice(key, ob_xyt, shape=(N_train,))
    input_points = random.choice(key, jnp.stack([x_star, y_star, t_star]).T, shape=(N_train,))

    points = []
    R = []
    for i in range(ite):
        keys = random.split(key, N_epochs + 1)
        for i in range(N_epochs):
            print(f'epoch={i}')
            input_point0s = random.choice(keys[i], ob_0, shape=(N0_train,))
            uv = input_point0s[:, 3:5]
            input_point0s = input_point0s[:, :3]
            loss, model, opt_state = make_step(model, input_points, input_point0s, uv, B, optim, opt_state)
            input_points = vmap(update_x, (None, 0, None, None))(model, input_points, 10, B)
            points.append(input_points)
            res = residual_point(model, jnp.stack([x_star, y_star, t_star]).T, B)
            R.append(res)
            # print(loss)
        # carry, history, points = scan(make_step_scan, carry, keys[:-1])
        key = keys[-1]

        u_pred = vmap(u_net, in_axes=(None, 0, 0, 0, None))(model, x_star, y_star, t_star, B)
        v_pred = vmap(v_net, in_axes=(None, 0, 0, 0, None))(model, x_star, y_star, t_star, B)
        p_pred = vmap(p_net, in_axes=(None, 0, 0, 0, None))(model, x_star, y_star, t_star, B)

        error_u = jnp.linalg.norm(u_star - u_pred) / jnp.linalg.norm(u_star)
        print('u: %.3f' % error_u)
        error_v = jnp.linalg.norm(v_star - v_pred) / jnp.linalg.norm(v_star)
        print('v: %.3f' % error_v)
        error_p = jnp.linalg.norm(p_star - p_pred) / jnp.linalg.norm(p_star)
        print('p: %.3f' % error_p)
        print('++++++++++++++++++++++++')
    points = jnp.stack(points)
    R = jnp.stack(R)
    jnp.savez('residual.npz', points=points, R=R)


if __name__ == "__main__":
    seed = 1234
    key = random.PRNGKey(seed)
    get_trajectory(key)

Collecting equinox
  Downloading equinox-0.11.2-py3-none-any.whl (164 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m164.1/164.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Collecting jaxtyping>=0.2.20 (from equinox)
  Downloading jaxtyping-0.2.24-py3-none-any.whl (38 kB)
Collecting typeguard<3,>=2.13.3 (from jaxtyping>=0.2.20->equinox)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, jaxtyping, equinox
Successfully installed equinox-0.11.2 jaxtyping-0.2.24 typeguard-2.13.3
epoch=0
epoch=1
epoch=2
epoch=3
epoch=4
epoch=5
epoch=6
epoch=7
epoch=8
epoch=9
epoch=10
epoch=11
epoch=12
epoch=13
epoch=14
epoch=15
epoch=16
epoch=17
epoch=18
epoch=19
epoch=20
epoch=21
epoch=22
epoch=23
epoch=24
epoch=25
epoch=26
epoch=27
epoch=28
epoch=29
epoch=30
epoch=31
epoch=32
epoch=33
epoch=34
epoch=35
epoch=36
epoch=37
epoch=38
epoch=39
epoch=40
epoch=41
epoch=42
epoch=43
epoch=44
epoch=45
epoch=46
epoch=47
epoch=48
epoch=49
epoch=50


KeyboardInterrupt: ignored

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
