# IDP NF Code

This is the unstructured code of IDP NF project.

Each protein molecular is represented as $(\theta)$ with dihedral angle $(i, j, k, l)$ which can be transformed into a standard 3D coordinate system $(x_1, x_2, x_3)$.

The main structure is a normalizing flow consisting of $K$ $f_k$, which looks like [base] - {[circular shift]-[coupling layer]}$_{k=1}^K$ - [loss (energy + log absolute determinant)].



## Circular Shift

This learned circular shift layer is to ensure that all angle coordinates are within $[0, 2\pi]$.

The formal formula is given by

$$\theta_{ni}\rightarrow(\theta_{ni}+c_i)\text{ mod }2\pi$$

Import necessary libraries

In [3]:
from typing import Union, Sequence, Callable, Mapping, Any, Tuple
import functools
import torch
from torch import Tensor
from nflows.transforms.base import (
    Transform, CompositeTransform
)
from nflows.transforms.linear import NaiveLinear

In [4]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

Define class `CircularShift`

In [None]:
class CircularShift(Transform):
    def __init__(self,
                 shift: Tensor,
                 lower: Union[float, Tensor],
                 upper: Union[float, Tensor]):
        if (not torch.is_tensor(lower)) and (not torch.is_tensor(upper)) and (lower >= upper):
            raise ValueError('`lower` must be less than `upper`.')

        try:
            width = upper - lower
        except TypeError as e:
            raise ValueError('`lower` and `upper` must be broadcastable to same '
                            f'shape, but `lower`={lower} and `upper`={upper}') from e

        self.wrap = lambda x: torch.remainder(x - lower, width) + lower
        self.shift = self.wrap(shift)
    
    def forward(self, inputs, context=None):
        outputs = self.wrap(inputs + self.shift)
        logabsdet = torch.zeros_like(inputs)
        return outputs, logabsdet

    def inverse(self, inputs, context=None):
        outputs = self.wrap(inputs - self.shift)
        logabsdet = torch.zeros_like(inputs)
        return outputs, logabsdet

# Coupling Flows
## Step 1: Rewrite embeddings.circular

In [None]:
def circular(x: Tensor,
             lower: float,
             upper: float,
             num_frequencies: int) -> Tensor:
    """Maps angles to points on the unit circle.

    The mapping is such that the interval [lower, upper] is mapped to a full
    circle starting and ending at (1, 0). For num_frequencies > 1, the mapping
    also includes higher frequencies which are multiples of 2 pi/(lower-upper)
    so that [lower, upper] wraps around the unit circle multiple times.

    Args:
        x: array of shape [..., D].
        lower: lower limit, angles equal to this will be mapped to (1, 0).
        upper: upper limit, angles equal to this will be mapped to (1, 0).
        num_frequencies: number of frequencies to consider in the embedding.

    Returns:
        An array of shape [..., 2*num_frequencies*D].
    """
    base_frequency = 2. * torch.pi / (upper - lower)
    frequencies = base_frequency * torch.arange(1, num_frequencies+1)
    angles = frequencies * (x[..., None] - lower)
    # Reshape from [..., D, num_frequencies] to [..., D*num_frequencies].
    angles = angles.reshape(x.shape[:-1] + (-1,))
    cos = torch.cos(angles)
    sin = torch.sin(angles)
    return torch.concat([cos, sin], axis=-1)


## Step 2: Rewrite conditioner

In [None]:
def _reshape_last(x: Tensor, ndims: int, new_shape: Sequence[int]) -> Tensor:
    """Reshapes the last `ndims` dimensions of `x` to shape `new_shape`."""
    if ndims <= 0:
        raise ValueError(
            f'Number of dimensions to reshape must be positive, got {ndims}.')
    return torch.reshape(x, x.shape[:-ndims] + tuple(new_shape))


def make_equivariant_conditioner(
    num_bijector_params: int,
    lower: float,
    upper: float,
    embedding_size: int,
    conditioner_constructor: Callable[..., Any],
    conditioner_kwargs: Mapping[str, Any],
    num_frequencies: int,
) -> CompositeTransform:
    """Make a permutation-equivariant conditioner for the coupling flow."""
    # This conditioner assumes that the input is of shape [..., N, D1]. It returns
    # an output of shape [..., N, D2, K], where:
    #   D2 = `shape_transformed[-1]`
    #   K = `num_bijector_params`
    conditioner = conditioner_constructor(**conditioner_kwargs)
    return CompositeTransform([
        functools.partial(
            circular, lower=lower, upper=upper,
            num_frequencies=num_frequencies),
        NaiveLinear(embedding_size),
        conditioner,
        NaiveLinear(num_bijector_params),
        functools.partial(
            _reshape_last, ndims=1, new_shape=(num_bijector_params)),
    ])


## Step 3: Rewrite 
- distrax.SplitCoupling

