In [None]:
!pip install brainstate brainunit braintools pinnx
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

In [None]:
import jax

jax.__version__

In [None]:
import os
import pickle
import sys
import time
from contextlib import contextmanager
from typing import Sequence, Callable

import brainstate as bst
import brainunit as u
import jax
import numpy as np
import pinnx

In [None]:

@contextmanager
def change_stdout():
    stdout = sys.stdout
    stderr = sys.stderr
    try:
        with open(os.devnull, 'w') as devnull:
            sys.stdout = devnull
            sys.stderr = devnull
            yield
    finally:
        sys.stdout = stdout
        sys.stderr = stderr


class Trainer(pinnx.Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._train_times = []

    def _train(self, iterations, display_every, batch_size, callbacks):
        for i in range(iterations):
            callbacks.on_epoch_begin()
            callbacks.on_batch_begin()

            # get data
            self.train_state.set_data_train(*self.problem.train_next_batch(batch_size))

            # train one batch
            t0 = time.time()
            self.fn_train_step(self.train_state.X_train, self.train_state.y_train, **self.train_state.Aux_train)
            t1 = time.time()

            self._train_times.append(t1 - t0)

            self.train_state.epoch += 1
            self.train_state.step += 1
            if self.train_state.step % display_every == 0 or i + 1 == iterations:
                self._test()

            callbacks.on_batch_end()
            callbacks.on_epoch_end()

            if self.stop_training:
                break


def eval(
    problem,
    n_point,
    unit: bool = True,
    n_train: int = 15000,
    external_trainable_variables: Sequence = None,
    **kwargs
):
    trainer = Trainer(
        problem,
        external_trainable_variables=external_trainable_variables
    )
    _, compile_time = trainer.compile(bst.optim.Adam(1e-3), measture_train_step_compile_time=True)
    trainer.train(iterations=n_train)

    loss_train = jax.tree.map(lambda *xs: u.math.asarray(xs), *trainer.loss_history.loss_train,
                              is_leaf=u.math.is_quantity)
    loss_test = jax.tree.map(lambda *xs: u.math.asarray(xs), *trainer.loss_history.loss_test,
                             is_leaf=u.math.is_quantity)

    return dict(
        n_point=n_point,
        with_unit_or_not='with unit' if unit else 'without unit',
        compile_time=compile_time,
        train_times=np.asarray(trainer._train_times),
        loss_train=loss_train,
        loss_test=loss_test,
        best_loss_train=trainer.train_state.best_loss_train,
        best_loss_test=trainer.train_state.best_loss_test,
    )


def scaling_experiments(
    name: str,
    solve_with_unit: Callable[[float], dict],
    solve_without_unit: Callable[[float], dict],
    scales: Sequence[float] = (0.1, 0.5, 1.0, 2.0, 5.0, 10.0),
):
    platform = jax.default_backend()
    os.makedirs('results/', exist_ok=True)

    for scale in scales:
        print(f"scale: {scale}")
        # with change_stdout():
        result1 = eval(**solve_with_unit(scale))
        with open(f'results/{name}_{platform}_scaling={scale}_with_unit.pkl', 'wb') as f:
            pickle.dump(result1, f)
        print(f'with unit: {result1["best_loss_test"]}, {result1["compile_time"]}, {result1["train_times"].mean()}')
        result2 = eval(**solve_without_unit(scale))
        with open(f'results/{name}_{platform}_scaling={scale}_without_unit.pkl', 'wb') as f:
            pickle.dump(result2, f)
        print(f'without unit: {result2["best_loss_test"]}, {result2["compile_time"]}, {result2["train_times"].mean()}')
        print()


In [None]:
def solve_problem_with_unit(scale=1.0):
    geometry = pinnx.geometry.GeometryXTime(
        geometry=pinnx.geometry.Interval(-1, 1.),
        timedomain=pinnx.geometry.TimeDomain(0, 0.99)
    ).to_dict_point(x=u.meter, t=u.second)

    uy = u.meter / u.second
    bc = pinnx.icbc.DirichletBC(lambda x: {'y': 0. * uy})
    ic = pinnx.icbc.IC(lambda x: {'y': -u.math.sin(u.math.pi * x['x'] / u.meter) * uy})

    v = 0.01 / u.math.pi * u.meter ** 2 / u.second

    def pde(x, y):
        jacobian = approximator.jacobian(x)
        hessian = approximator.hessian(x)
        dy_x = jacobian['y']['x']
        dy_t = jacobian['y']['t']
        dy_xx = hessian['y']['x']['x']
        residual = dy_t + y['y'] * dy_x - v * dy_xx
        return residual

    approximator = pinnx.nn.Model(
        pinnx.nn.DictToArray(x=u.meter, t=u.second),
        pinnx.nn.FNN(
            [geometry.dim] + [20] * 3 + [1],
            "tanh",
        ),
        pinnx.nn.ArrayToDict(y=uy)
    )

    num_domain = int(2540 * scale)
    num_boundary = int(80 * scale)
    num_initial = int(160 * scale)

    problem = pinnx.problem.TimePDE(
        geometry,
        pde,
        [bc, ic],
        approximator,
        num_domain=num_domain,
        num_boundary=num_boundary,
        num_initial=num_initial,
    )

    return {
        'problem': problem,
        'n_point': num_domain + num_boundary + num_initial,
        'unit': True,
        'n_train': 15000
    }

In [None]:
def solve_problem_without_unit(scale=1.0):
    geometry = pinnx.geometry.GeometryXTime(
        geometry=pinnx.geometry.Interval(-1, 1.),
        timedomain=pinnx.geometry.TimeDomain(0, 0.99)
    ).to_dict_point(x=None, t=None)

    bc = pinnx.icbc.DirichletBC(lambda x: {'y': 0.})
    ic = pinnx.icbc.IC(lambda x: {'y': -u.math.sin(u.math.pi * x['x'])})

    v = 0.01 / u.math.pi

    def pde(x, y):
        jacobian = approximator.jacobian(x)
        hessian = approximator.hessian(x)
        dy_x = jacobian['y']['x']
        dy_t = jacobian['y']['t']
        dy_xx = hessian['y']['x']['x']
        residual = dy_t + y['y'] * dy_x - v * dy_xx
        return residual

    approximator = pinnx.nn.Model(
        pinnx.nn.DictToArray(x=None, t=None),
        pinnx.nn.FNN(
            [geometry.dim] + [20] * 3 + [1],
            "tanh",
        ),
        pinnx.nn.ArrayToDict(y=None)
    )

    num_domain = int(2540 * scale)
    num_boundary = int(80 * scale)
    num_initial = int(160 * scale)

    problem = pinnx.problem.TimePDE(
        geometry,
        pde,
        [bc, ic],
        approximator,
        num_domain=num_domain,
        num_boundary=num_boundary,
        num_initial=num_initial,
    )

    return {
        'problem': problem,
        'n_point': num_domain + num_boundary + num_initial,
        'unit': False,
        'n_train': 15000
    }

In [None]:
scaling_experiments(
    'Burger_equation',
    solve_with_unit=solve_problem_with_unit,
    solve_without_unit=solve_problem_without_unit,
    scales=(0.1, 0.5, 1.0, 2.0, 5.0, 10.0),
)
