In [1]:
import jax.numpy as jnp
from jax import lax, jit, grad, random, vmap, pmap, local_device_count, tree_map, jacfwd, jacrev
from jax.tree_util import tree_map, tree_reduce, tree_leaves
from functools import partial
from absl import logging
from absl import app
from absl import flags
import os
import time
import ml_collections
import jax
from jax.tree_util import tree_map
from flax.training import train_state
from flax import jax_utils
from typing import Any, Callable, Sequence, Tuple, Optional, Dict
from matplotlib import pyplot as plt
from jaxpi.evaluator import BaseEvaluator
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"  # DETERMINISTIC
from ml_collections import config_flags

import optax

from jaxpi import archs
from jaxpi.utils import flatten_pytree

  from jax import lax, jit, grad, random, vmap, pmap, local_device_count, tree_map, jacfwd, jacrev


In [2]:
u = 1.0
t_0 = 0.0
t_end = 50.0
r = 1000.0
n_samples = 50
c = 0.01

def solution(t):
    return - t / (r*c) + jnp.log(u/r)

def get_dataset():
    t = jnp.linspace(t_0, t_end, n_samples)
    u = solution(t)
    return t,u

In [3]:
#get_dataset()

In [4]:
class UniformSampler():
    #@partial(pmap, static_broadcasted_argnums=(0,))
    def data_generation():
        dom = jnp.array([[0., 50.]])
        dim = dom.shape[0]

        batch = random.uniform(
            random.PRNGKey(1234),
            shape=(5, dim),
            minval=dom[:, 0],
            maxval=dom[:, 1],
        )
        return batch

In [5]:
#data_generation()

In [6]:
class TrainState(train_state.TrainState):
    weights: Dict
    momentum: float

    def apply_weights(self, weights, **kwargs):
        """Updates `weights` using running average  in return value.

        Returns:
          An updated instance of `self` with new weights updated by applying `running_average`,
          and additional attributes replaced as specified by `kwargs`.
        """

        running_average = (
            lambda old_w, new_w: old_w * self.momentum + (1 - self.momentum) * new_w
        )
        weights = tree_map(running_average, self.weights, weights)
        weights = lax.stop_gradient(weights)

        return self.replace(
            step=self.step,
            params=self.params,
            opt_state=self.opt_state,
            weights=weights,
            **kwargs,
        )


def _create_arch(config):
    if config.arch_name == "Mlp":
        arch = archs.Mlp(**config)
    else:
        raise NotImplementedError(f"Arch {config.arch_name} not supported yet!")

    return arch


def _create_optimizer(config):
    if config.optimizer == "Adam":
        lr = optax.exponential_decay(
            init_value=config.learning_rate,
            transition_steps=config.decay_steps,
            decay_rate=config.decay_rate,
        )
        tx = optax.adam(
            learning_rate=lr, b1=config.beta1, b2=config.beta2, eps=config.eps
        )

    else:
        raise NotImplementedError(f"Optimizer {config.optimizer} not supported yet!")

    # Gradient accumulation
    if config.grad_accum_steps > 1:
        tx = optax.MultiSteps(tx, every_k_schedule=config.grad_accum_steps)

    return tx


# Create nn module from config file
def _create_train_state(config):
    # Initialize network
    arch = _create_arch(config.arch) # nn.module
    x = jnp.ones(config.input_dim)
    params = arch.init(random.PRNGKey(config.seed), x)

    # Initialize optax optimizer
    tx = _create_optimizer(config.optim)

    # Convert config dict to dict
    init_weights = dict(config.weighting.init_weights)

    state = TrainState.create(
        apply_fn=arch.apply,
        params=params,
        tx=tx,
        weights=init_weights,
        momentum=config.weighting.momentum,
    )

    return jax_utils.replicate(state)


class PINN:
    def __init__(self, config):
        self.config = config
        self.state = _create_train_state(config)

    def u_net(self, params, *args):
        raise NotImplementedError("Subclasses should implement this!")

    def r_net(self, params, *args):
        raise NotImplementedError("Subclasses should implement this!")

    def losses(self, params, batch, *args):
        raise NotImplementedError("Subclasses should implement this!")

    def compute_diag_ntk(self, params, batch, *args):
        raise NotImplementedError("Subclasses should implement this!")

    @partial(jit, static_argnums=(0,))
    def loss(self, params, weights, batch, *args):
        # Compute losses
        losses = self.losses(params, batch, *args)
        # Compute weighted loss
        weighted_losses = tree_map(lambda x, y: x * y, losses, weights)
        # Sum weighted losses
        loss = tree_reduce(lambda x, y: x + y, weighted_losses)
        return loss

    @partial(jit, static_argnums=(0,))
    def compute_weights(self, params, batch, *args):
        if self.config.weighting.scheme == "grad_norm":
            # Compute the gradient of each loss w.r.t. the parameters
            grads = jacrev(self.losses)(params, batch, *args)

            # Compute the grad norm of each loss
            grad_norm_dict = {}
            for key, value in grads.items():
                flattened_grad = flatten_pytree(value)
                grad_norm_dict[key] = jnp.linalg.norm(flattened_grad)

            # Compute the mean of grad norms over all losses
            mean_grad_norm = jnp.mean(jnp.stack(tree_leaves(grad_norm_dict)))
            # Grad Norm Weighting
            w = tree_map(lambda x: (mean_grad_norm / x), grad_norm_dict)

        return w

    @partial(pmap, axis_name="batch", static_broadcasted_argnums=(0,))
    def update_weights(self, state, batch, *args):
        weights = self.compute_weights(state.params, batch, *args)
        weights = lax.pmean(weights, "batch")
        state = state.apply_weights(weights=weights)
        return state

    @partial(pmap, axis_name="batch", static_broadcasted_argnums=(0,))
    def step(self, state, batch, *args):
        grads = grad(self.loss)(state.params, state.weights, batch, *args)
        grads = lax.pmean(grads, "batch")
        state = state.apply_gradients(grads=grads)
        return state