In [None]:
class SplitCoupling(Transform):
  """Split coupling bijector, with arbitrary conditioner & inner bijector.
  This coupling bijector splits the input array into two parts along a specified
  axis. One part remains unchanged, whereas the other part is transformed by an
  inner bijector conditioned on the unchanged part.
  Let `f` be a conditional bijector (the inner bijector) and `g` be a function
  (the conditioner). For `swap=False`, the split coupling bijector is defined as
  follows:
  - Forward:
    ```
    x = [x1, x2]
    y1 = x1
    y2 = f(x2; g(x1))
    y = [y1, y2]
    ```
  - Forward Jacobian log determinant:
    ```
    x = [x1, x2]
    log|det J(x)| = log|det df/dx2(x2; g(x1))|
    ```
  - Inverse:
    ```
    y = [y1, y2]
    x1 = y1
    x2 = f^{-1}(y2; g(y1))
    x = [x1, x2]
    ```
  - Inverse Jacobian log determinant:
    ```
    y = [y1, y2]
    log|det J(y)| = log|det df^{-1}/dy2(y2; g(y1))|
    ```
  Here, `[x1, x2]` is a partition of `x` along some axis. By default, `x1`
  remains unchanged and `x2` is transformed. If `swap=True`, `x2` will remain
  unchanged and `x1` will be transformed.
  """

  def __init__(self,
               angles: Tensor,
               conditioner: function,
               bijector: function,):
    """Initializes a SplitCoupling bijector.
    Args:
      conditioner: a function that computes the parameters of the inner bijector
        as a function of the unchanged part of the input. The output of the
        conditioner will be passed to `bijector` in order to obtain the inner
        bijector.
      bijector: a callable that returns the inner bijector that will be used to
        transform one of the two parts. The input to `bijector` is a set of
        parameters that can be used to configure the inner bijector. The
        `event_ndims_in` and `event_ndims_out` of the inner bijector must be
        equal, and less than or equal to `event_ndims`. If they are less than
        `event_ndims`, the remaining dimensions will be converted to event
        dimensions using `distrax.Block`.
    """
    self._angles = angles
    self._conditioner = conditioner
    self._bijector = bijector

  @property
  def bijector(self) -> function:
    """The callable that returns the inner bijector of `SplitCoupling`."""
    return self._bijector

  @property
  def conditioner(self) -> function:
    """The conditioner function."""
    return self._conditioner

  def forward(self, x: Tensor, context=None) -> Tuple[Tensor, Tensor]:
    """Computes y = f(x) and log|det J(f)(x)|."""
    params = self._conditioner(self._angles)
    y, logdet = self._bijector(params).forward(x)
    return y, logdet

  def inverse(self, y: Tensor, context=None) -> Tuple[Tensor, Tensor]:
    """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
    params = self._conditioner(self._angles)
    x, logdet = self._bijector(params).inverse(y)
    return x, logdet

- distrax.RationalQuadraticSpline

In [83]:
# TODO
"""Rational-quadratic spline bijector."""
def _normalize_bin_sizes(unnormalized_bin_sizes: Tensor,
                         total_size: float,
                         min_bin_size: float) -> Tensor:
  """Make bin sizes sum to `total_size` and be no less than `min_bin_size`."""
  num_bins = unnormalized_bin_sizes.shape[-1]
  if num_bins * min_bin_size > total_size:
    raise ValueError(
        f'The number of bins ({num_bins}) times the minimum bin size'
        f' ({min_bin_size}) cannot be greater than the total bin size'
        f' ({total_size}).')
  bin_sizes = torch.softmax(unnormalized_bin_sizes, axis=-1)
  return bin_sizes * (total_size - num_bins * min_bin_size) + min_bin_size


def _normalize_knot_slopes(unnormalized_knot_slopes: Tensor,
                           min_knot_slope: float) -> Tensor:
  """Make knot slopes be no less than `min_knot_slope`."""
  # The offset is such that the normalized knot slope will be equal to 1
  # whenever the unnormalized knot slope is equal to 0.
  if min_knot_slope >= 1.:
    raise ValueError(f'The minimum knot slope must be less than 1; got'
                     f' {min_knot_slope}.')
  min_knot_slope = torch.tensor(
      min_knot_slope, dtype=unnormalized_knot_slopes.dtype)
  offset = torch.log(torch.exp(1. - min_knot_slope) - 1.)
  softplus = torch.nn.Softplus()
  return softplus(unnormalized_knot_slopes + offset) + min_knot_slope


def _rational_quadratic_spline_fwd(x: Tensor,
                                   x_pos: Tensor,
                                   y_pos: Tensor,
                                   knot_slopes: Tensor) -> Tuple[Tensor, Tensor]:
  """Applies a rational-quadratic spline to a scalar.
  Args:
    x: a scalar (0-dimensional array). The scalar `x` can be any real number; it
      will be transformed by the spline if it's in the closed interval
      `[x_pos[0], x_pos[-1]]`, and it will be transformed linearly if it's
      outside that interval.
    x_pos: array of shape [num_bins + 1], the bin boundaries on the x axis.
    y_pos: array of shape [num_bins + 1], the bin boundaries on the y axis.
    knot_slopes: array of shape [num_bins + 1], the slopes at the knot points.
  Returns:
    A tuple of two scalars: the output of the transformation and the log of the
    absolute first derivative at `x`.
  """
  # Search to find the right bin. NOTE: The bins are sorted, so we could use
  # binary search, but this is more GPU/TPU friendly.
  # The following implementation avoids indexing for faster TPU computation.
  below_range = x <= x_pos[0]
  above_range = x >= x_pos[-1]
  correct_bin = torch.logical_and(x[..., None] >= x_pos[:-1][None, None, ...], 
                                  x[..., None] < x_pos[1:][None, None, ...])
  any_bin_in_range = torch.any(correct_bin, dim=2)
  first_bin = torch.concat([torch.tensor([1], dtype=bool),
                            torch.zeros(correct_bin.shape[-1]-1, dtype=bool)])
  # If y does not fall into any bin, we use the first spline in the following
  # computations to avoid numerical issues.
  correct_bin[~any_bin_in_range] = first_bin
  # Dot product of each parameter with the correct bin mask.
  params = torch.stack([x_pos, y_pos, knot_slopes], axis=1)
  
  params_bin_left = torch.sum(correct_bin[..., None] * params[:-1], axis=2)
  params_bin_right = torch.sum(correct_bin[..., None] * params[1:], axis=2)

  x_pos_bin = (params_bin_left[..., 0], params_bin_right[..., 0])
  y_pos_bin = (params_bin_left[..., 1], params_bin_right[..., 1])
  knot_slopes_bin = (params_bin_left[..., 2], params_bin_right[..., 2])

  bin_width = x_pos_bin[1] - x_pos_bin[0]
  bin_height = y_pos_bin[1] - y_pos_bin[0]
  bin_slope = bin_height / bin_width

  z = (x - x_pos_bin[0]) / bin_width
  # `z` should be in range [0, 1] to avoid NaNs later. This can happen because
  # of small floating point issues or when x is outside of the range of bins.
  # To avoid all problems, we restrict z in [0, 1].
  z = torch.clip(z, 0., 1.)
  sq_z = z * z
  z1mz = z - sq_z  # z(1-z)
  sq_1mz = (1. - z) ** 2
  slopes_term = knot_slopes_bin[1] + knot_slopes_bin[0] - 2. * bin_slope
  numerator = bin_height * (bin_slope * sq_z + knot_slopes_bin[0] * z1mz)
  denominator = bin_slope + slopes_term * z1mz
  y = y_pos_bin[0] + numerator / denominator

  # Compute log det Jacobian.
  # The logdet is a sum of 3 logs. It is easy to see that the inputs of the
  # first two logs are guaranteed to be positive because we ensured that z is in
  # [0, 1]. This is also true of the log(denominator) because:
  # denominator
  # == bin_slope + (knot_slopes_bin[1] + knot_slopes_bin[0] - 2 * bin_slope) *
  # z*(1-z)
  # >= bin_slope - 2 * bin_slope * z * (1-z)
  # >= bin_slope - 2 * bin_slope * (1/4)
  # == bin_slope / 2
  logdet = 2. * torch.log(bin_slope) + torch.log(
      knot_slopes_bin[1] * sq_z + 2. * bin_slope * z1mz +
      knot_slopes_bin[0] * sq_1mz) - 2. * torch.log(denominator)

  # If x is outside the spline range, we default to a linear transformation.
  y = torch.where(below_range, (x - x_pos[0]) * knot_slopes[0] + y_pos[0], y)
  y = torch.where(above_range, (x - x_pos[-1]) * knot_slopes[-1] + y_pos[-1], y)
  logdet = torch.where(below_range, torch.log(knot_slopes[0]), logdet)
  logdet = torch.where(above_range, torch.log(knot_slopes[-1]), logdet)
  return y, logdet


def _safe_quadratic_root(a: Tensor, b: Tensor, c: Tensor) -> Tensor:
  """Implement a numerically stable version of the quadratic formula."""
  # This is not a general solution to the quadratic equation, as it assumes
  # b ** 2 - 4. * a * c is known a priori to be positive (and which of the two
  # roots is to be used, see https://arxiv.org/abs/1906.04032).
  # There are two sources of instability:
  # (a) When b ** 2 - 4. * a * c -> 0, sqrt gives NaNs in gradient.
  # We clip sqrt_diff to have the smallest float number.
  sqrt_diff = b ** 2 - 4. * a * c
  safe_sqrt = torch.sqrt(torch.clip(sqrt_diff, torch.finfo(sqrt_diff.dtype).tiny))
  # If sqrt_diff is non-positive, we set sqrt to 0. as it should be positive.
  safe_sqrt = torch.where(sqrt_diff > 0., safe_sqrt, 0.)
  # (b) When 4. * a * c -> 0. We use the more stable quadratic solution
  # depending on the sign of b.
  # See https://people.csail.mit.edu/bkph/articles/Quadratics.pdf (eq 7 and 8).
  # Solution when b >= 0
  numerator_1 = 2. * c
  denominator_1 = -b - safe_sqrt
  # Solution when b < 0
  numerator_2 = - b + safe_sqrt
  denominator_2 = 2 * a
  # Choose the numerically stable solution.
  numerator = torch.where(b >= 0, numerator_1, numerator_2)
  denominator = torch.where(b >= 0, denominator_1, denominator_2)
  return numerator / denominator


def _rational_quadratic_spline_inv(y: Tensor,
                                   x_pos: Tensor,
                                   y_pos: Tensor,
                                   knot_slopes: Tensor) -> Tuple[Tensor, Tensor]:
  """Applies the inverse of a rational-quadratic spline to a scalar.
  Args:
    y: a scalar (0-dimensional array). The scalar `y` can be any real number; it
      will be transformed by the spline if it's in the closed interval
      `[y_pos[0], y_pos[-1]]`, and it will be transformed linearly if it's
      outside that interval.
    x_pos: array of shape [num_bins + 1], the bin boundaries on the x axis.
    y_pos: array of shape [num_bins + 1], the bin boundaries on the y axis.
    knot_slopes: array of shape [num_bins + 1], the slopes at the knot points.
  Returns:
    A tuple of two scalars: the output of the inverse transformation and the log
    of the absolute first derivative of the inverse at `y`.
  """
  # Search to find the right bin. NOTE: The bins are sorted, so we could use
  # binary search, but this is more GPU/TPU friendly.
  # The following implementation avoids indexing for faster TPU computation.
  below_range = y <= y_pos[0]
  above_range = y >= y_pos[-1]
  correct_bin = torch.logical_and(y[..., None] >= y_pos[:-1][None, None, ...], 
                                  y[..., None] < y_pos[1:][None, None, ...])
  any_bin_in_range = torch.any(correct_bin)
  first_bin = torch.concat([torch.tensor([1], dtype=bool),
                            torch.zeros(correct_bin.shape[-1]-1, dtype=bool)])
  # If y does not fall into any bin, we use the first spline in the following
  # computations to avoid numerical issues.
  correct_bin[~any_bin_in_range] = first_bin
  # Dot product of each parameter with the correct bin mask.
  params = torch.stack([x_pos, y_pos, knot_slopes], axis=1)
  params_bin_left = torch.sum(correct_bin[..., None] * params[:-1], axis=2)
  params_bin_right = torch.sum(correct_bin[..., None] * params[1:], axis=2)

  # These are the parameters for the corresponding bin.
  x_pos_bin = (params_bin_left[..., 0], params_bin_right[..., 0])
  y_pos_bin = (params_bin_left[..., 1], params_bin_right[..., 1])
  knot_slopes_bin = (params_bin_left[..., 2], params_bin_right[..., 2])

  bin_width = x_pos_bin[1] - x_pos_bin[0]
  bin_height = y_pos_bin[1] - y_pos_bin[0]
  bin_slope = bin_height / bin_width
  w = (y - y_pos_bin[0]) / bin_height
  w = torch.clip(w, 0., 1.)  # Ensure w is in [0, 1].
  # Compute quadratic coefficients: az^2 + bz + c = 0
  slopes_term = knot_slopes_bin[1] + knot_slopes_bin[0] - 2. * bin_slope
  c = - bin_slope * w
  b = knot_slopes_bin[0] - slopes_term * w
  a = bin_slope - b

  # Solve quadratic to obtain z and then x.
  z = _safe_quadratic_root(a, b, c)
  z = torch.clip(z, 0., 1.)  # Ensure z is in [0, 1].
  x = bin_width * z + x_pos_bin[0]

  # Compute log det Jacobian.
  sq_z = z * z
  z1mz = z - sq_z  # z(1-z)
  sq_1mz = (1. - z) ** 2
  denominator = bin_slope + slopes_term * z1mz
  logdet = - 2. * torch.log(bin_slope) - torch.log(
      knot_slopes_bin[1] * sq_z + 2. * bin_slope * z1mz +
      knot_slopes_bin[0] * sq_1mz) + 2. * torch.log(denominator)

  # If y is outside the spline range, we default to a linear transformation.
  x = torch.where(below_range, (y - y_pos[0]) / knot_slopes[0] + x_pos[0], x)
  x = torch.where(above_range, (y - y_pos[-1]) / knot_slopes[-1] + x_pos[-1], x)
  logdet = torch.where(below_range, - torch.log(knot_slopes[0]), logdet)
  logdet = torch.where(above_range, - torch.log(knot_slopes[-1]), logdet)
  return x, logdet


In [None]:
class RationalQuadraticSpline(Transform):
  """A rational-quadratic spline bijector.
  Implements the spline bijector introduced by:
  > Durkan et al., Neural Spline Flows, https://arxiv.org/abs/1906.04032, 2019.
  This bijector is a monotonically increasing spline operating on an interval
  [a, b], such that f(a) = a and f(b) = b. Outside the interval [a, b], the
  bijector defaults to a linear transformation whose slope matches that of the
  spline at the nearest boundary (either a or b). The range boundaries a and b
  are hyperparameters passed to the constructor.
  The spline on the interval [a, b] consists of `num_bins` segments, on each of
  which the spline takes the form of a rational quadratic (ratio of two
  quadratic polynomials). The first derivative of the bijector is guaranteed to
  be continuous on the whole real line. The second derivative is generally not
  continuous at the knot points (bin boundaries).
  The spline is parameterized by the bin sizes on the x and y axis, and by the
  slopes at the knot points. All spline parameters are passed to the constructor
  as an unconstrained array `params` of shape `[..., 3 * num_bins + 1]`. The
  spline parameters are extracted from `params`, and are reparameterized
  internally as appropriate. The number of bins is a hyperparameter, and is
  implicitly defined by the last dimension of `params`.
  This bijector is applied elementwise. Given some input `x`, the parameters
  `params` and the input `x` are broadcast against each other. For example,
  suppose `x` is of shape `[N, D]`. Then:
  - If `params` is of shape `[3 * num_bins + 1]`, the same spline is identically
    applied to each element of `x`.
  - If `params` is of shape `[D, 3 * num_bins + 1]`, the same spline is applied
    along the first axis of `x` but a different spline is applied along the
    second axis of `x`.
  - If `params` is of shape `[N, D, 3 * num_bins + 1]`, a different spline is
    applied to each element of `x`.
  - If `params` is of shape `[M, N, D, 3 * num_bins + 1]`, `M` different splines
    are applied to each element of `x`, and the output is of shape `[M, N, D]`.
  """

  def __init__(self,
               params: Tensor,
               range_min: float,
               range_max: float,
               boundary_slopes: str = 'unconstrained',
               min_bin_size: float = 1e-4,
               min_knot_slope: float = 1e-4):
    """Initializes a RationalQuadraticSpline bijector.
    Args:
      params: array of shape `[..., 3 * num_bins + 1]`, the unconstrained
        parameters of the bijector. The number of bins is implicitly defined by
        the last dimension of `params`. The parameters can take arbitrary
        unconstrained values; the bijector will reparameterize them internally
        and make sure they obey appropriate constraints. If `params` is the
        all-zeros array, the bijector becomes the identity function everywhere
        on the real line.
      range_min: the lower bound of the spline's range. Below `range_min`, the
        bijector defaults to a linear transformation.
      range_max: the upper bound of the spline's range. Above `range_max`, the
        bijector defaults to a linear transformation.
      boundary_slopes: controls the behaviour of the slope of the spline at the
        range boundaries (`range_min` and `range_max`). It is used to enforce
        certain boundary conditions on the spline. Available options are:
        - 'unconstrained': no boundary conditions are imposed; the slopes at the
          boundaries can vary freely.
        - 'lower_identity': the slope of the spline is set equal to 1 at the
          lower boundary (`range_min`). This makes the bijector equal to the
          identity function for values less than `range_min`.
        - 'upper_identity': similar to `lower_identity`, but now the slope of
          the spline is set equal to 1 at the upper boundary (`range_max`). This
          makes the bijector equal to the identity function for values greater
          than `range_max`.
        - 'identity': combines the effects of 'lower_identity' and
          'upper_identity' together. The slope of the spline is set equal to 1
          at both boundaries (`range_min` and `range_max`). This makes the
          bijector equal to the identity function outside the interval
          `[range_min, range_max]`.
        - 'circular': makes the slope at `range_min` and `range_max` be the
          same. This implements the "circular spline" introduced by:
          > Rezende et al., Normalizing Flows on Tori and Spheres,
          > https://arxiv.org/abs/2002.02428, 2020.
          This option should be used when the spline operates on a circle
          parameterized by an angle in the interval `[range_min, range_max]`,
          where `range_min` and `range_max` correspond to the same point on the
          circle.
      min_bin_size: The minimum bin size, in either the x or the y axis. Should
        be a small positive number, chosen for numerical stability. Guarantees
        that no bin in either the x or the y axis will be less than this value.
      min_knot_slope: The minimum slope at each knot point. Should be a small
        positive number, chosen for numerical stability. Guarantess that no knot
        will have a slope less than this value.
    """
    super().__init__()
    if params.shape[-1] % 3 != 1 or params.shape[-1] < 4:
      raise ValueError(f'The last dimension of `params` must have size'
                       f' `3 * num_bins + 1` and `num_bins` must be at least 1.'
                       f' Got size {params.shape[-1]}.')
    if range_min >= range_max:
      raise ValueError(f'`range_min` must be less than `range_max`. Got'
                       f' `range_min={range_min}` and `range_max={range_max}`.')
    if min_bin_size <= 0.:
      raise ValueError(f'The minimum bin size must be positive; got'
                       f' {min_bin_size}.')
    if min_knot_slope <= 0.:
      raise ValueError(f'The minimum knot slope must be positive; got'
                       f' {min_knot_slope}.')
    self._dtype = params.dtype
    self._num_bins = (params.shape[-1] - 1) // 3
    # Extract unnormalized parameters.
    unnormalized_bin_widths = params[..., :self._num_bins]
    unnormalized_bin_heights = params[..., self._num_bins : 2 * self._num_bins]
    unnormalized_knot_slopes = params[..., 2 * self._num_bins:]
    # Normalize bin sizes and compute bin positions on the x and y axis.
    range_size = range_max - range_min
    bin_widths = _normalize_bin_sizes(unnormalized_bin_widths, range_size,
                                      min_bin_size)
    bin_heights = _normalize_bin_sizes(unnormalized_bin_heights, range_size,
                                       min_bin_size)
    x_pos = range_min + torch.cumsum(bin_widths[..., :-1], axis=-1)
    y_pos = range_min + torch.cumsum(bin_heights[..., :-1], axis=-1)
    pad_shape = params.shape[:-1] + (1,)
    pad_below = torch.full(pad_shape, range_min, dtype=self._dtype)
    pad_above = torch.full(pad_shape, range_max, dtype=self._dtype)
    self._x_pos = torch.concat([pad_below, x_pos, pad_above], axis=-1)
    self._y_pos = torch.concat([pad_below, y_pos, pad_above], axis=-1)
    # Normalize knot slopes and enforce requested boundary conditions.
    knot_slopes = _normalize_knot_slopes(unnormalized_knot_slopes,
                                         min_knot_slope)
    if boundary_slopes == 'unconstrained':
      self._knot_slopes = knot_slopes
    elif boundary_slopes == 'lower_identity':
      ones = torch.ones(pad_shape, self._dtype)
      self._knot_slopes = torch.concat([ones, knot_slopes[..., 1:]], axis=-1)
    elif boundary_slopes == 'upper_identity':
      ones = torch.ones(pad_shape, self._dtype)
      self._knot_slopes = torch.concat(
          [knot_slopes[..., :-1], ones], axis=-1)
    elif boundary_slopes == 'identity':
      ones = torch.ones(pad_shape, self._dtype)
      self._knot_slopes = torch.concat(
          [ones, knot_slopes[..., 1:-1], ones], axis=-1)
    elif boundary_slopes == 'circular':
      self._knot_slopes = torch.concat(
          [knot_slopes[..., :-1], knot_slopes[..., :1]], axis=-1)
    else:
      raise ValueError(f'Unknown option for boundary slopes:'
                       f' `{boundary_slopes}`.')

  @property
  def num_bins(self) -> int:
    """The number of segments on the interval."""
    return self._num_bins

  @property
  def knot_slopes(self) -> Tensor:
    """The slopes at the knot points."""
    return self._knot_slopes

  @property
  def x_pos(self) -> Tensor:
    """The bin boundaries on the `x`-axis."""
    return self._x_pos

  @property
  def y_pos(self) -> Tensor:
    """The bin boundaries on the `y`-axis."""
    return self._y_pos

  def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
    """Computes y = f(x) and log|det J(f)(x)|."""
    y, logdet = _rational_quadratic_spline_fwd(
      x, self._x_pos, self._y_pos, self._knot_slopes)
    return y, logdet

  def inverse(self, y: Tensor) -> Tuple[Tensor, Tensor]:
    """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
    x, logdet = _rational_quadratic_spline_inv(
        y, self._x_pos, self._y_pos, self._knot_slopes)
    return x, logdet


