In [1]:
import h5py
import jax
import jax.numpy as jnp
import optax
from clu import metrics
from flax import linen as nn
from flax import struct
from flax.training import train_state
from jax import numpy as jnp
from pprint import pprint
from gwkokab.utils import get_key
from absl import logging
from flax import linen as nn
from flax.metrics import tensorboard
from flax.training import train_state
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax

from jaxtyping import Array
from typing import Any, Tuple, Union, Sequence

  from .autonotebook import tqdm as notebook_tqdm

SWIGLAL standard output/error redirection is enabled in IPython.
This may lead to performance penalties. To disable locally, use:

with lal.no_swig_redirect_standard_output_error():
    ...

To disable globally, use:

lal.swig_redirect_standard_output_error(False)

Note however that this will likely lead to error messages from
LAL functions being either misdirected or lost when called from
Jupyter notebooks.


import lal

  import lal
[Loading lalsimutils.py : MonteCarloMarginalization version]
  scipy :  1.13.0
  numpy :  1.26.4


In [2]:
def read_vt_file(file_path: str = "./vt_1_200_1000.hdf5") -> Sequence[Array]:
    """Interpolates the VT values from an HDF5 file based on given m1 and m2 coordinates.

    :param m1: The m1 coordinate.
    :param m2: The m2 coordinate.
    :param file_path: The path to the HDF5 file, defaults to "./vt_1_200_1000.hdf5"
    :return: The interpolated VT value.
    """
    with h5py.File(file_path, "r") as hdf5_file:
        m1_grid = hdf5_file["m1"][:]
        m2_grid = hdf5_file["m2"][:]
        VT_grid = hdf5_file["VT"][:]
        m1_coord = m1_grid[0]
        m2_coord = m2_grid[:, 0]

    return m1_coord, m2_coord, VT_grid

In [3]:
class NeuralVT(nn.Module):
    """A neural network that approximates the VT function.

    Dense(2)->ReLU->Dense(128)->ReLU->Dense(128)->ReLU->Dense(1)
    """

    @nn.compact
    def __call__(self, *args, **kwargs):
        x = args[0]
        # x = nn.Dense(2)(x)
        # x = nn.softmax(x)
        x = nn.Dense(32)(x)
        x = nn.softmax(x)
        x = nn.Dense(32)(x)
        x = nn.softmax(x)
        x = nn.Dense(1)(x)
        return x

In [4]:
@jax.jit
def apply_model(state, m1m2, vt):

    def loss_fn(params):
        vt_pred = state.apply_fn(params, m1m2)
        return jnp.mean((vt - vt_pred) ** 2)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss,), grads = grad_fn(state.params)
    accuracy = metrics.accuracy(vt, state.apply_fn(state.params, m1m2))
    return grads, loss, accuracy

In [5]:
@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

In [6]:
def train_epoch(state, train_ds, batch_size, rng):
    train_ds_size = len(train_ds["m1m2"])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, len(train_ds["m1m2"]))
    perms = perms[: steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))

    epoch_loss = []
    epoch_accuracy = []

    for perm in perms:
        batch_images = train_ds["m1m2"][perm, ...]
        batch_labels = train_ds["label"][perm, ...]
        grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
        state = update_model(state, grads)
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)
    train_loss = jnp.mean(epoch_loss)
    train_accuracy = jnp.mean(epoch_accuracy)
    return state, train_loss, train_accuracy

In [7]:
def get_datasets():
    """Load MNIST train and test datasets into memory."""
    m1_from_file, m2_from_file, VT_from_file = read_vt_file()
    m1_mesh, m2_mesh = jnp.meshgrid(m1_from_file, m2_from_file)

    m1 = m1_mesh.flatten()
    m2 = m2_mesh.flatten()
    VT = VT_from_file.flatten().reshape(-1, 1)

    train_to_test_ratio = 0.7
    train_size = int(len(m1) * train_to_test_ratio)

    # randomly shuffle the data and split into train and test sets
    key = get_key()
    perm = jax.random.permutation(key, len(m1))
    m1, m2, VT = m1[perm, ...], m2[perm, ...], VT[perm, ...]
    m1_train, m2_train, VT_train = m1[:train_size], m2[:train_size], VT[:train_size]
    m1_test, m2_test, VT_test = m1[train_size:], m2[train_size:], VT[train_size:]

    m1m2_train = jnp.column_stack((m1_train, m2_train))
    m1m2_test = jnp.column_stack((m1_test, m2_test))

    return (
        {"m1m2": m1m2_train, "label": VT_train},
        {"m1m2": m1m2_test, "label": VT_test},
    )

In [8]:
def create_train_state(rng, config):
    """Creates initial `TrainState`."""
    nvt = NeuralVT()
    params = nvt.init(rng, jnp.ones([1, 28, 28, 1]))["params"]
    tx = optax.sgd(learning_rate=1e-3, momentum=0.9)
    return train_state.TrainState.create(
        apply_fn=nvt.apply,
        params=params,
        tx=tx,
    )

In [9]:
def train_and_evaluate(
    config: ml_collections.ConfigDict,
    workdir: str,
) -> train_state.TrainState:
    """Execute model training and evaluation loop.

    Args:
      config: Hyperparameter configuration for training and evaluation.
      workdir: Directory where the tensorboard summaries are written to.

    Returns:
      The train state (which includes the `.params`).
    """
    train_ds, test_ds = get_datasets()
    rng = get_key()

    summary_writer = tensorboard.SummaryWriter(workdir)
    summary_writer.hparams(dict(config))

    rng, init_rng = jax.random.split(rng)
    state = create_train_state(init_rng, config)

    for epoch in range(1, config.num_epochs + 1):
        rng, input_rng = jax.random.split(rng)
        state, train_loss, train_accuracy = train_epoch(state, train_ds, config.batch_size, input_rng)
        _, test_loss, test_accuracy = apply_model(state, test_ds["image"], test_ds["label"])

        logging.info(
            "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f,"
            " test_accuracy: %.2f"
            % (
                epoch,
                train_loss,
                train_accuracy * 100,
                test_loss,
                test_accuracy * 100,
            )
        )

        summary_writer.scalar("train_loss", train_loss, epoch)
        summary_writer.scalar("train_accuracy", train_accuracy, epoch)
        summary_writer.scalar("test_loss", test_loss, epoch)
        summary_writer.scalar("test_accuracy", test_accuracy, epoch)

    summary_writer.flush()
    return state