In [None]:
!pip install brainstate brainunit braintools pinnx
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

In [None]:
import jax

jax.__version__

In [None]:
import os
import pickle
import sys
import time
from contextlib import contextmanager
from typing import Sequence, Callable

import brainstate as bst
import brainunit as u
import jax
import numpy as np
import pinnx
import brainstate as bst
import brainunit as u
import numpy as np
import pinnx

In [None]:

@contextmanager
def change_stdout():
    stdout = sys.stdout
    stderr = sys.stderr
    try:
        with open(os.devnull, 'w') as devnull:
            sys.stdout = devnull
            sys.stderr = devnull
            yield
    finally:
        sys.stdout = stdout
        sys.stderr = stderr


class Trainer(pinnx.Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._train_times = []

    def _train(self, iterations, display_every, batch_size, callbacks):
        for i in range(iterations):
            callbacks.on_epoch_begin()
            callbacks.on_batch_begin()

            # get data
            self.train_state.set_data_train(*self.problem.train_next_batch(batch_size))

            # train one batch
            t0 = time.time()
            self.fn_train_step(self.train_state.X_train, self.train_state.y_train, **self.train_state.Aux_train)
            t1 = time.time()

            self._train_times.append(t1 - t0)

            self.train_state.epoch += 1
            self.train_state.step += 1
            if self.train_state.step % display_every == 0 or i + 1 == iterations:
                self._test()

            callbacks.on_batch_end()
            callbacks.on_epoch_end()

            if self.stop_training:
                break


def eval(
    problem,
    n_point,
    unit: bool = True,
    n_train: int = 15000,
    external_trainable_variables: Sequence = None,
    **kwargs
):
    trainer = Trainer(
        problem,
        external_trainable_variables=external_trainable_variables
    )
    _, compile_time = trainer.compile(bst.optim.Adam(1e-3), measture_train_step_compile_time=True)
    trainer.train(iterations=n_train)

    loss_train = jax.tree.map(lambda *xs: u.math.asarray(xs), *trainer.loss_history.loss_train,
                              is_leaf=u.math.is_quantity)
    loss_test = jax.tree.map(lambda *xs: u.math.asarray(xs), *trainer.loss_history.loss_test,
                             is_leaf=u.math.is_quantity)

    return dict(
        n_point=n_point,
        with_unit_or_not='with unit' if unit else 'without unit',
        compile_time=compile_time,
        train_times=np.asarray(trainer._train_times),
        loss_train=loss_train,
        loss_test=loss_test,
        best_loss_train=trainer.train_state.best_loss_train,
        best_loss_test=trainer.train_state.best_loss_test,
    )


def scaling_experiments(
    name: str,
    solve_with_unit: Callable[[float], dict],
    solve_without_unit: Callable[[float], dict],
    scales: Sequence[float] = (0.1, 0.5, 1.0, 2.0, 5.0, 10.0),
):
    platform = jax.default_backend()
    os.makedirs('results/', exist_ok=True)

    for scale in scales:
        print(f"scale: {scale}")
        # with change_stdout():
        result1 = eval(**solve_with_unit(scale))
        with open(f'results/{name}_{platform}_scaling={scale}_with_unit.pkl', 'wb') as f:
            pickle.dump(result1, f)
        print(f'with unit: {result1["best_loss_test"]}, {result1["compile_time"]}, {result1["train_times"].mean()}')
        result2 = eval(**solve_without_unit(scale))
        with open(f'results/{name}_{platform}_scaling={scale}_without_unit.pkl', 'wb') as f:
            pickle.dump(result2, f)
        print(f'without unit: {result2["best_loss_test"]}, {result2["compile_time"]}, {result2["train_times"].mean()}')
        print()


In [None]:

def solve_problem_with_unit(scale: float = 1.0):
    unit_of_x = u.meter
    unit_of_t = u.second
    unit_of_c = u.mole / u.meter ** 3

    kf = bst.ParamState(0.05 * u.meter ** 6 / u.mole ** 2 / u.second)
    D = bst.ParamState(1.0 * u.meter ** 2 / u.second)

    def pde(x, y):
        jacobian = net.jacobian(x, x='t')
        hessian = net.hessian(x)
        ca, cb = y['ca'], y['cb']
        dca_t = jacobian['ca']['t']
        dcb_t = jacobian['cb']['t']
        dca_xx = hessian['ca']['x']['x']
        dcb_xx = hessian['cb']['x']['x']
        eq_a = dca_t - 1e-3 * D.value * dca_xx + kf.value * ca * cb ** 2
        eq_b = dcb_t - 1e-3 * D.value * dcb_xx + 2 * kf.value * ca * cb ** 2
        return [eq_a, eq_b]

    net = pinnx.nn.Model(
        pinnx.nn.DictToArray(x=unit_of_x, t=unit_of_t),
        pinnx.nn.FNN([2] + [20] * 3 + [2], "tanh"),
        pinnx.nn.ArrayToDict(ca=unit_of_c, cb=unit_of_c),
    )

    geom = pinnx.geometry.Interval(0, 1)
    timedomain = pinnx.geometry.TimeDomain(0, 10)
    geomtime = pinnx.geometry.GeometryXTime(geom, timedomain)
    geomtime = geomtime.to_dict_point(x=unit_of_x, t=unit_of_t)

    def fun_bc(x):
        c = (1 - x['x'] / unit_of_x) * unit_of_c
        return {'ca': c, 'cb': c}

    bc = pinnx.icbc.DirichletBC(fun_bc)

    def fun_init(x):
        return {
            'ca': u.math.exp(-20 * x['x'] / unit_of_x) * unit_of_c,
            'cb': u.math.exp(-20 * x['x'] / unit_of_x) * unit_of_c,
        }

    ic = pinnx.icbc.IC(fun_init)

    def gen_traindata():
        data = np.load("./dataset/reaction.npz")
        t, x, ca, cb = data["t"], data["x"], data["Ca"], data["Cb"]
        X, T = np.meshgrid(x, t)
        x = {'x': X.flatten() * unit_of_x, 't': T.flatten() * unit_of_t}
        y = {'ca': ca.flatten() * unit_of_c, 'cb': cb.flatten() * unit_of_c}
        return x, y

    observe_x, observe_y = gen_traindata()
    observe_bc = pinnx.icbc.PointSetBC(observe_x, observe_y)

    num_domain = int(2000 * scale)
    num_boundary = int(100 * scale)
    num_initial = int(100 * scale)
    num_test = int(500 * scale)

    data = pinnx.problem.TimePDE(
        geomtime,
        pde,
        [bc, ic, observe_bc],
        net,
        num_domain=num_domain,
        num_boundary=num_boundary,
        num_initial=num_initial,
        num_test=num_test,
        anchors=observe_x,
    )

    return {
        'problem': data,
        'n_point': num_domain + num_boundary + num_initial,
        'unit': True,
        'n_train': 15000,
        'external_trainable_variables': [kf, D]
    }


In [None]:

def solve_problem_without_unit(scale: float = 1.0):
    kf = bst.ParamState(0.05)
    D = bst.ParamState(1.0)

    def pde(x, y):
        jacobian = net.jacobian(x, x='t')
        hessian = net.hessian(x)
        ca, cb = y['ca'], y['cb']
        dca_t = jacobian['ca']['t']
        dcb_t = jacobian['cb']['t']
        dca_xx = hessian['ca']['x']['x']
        dcb_xx = hessian['cb']['x']['x']
        eq_a = dca_t - 1e-3 * D.value * dca_xx + kf.value * ca * cb ** 2
        eq_b = dcb_t - 1e-3 * D.value * dcb_xx + 2 * kf.value * ca * cb ** 2
        return [eq_a, eq_b]

    net = pinnx.nn.Model(
        pinnx.nn.DictToArray(x=None, t=None),
        pinnx.nn.FNN([2] + [20] * 3 + [2], "tanh"),
        pinnx.nn.ArrayToDict(ca=None, cb=None),
    )

    geom = pinnx.geometry.Interval(0, 1)
    timedomain = pinnx.geometry.TimeDomain(0, 10)
    geomtime = pinnx.geometry.GeometryXTime(geom, timedomain)
    geomtime = geomtime.to_dict_point(x=None, t=None)

    def fun_bc(x):
        c = (1 - x['x'])
        return {'ca': c, 'cb': c}

    bc = pinnx.icbc.DirichletBC(fun_bc)

    def fun_init(x):
        return {
            'ca': u.math.exp(-20 * x['x']),
            'cb': u.math.exp(-20 * x['x']),
        }

    ic = pinnx.icbc.IC(fun_init)

    def gen_traindata():
        data = np.load("./dataset/reaction.npz")
        t, x, ca, cb = data["t"], data["x"], data["Ca"], data["Cb"]
        X, T = np.meshgrid(x, t)
        x = {'x': X.flatten(), 't': T.flatten()}
        y = {'ca': ca.flatten(), 'cb': cb.flatten()}
        return x, y

    observe_x, observe_y = gen_traindata()
    observe_bc = pinnx.icbc.PointSetBC(observe_x, observe_y)

    num_domain = int(2000 * scale)
    num_boundary = int(100 * scale)
    num_initial = int(100 * scale)
    num_test = int(500 * scale)

    data = pinnx.problem.TimePDE(
        geomtime,
        pde,
        [bc, ic, observe_bc],
        net,
        num_domain=num_domain,
        num_boundary=num_boundary,
        num_initial=num_initial,
        num_test=num_test,
        anchors=observe_x,
    )

    return {
        'problem': data,
        'n_point': num_domain + num_boundary + num_initial,
        'unit': True,
        'n_train': 15000,
        'external_trainable_variables': [kf, D]
    }


In [None]:
scaling_experiments(
    'diffusion_2d',
    solve_with_unit=solve_problem_with_unit,
    solve_without_unit=solve_problem_without_unit,
    scales=(1.0, 2.0, 5.0, 10.0),
)