## Step 4: Rewrite coupling layers as follows
[$x_1$] - [conditioner ($x_2$) ($C$)] - [rational quadratic spline ($G$ paramed by $C$)]

[$x_2$]

In [None]:
def make_split_coupling_flow(
    angles: Tensor,
    lower: float,
    upper: float,
    num_layers: int,
    num_bins: int,
    conditioner: Mapping[str, Any],
    use_circular_shift: bool,
    circular_shift_init: function = torch.zeros,
) -> CompositeTransform:
  """Create a flow that consists of a sequence of split coupling layers.

  All coupling layers use rational-quadratic splines. Each layer of the flow
  is composed of two split coupling bijectors, where each coupling bijector
  transforms a different part of the input.

  The flow maps to and from the range `[lower, upper]`, obeying periodic
  boundary conditions.

  Args:
    angles: a Tensor of shape (N) whose N is the batch size and 1 is angles.
    lower: lower range of the flow.
    upper: upper range of the flow.
    num_layers: the number of layers to use. Each layer consists of two split
      coupling bijectors, where each coupling bijector transforms a different
      part of the input.
    num_bins: number of bins to use in the rational-quadratic splines.
    conditioner: a Mapping containing 'constructor' and 'kwargs' keys that
      configures the conditioner used in the coupling layers.
    use_circular_shift: if True, add a learned circular shift between successive
      flow layers.
    circular_shift_init: initializer for the circular shifts.

  Returns:
    The flow, a Distrax bijector.
  """

  def bijector_fn(params: Tensor):
    return RationalQuadraticSpline(
        params,
        range_min=lower,
        range_max=upper,
        boundary_slopes='circular',
        min_bin_size=(upper - lower) * 1e-4)

  layers = []
  for _ in range(num_layers):
    sublayers = []

    # Circular shift.
    if use_circular_shift:
      shift = torch.nn.Parameter(
          name='circular_shift',
          param_name='shift',
          shape=(1),
          init=circular_shift_init)()
      shift_layer = CircularShift(
          (upper - lower) * shift, lower, upper)
      sublayers.append(shift_layer)

    # Coupling layer.
    coupling_layer = SplitCoupling(
        angles=angles,
        bijector=bijector_fn,
        conditioner=conditioner['constructor'](
            num_bijector_params=3 * num_bins + 1,
            lower=lower,
            upper=upper,
            **conditioner['kwargs'])
        )
    sublayers.append(coupling_layer)
    layers.append(CompositeTransform(sublayers))

  return CompositeTransform(layers)