class ForwardIVP(PINN):
    def __init__(self, config):
        super().__init__(config)

        if config.weighting.use_causal:
            self.tol = config.weighting.causal_tol
            self.num_chunks = config.weighting.num_chunks
            self.M = jnp.triu(jnp.ones((self.num_chunks, self.num_chunks)), k=1).T


class ForwardBVP(PINN):
    def __init__(self, config):
        super().__init__(config)


looks ok until here

In [7]:
class CaseZero(ForwardIVP):
    def __init__(self, config, t_star, u0):
        super().__init__(config)

        self.u0 = u0
        self.t_star = t_star

        self.t0 = t_star[0]
        self.t1 = t_star[-1]

        # Predictions over t
        self.u_pred_fn = vmap(self.u_net, (0, None)) # -------------- DEBUG
        self.r_pred_fn = vmap(self.r_net, (0, None)) # -------------- DEBUG


    # Prediction from net for initial value
    def u_net(self, params, t):
        u = self.state.apply_fn(params, t) # -------------- DEBUG
        return u[0]

    # Gradient of the neural net
    def grad_net(self, params, t):
        u_t = grad(self.u_net, argnums=1)(params, t)
        return u_t

    # Residual of the neural net
    def r_net(self, params, t):
        u_t = self.grad_net(params, t)
        return u_t + 0.1

    @partial(jit, static_argnums=(0,))
    def res_and_w(self, params, batch):
        "Compute residuals and weights for causal training"
        # Sort time coordinates
        t_sorted = batch[:, 0].sort()
        r_pred = vmap(self.r_net, (None, 0))(params, t_sorted) # -------------- DEBUG
        # Split residuals into chunks
        r_pred = r_pred.reshape(self.num_chunks, -1)
        l = jnp.mean(r_pred**2, axis=1)
        w = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ l)))
        return l, w

    @partial(jit, static_argnums=(0,))
    def losses(self, params, batch):
        # Initial condition loss
        u_pred = self.u_net(params, self.t0) # -------------- DEBUG
        ics_loss = jnp.mean((self.u0 - u_pred) ** 2)

        # Residual loss
        if self.config.weighting.use_causal == True:
            l, w = self.res_and_w(params, batch)
            res_loss = jnp.mean(l * w)
        else:
            r_pred = vmap(self.r_net, (None, 0))(params, batch[:, 0]) # -------------- DEBUG
            res_loss = jnp.mean((r_pred) ** 2)

        loss_dict = {"ics": ics_loss, "res": res_loss}
        return loss_dict


    @partial(jit, static_argnums=(0,))
    def compute_l2_error(self, params, u_test):
        u_pred = self.u_pred_fn(params, self.t_star)
        error = jnp.linalg.norm(u_pred - u_test) / jnp.linalg.norm(u_test)
        return error

In [8]:
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str):

    # Get dataset
    t_star, u_ref  = get_dataset()
    u0 = u_ref[0]

    t0 = t_star[0]
    t_end = t_star[-1]

    # Define domain
    dom = jnp.array([[t0, t_end]])

    # Define residual sampler
    res_sampler = iter(UniformSampler(dom, config.training.batch_size_per_device))

    # Initialize model
    model = CaseZero(config, t_star, u0)

    print("Waiting for JIT...")
    start_time = time.time()
    for step in range(config.training.max_steps):
        batch = next(res_sampler)
        model.state = model.step(model.state, batch)

        if config.weighting.scheme in ["grad_norm", "ntk"]:
            if step % config.weighting.update_every_steps == 0:
                model.state = model.update_weights(model.state, batch)


    return model

In [None]:

jax.config.update("jax_default_matmul_precision", "highest")


FLAGS = flags.FLAGS

flags.DEFINE_string("workdir", ".", "Directory to store model data.")

config_flags.DEFINE_config_file(
    "config",
    "./configs/default.py",
    "File path to the training hyperparameter configuration.",
    lock_config=True,
)


if __name__ == "__main__":
    train_and_evaluate(FLAGS.config, FLAGS.workdir)

FATAL Flags parsing error: Unknown command line flag 'f'
Pass --helpshort or --helpfull to see help on flags.


AttributeError: 'tuple' object has no attribute 'tb_frame'