## Step 5: Rewrite rdkit setDihedralRad

In [5]:
from rdkit.Chem.rdchem import Mol
from torch import Tensor
from torch.nn import Module
from rdkit import Chem

class Dihedral2Coord(Module):
    """Transform dihedral angles of a batch of conformers into 3D coordinates."""

    def __init__(self, mol: Mol, angles: Tensor):
        """
        Initialization of D2C layer.

        Args:
            mol (Mol): N molecular conformation with the same backbone and possibly different dihedral angles.
            angles (Tensor): a Tensor of shape (K, 4) where K is the number of dihedral angles for a conformer, 4 is (iAtomId, jAtomId, kAtomId, lAtomId).
        """
        self.mol = mol
        self.angles = angles
        self.alist = {}
        self.toBeMovedIdxList()


    def toBeMovedIdxList(self):
        """
        An implementation of toBeMovedIdxList from rdkit.
        See https://github.com/rdkit/rdkit/blob/master/Code/GraphMol/MolTransforms/MolTransforms.cpp#L426
        """
        nAtoms = self.mol.GetNumAtoms()
        K = self.angles.shape[0]
        for i in K:
            iAtomId = self.angles[i, 1].item()
            jAtomId = self.angles[i, 2].item()
            if (iAtomId, jAtomId) not in self.alist:
                self.alist[(iAtomId, jAtomId)] = []
                visitedIdx = [False for _ in range(nAtoms)]
                stack = []
                stack.append(jAtomId)
                visitedIdx[iAtomId] = 1
                visitedIdx[jAtomId] = 1
                tIdx = 0
                wIdx = 0
                doMainLoop = True
                while len(stack) > 0:
                    doMainLoop = False
                    tIdx = stack[0]
                    tAtom = self.mol.GetAtomWithIdx(tIdx)
                    neighbors = tAtom.GetNeighbors()
                    nbrIdx = 0
                    endNbrs = len(neighbors)
                    while nbrIdx != endNbrs:
                        wIdx = neighbors[nbrIdx].GetIdx()
                        if not visitedIdx[wIdx]:
                            visitedIdx[wIdx] = 1
                            stack.append(wIdx)
                            doMainLoop = True
                            break
                        nbrIdx += 1
                    if doMainLoop:
                        continue
                    visitedIdx[tIdx] = 1
                    stack.pop()
                self.alist[(iAtomId, jAtomId)].clear()
                for j in range(nAtoms):
                    if visitedIdx[j] and j != iAtomId:
                        self.alist[(iAtomId, jAtomId)].append(j)


    def transformPoint(self, pt: Tensor, angle: Tensor, axis: Tensor):
        """
        An implementation of differentiable SetRotation and TransformPoint from rdkit.
        See https://github.com/rdkit/rdkit/blob/master/Code/Geometry/Transform3D.cpp

        Args:
            pt (Tensor): a Tensor of shape (N, 3) where N is the batch size, 3 is 3D coordinates.
            angle (Tensor): a Tensor of shape (N) where N is the batch size, 1 is the rotation angle.
            axis (Tensor): a Tensor of shape (N, 3) where N is the batch size, 3 is 3D coordinates of the axis.
        """
        N = pt.shape[0]
        data = torch.eye(4).reshape(1, 4, 4).repeat(N, 1, 1)
        cosT = angle.cos()
        sinT = angle.sin()
        t = 1 - cosT
        X = axis[:, 0]
        Y = axis[:, 1]
        Z = axis[:, 2]
        data[:, 0, 0] = t * X * X + cosT
        data[:, 0, 1] = t * X * Y - sinT * Z
        data[:, 0, 2] = t * X * Z + sinT * Y
        data[:, 1, 0] = t * X * Y + sinT * Z
        data[:, 1, 1] = t * Y * Y + cosT
        data[:, 1, 2] = t * Y * Z - sinT * X
        data[:, 2, 0] = t * X * Z - sinT * Y
        data[:, 2, 1] = t * Y * Z + sinT * X
        data[:, 2, 2] = t * Z * Z + cosT
        x = data[:, 0, 0] * pt[:, 0] + data[:, 0, 1] * pt[:, 1] + data[:, 0, 2] * pt[:, 2] + data[:, 0, 3]
        y = data[:, 1, 0] * pt[:, 0] + data[:, 1, 1] * pt[:, 1] + data[:, 1, 2] * pt[:, 2] + data[:, 1, 3]
        z = data[:, 2, 0] * pt[:, 0] + data[:, 2, 1] * pt[:, 1] + data[:, 2, 2] * pt[:, 2] + data[:, 2, 3]
        pt[:, 0] = x
        pt[:, 1] = y
        pt[:, 2] = z


    def setDihedralRad(self, input: Tensor, angle: Tensor) -> Tensor:
        """
        An implementation of differentiable setDihedralRad from rdkit.
        Note: This version has eliminated all fault checks temporarily. Add them if needed from the link below.
        See https://github.com/rdkit/rdkit/blob/master/Code/GraphMol/MolTransforms/MolTransforms.cpp#L612

        Args:
            mol (Mol): N molecular conformation with the same backbone and possibly different dihedral angles.
            input (Tensor): a Tensor of shape (N) where N is the batch size, 1 is (dihedral angle value).
            angle (Tensor): a Tensor of shape (4) where 4 is (iAtomId, jAtomId, kAtomId, lAtomId).

        Returns:
            output (Tensor): a Tensor of shape (N, M, 3) where N is the batch size, M is the number of atoms, 3 is the 3D coordinates (x, y, z).
        """
        pos = []
        confs = self.mol.GetConformers()
        for conf in confs:
            pos.append(torch.tensor(conf.GetPositions(),
                                    dtype=torch.float32,
                                    device=DEVICE))
        pos = torch.stack(pos)
        rIJ = pos[:, angle[1], :] - pos[:, angle[0], :]
        rJK = pos[:, angle[2], :] - pos[:, angle[1], :]
        rKL = pos[:, angle[3], :] - pos[:, angle[2], :]
        nIJK = rIJ.cross(rJK, dim=2)
        nJKL = rJK.cross(rKL, dim=2)
        m = nIJK.cross(rJK)
        N, _ = input.shape
        values = input + torch.atan2(m.reshape(N, 1, 3).bmm(nJKL.reshape(N, 3, 1)).reshape(N))
        rotAxisBegin = pos[:, angle[1], :]
        rotAxisEnd = pos[:, angle[2], :]
        rotAxis = rotAxisEnd - rotAxisBegin
        rotAxis.norm(dim=1)
        for it in self.alist[(angle[1], angle[2])]:
            pos[:, it, :] -= rotAxisBegin
            self.transformPoint(pos[:, it, :], values, rotAxis)
            pos[:, it, :] += rotAxisBegin
        return pos


    def forward(self, input: Tensor) -> Tensor:
        """
        An implementation of differentiable setDihedralRad from rdkit.
        TODO: This version has eliminated all fault checks temporarily. Add them if needed from the link below.
        See https://github.com/rdkit/rdkit/blob/master/Code/GraphMol/MolTransforms/MolTransforms.cpp#L612

        Args:
            input (Tensor): a Tensor of shape (N, K) where N is the batch size, K is the number of dihedral angles for a conformer, 1 is (dihedral angle value).

        Returns:
            output (Tensor): a Tensor of shape (N, M, 3) where N is the batch size, M is the number of atoms, 3 is the 3D coordinates (x, y, z).
        """
        N, K = input.shape
        for i in range(K):
            output = self.setDihedralRad(input[:, i], self.angles[i, :])
        confs = self.mol.GetConformers()
        for i in range(N):
            for j in range(K):
                Chem.rdMolTransforms.SetDihedralDeg(confs[i], 
                                                    self.angles[j, 0].item(),
                                                    self.angles[j, 1].item(),
                                                    self.angles[j, 2].item(),
                                                    self.angles[j, 3].item(),
                                                    input[i, j].item())
        return output

## Step 6: Write energy layer

In [None]:
import rdkit.Chem.AllChem as Chem2


class Energy(Module):
    """Energy loss with forward and backward pass."""
    
    def __init__(self, mol: Mol):
        self.mol = mol
        self.ff_list = []

    def forward(self, input: Tensor):
        """Forward pass of energy. Only mol matters here."""
        Chem2.MMFFSanitizeMolecule(self.mol)
        mmff_props = Chem2.MMFFGetMoleculeProperties(self.mol)
        energy = torch.tensor(0)
        for i in range(self.mol.GetNumConformers()):
            ff = Chem2.MMFFGetMoleculeForceField(self.mol, mmff_props, confId=i)
            self.ff_list.append(ff)
            energy += ff.CalcEnergy()
        energy = energy / self.mol.GetNumConformers()
        return energy


    def backward(self, input: Tensor):
        """Backward pass of energy. Only ff_list from forward matters here."""
        grad_list = []
        for ff in self.ff_list:
            grad_list.append(torch.tensor(ff.CalcGrad()).reshape(1, -1, 3))
        grad_energy = torch.stack(grad_list)
        return grad_energy


## Step 7: Write base layer

In [None]:
from nflows.distributions import StandardNormal
from rdkit.Chem import AllChem as Chem2
from rdkit.Chem import TorsionFingerprints
import numpy as np
import random

def generate_branched_alkane(num_atoms: int) -> Chem.Mol:
    """Generates a branched alkane.

    Parameters
    ----------
    num_atoms : int
        Number of atoms in molecule to be generated.
    """
    mol = Chem.MolFromSmiles('CCCC')
    edit_mol = Chem.RWMol(mol)
    while edit_mol.GetNumAtoms() < num_atoms:
        x = Chem.rdchem.Atom(6)
        randidx = np.random.randint(len(edit_mol.GetAtoms()))
        atom = edit_mol.GetAtomWithIdx(randidx)
        if atom.GetDegree() > 2:
            continue
        if atom.GetDegree() == 2 and random.random() <= 0.5:
            continue
        idx = edit_mol.AddAtom(x)
        edit_mol.AddBond(idx, randidx, Chem.rdchem.BondType.SINGLE)

    Chem.SanitizeMol(edit_mol)
    mol = Chem.rdmolops.AddHs(edit_mol.GetMol())

    return mol


def get_torsion_tuples(mol):
    """Gets the tuples for the torsion angles of the molecule.

    Parameters
    ----------
    mol : RDKit molecule
        Molecule for which torsion angles are to be extracted

    * tuples_original, tuples_reindexed : list[int]
        Tuples (quadruples) of indices that correspond to torsion angles. The first returns indices
        for the original molecule and the second for a version of the molecule with Hydrogens removed
        (since there are many cases where this stripped molecule is of interest)
    """

    [mol.GetAtomWithIdx(i).SetProp("original_index", str(i))
     for i in range(mol.GetNumAtoms())]
    stripped_mol = Chem2.rdmolops.RemoveHs(mol)

    nonring, _ = TorsionFingerprints.CalculateTorsionLists(mol)
    nonring_original = [list(atoms[0]) for atoms, ang in nonring]

    original_to_stripped = {
        int(stripped_mol.GetAtomWithIdx(reindex).GetProp("original_index")): reindex
        for reindex in range(stripped_mol.GetNumAtoms())
    }
    nonring_reindexed = [
        [original_to_stripped[original] for original in atom_group]
        for atom_group in nonring_original
    ]

    return nonring_original, nonring_reindexed


class Base(StandardNormal):
    """Conformer initialization."""

    def __init__(self, num_atoms: int, batch_size: int):
        """Initialization of base distribution.

        Args:
            batch_size (int): an integer of batch size.
            num_atoms (int): an integer of number of atoms.
        """
        self.mol = generate_branched_alkane(num_atoms)
        Chem.AllChem.EmbedMultipleConfs(self.mol, numConfs=batch_size)
        Chem.rdForceFieldHelpers.MMFFOptimizeMoleculeConfs(self.mol, nonBondedThresh=10., )
        self.torsion_angles, _ = get_torsion_tuples(self.mol)
        self.torsion_angles = torch.tensor(self.torsion_angles)
        super.__init__(shape=self.torsion_angles.shape)


## Step 8: Combine flow model

In [None]:
from nflows.flows import Flow
from torch.nn import Sequential
def make_model(num_atoms: int,
               lower: float,
               upper: float,
               bijector: Mapping[str, Any],
               base: Mapping[str, Any], 
               coord_trans: Mapping[str, Any],
               energy_layer: Mapping[str, Any]
               ) -> Tuple[Flow, Sequential]:
  """Constructs a particle model, with various configuration options.

  With N particles, the model is implemented as follows:
  1. We draw N particles randomly from a base distribution.
  2. We jointly transform the particles with a flow (a Distrax bijector).

  Optionally, the model can be made invariant to translations. We do this as
  follows:
  1. We draw N-1 particles and transform them with the flow as above.
  2. We add an extra particle at a fixed location.
  3. We choose a translation uniformly at random and apply it to all particles.

  Args:
    num_particles: number of particles.
    lower: array of shape [dim], the lower ranges of the box.
    upper: array of shape [dim], the upper ranges of the box.
    bijector: configures the bijector that transforms particles. Expected to
      have the following keys:
      * 'constructor': a callable that creates the bijector.
      * 'kwargs': keyword arguments to pass to the constructor.
    base: configures the base distribution. Expected to have the following keys:
      * 'constructor': a callable that creates the base distribution.
      * 'kwargs': keyword arguments to pass to the constructor.

  Returns:
    A particle model.
  """
  base_model = base['constructor'](
      num_particles=num_atoms,
      **base['kwargs'])
  bij = bijector['constructor'](
      angles=base_model.torsion_angles,
      lower=lower,
      upper=upper,
      **bijector['kwargs'])

  model = Flow(bij, base_model)
  
  trans = coord_trans['constructor'](
      mol=base_model.mol,
      angles=base_model.torsion_angles)
  energy = energy_layer['constructor'](
      mol=base_model.mol)
  
  energy_fn = Sequential(trans, energy)
  
  return model, energy_fn

## Step 9: Write training/testing code

In [None]:
from ml_collections import config_dict
from torch.nn import Transformer
FREQUENCIES = {
    4: 8,
    8: 8,
    16: 8,
    32: 8,
}

def get_config(num_atoms: int):
  """Returns the config."""
  num_frequencies = FREQUENCIES[num_atoms]
  train_batch_size = 128
  config = config_dict.ConfigDict()
  config.state = dict(
      num_atoms=num_atoms,
      beta=0.5,
      lower=-torch.pi,
      upper=torch.pi,
  )
  conditioner = dict(
      constructor=make_equivariant_conditioner,
      kwargs=dict(
          embedding_size=256,
          num_frequencies=num_frequencies,
          conditioner_constructor=Transformer,
          conditioner_kwargs=dict(
              nhead=2,
              num_encoder_layers=2,
              num_decoder_layers=2,
              dropout=0.,))
  )
  config.model = dict(
      constructor=make_model,
      kwargs=dict(
          bijector=dict(
              constructor=make_split_coupling_flow,
              kwargs=dict(
                  num_layers=24,
                  num_bins=16,
                  conditioner=conditioner,
                  use_circular_shift=True,
              ),
          ),
          base=dict(
              constructor=Base,
              kwargs=dict(
                  num_atoms=num_atoms,
                  batch_size=train_batch_size,
              ),
          ),
          coord_trans=dict(
              constructor=Dihedral2Coord,
          ),
          energy_layer=dict(
              constructor=Energy,
          ),
      ),
  )
  config.train = dict(
      batch_size=train_batch_size,
      learning_rate=7e-5,
      learning_rate_decay_steps=[250000, 500000],
      learning_rate_decay_factor=0.1,
      seed=42,
      max_gradient_norm=10000.,
  )
  config.test = dict(
      test_every=500,
      batch_size=2048,
  )
  return config


In [None]:
#!/usr/bin/python
#
# Copyright 2022 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Energy-based training of a flow model on an atomistic system."""

from typing import Callable, Dict, Tuple, Union

from absl import app
from absl import flags

flags.DEFINE_enum('system', '_16',
                  ['_4', '_8', '_32', '_64'
                  ], 'System and number of atoms to train.')
flags.DEFINE_integer('num_iterations', int(10**6), 'Number of training steps.')


FLAGS = flags.FLAGS


def _num_particles(system: str) -> int:
  return int(system.split('_')[-1])


def _get_loss(
    model: Flow,
    energy_fn: Callable,
    beta: Tensor,
    num_samples: int) -> Tuple[Tensor, Dict[str, Tensor]]:
  """Returns the loss and stats."""
  samples, log_prob = model.sample_and_log_prob(
      num_samples=num_samples)
  energies = energy_fn(samples)
  energy_loss = torch.mean(beta * energies + log_prob)

  loss = energy_loss
  stats = {
      'energy': energies,
      'model_log_prob': log_prob,
      'target_log_prob': -beta * energies
  }
  return loss, stats


def main(_):
  system = FLAGS.system
  config = get_config(_num_particles(system))
  # if system.startswith('lj'):
  #   config = lennard_jones_config.get_config(_num_particles(system))
  # elif system.startswith('mw_cubic'):
  #   config = monatomic_water_config.get_config(_num_particles(system), 'cubic')
  # elif system.startswith('mw_hex'):
  #   config = monatomic_water_config.get_config(_num_particles(system), 'hex')
  # else:
  #   raise KeyError(system)

  state = config.state
  

  def create_model():
    return config.model['constructor'](
        num_particles=state.num_particles,
        lower=state.lower,
        upper=state.upper,
        **config.model['kwargs'])

  def loss_fn():
    """Loss function for training."""
    model = create_model()

    loss, stats = _get_loss(
        model=model,
        energy_fn=energy_fn_train,
        beta=state.beta,
        num_samples=config.train.batch_size,
        )

    metrics = {
        'loss': loss,
        'energy': jnp.mean(stats['energy']),
        'model_entropy': -jnp.mean(stats['model_log_prob']),
    }
    return loss, metrics

  def eval_fn():
    """Evaluation function."""
    model = create_model()
    loss, stats = _get_loss(
        model=model,
        energy_fn=energy_fn_test,
        beta=state.beta,
        num_samples=config.test.batch_size,
        )
    log_probs = {
        'model_log_probs': stats['model_log_prob'],
        'target_log_probs': stats['target_log_prob'],
    }
    metrics = {
        'loss': loss,
        'energy': jnp.mean(stats['energy']),
        'model_entropy': -jnp.mean(stats['model_log_prob']),
        'ess': obs_utils.compute_ess(**log_probs),
        'logz': obs_utils.compute_logz(**log_probs),
        'logz_per_particle':
            obs_utils.compute_logz(**log_probs) / state.num_particles,
    }
    return metrics

  print(f'Initialising system {system}')
  rng_key = jax.random.PRNGKey(config.train.seed)
  init_fn, apply_fn = hk.transform(loss_fn)
  _, apply_eval_fn = hk.transform(eval_fn)

  rng_key, init_key = jax.random.split(rng_key)
  params = init_fn(init_key)
  opt_state = optimizer.init(params)

  def _loss(params, rng):
    loss, metrics = apply_fn(params, rng)
    return loss, metrics
  jitted_loss = jax.jit(jax.value_and_grad(_loss, has_aux=True))
  jitted_eval = jax.jit(apply_eval_fn)

  step = 0
  print('Beginning of training.')
  while step < FLAGS.num_iterations:
    # Training update.
    rng_key, loss_key = jax.random.split(rng_key)
    (_, metrics), g = jitted_loss(params, loss_key)
    if (step % 50) == 0:
      print(f'Train[{step}]: {metrics}')
    updates, opt_state = optimizer.update(g, opt_state, params)
    params = optax.apply_updates(params, updates)

    if (step % config.test.test_every) == 0:
      rng_key, val_key = jax.random.split(rng_key)
      metrics = jitted_eval(params, val_key)
      print(f'Valid[{step}]: {metrics}')

    step += 1

  print('Done')


if __name__ == '__main__':
  app.run(main)
