This colab is an attempt to implement SEDD: Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution

Arxiv paper is https://arxiv.org/pdf/2310.16834

In [None]:
from jax import numpy as jnp
# pylint: disable=g-multiple-import, g-importing-member
from jaxtyping import Array, Float, Int32
from typing import Protocol

import jax
import numpy as np

# Implement key algorithmic pieces.

## Diffusion intensity
Here we implement objects that governs how diffusion intensity changes over time

In [None]:
# Start with some basics.
FloatNdArrayOrSclar = Float[Array, 'B ...'] | Float
IntNdArray = Int32[Array, 'B ...']


class DiffusionIntenistyInterface(Protocol):
  """Protocol encapsulating how diffusion intensity changes over time.

  Unlike in the paper, here we always assume time starts at 0 and ends in 1.

  In the forward process, an example x_0 sampled from the data distribution
  p_data undergoes the discrete diffusion process to gradually becomes complete
  noise.

  In the backward process, we gradually denoise the noisy sample to uncover the
  original example.
  """

  def sigma_t(self, time: FloatNdArrayOrSclar) -> FloatNdArrayOrSclar:
    # This corresponds to \sigma(t) in the paper, indicating the rate of change
    # at a particular point of time in diffusion probabilities.
    ...

  def cum_sigma_t(self, time: FloatNdArrayOrSclar) -> FloatNdArrayOrSclar:
    # This corresponds to \bar{\sigma}(t) in the paper. It is the cumulative of
    # sigma_t from time 0.
    ...

In [None]:
class LinearDiffusionDensity(DiffusionIntenistyInterface):

  def __init__(self, strength = 1.0):
    self._strength = strength

  def sigma_t(self, time: FloatNdArrayOrSclar) -> FloatNdArrayOrSclar:
    return jnp.ones_like(time) * self._strength

  def cum_sigma_t(self, time: FloatNdArrayOrSclar) -> FloatNdArrayOrSclar:
    # clamp time to be between 0 and 1.
    time = jnp.maximum(jnp.minimum(time, 1.0), 0.0)
    return time * self._strength

### unit-tests for diffusion intensity objects
Add a bunch of tests. TODO(yonghui): convert them to unit-tests

In [None]:
# Test the implementation of LinearDiffusionDensity.
x = LinearDiffusionDensity()
time = np.array([-1.0, 0.2, 0.7, 1.2])
# The expected output is [1, 1, 1, 1]
print(x.sigma_t(time))
# The expected output is [0, 0.2, 0.7, 1.0]
print(x.cum_sigma_t(time))

time = 0.5
print(x.sigma_t(time))
print(x.cum_sigma_t(time))

## Diffusion Matrices

They governs the diffusion process, how probability mass flows from one node to othes.

In [None]:
# Transition matrix is a key concept used in discrete diffusion process, It
# gonverns the diffusion process.
class TransitionMatrix(Protocol):
  """Protocol defines the transition matrix Q.

  A transition matrix, denoted as Q in the paper, is of size
  [vocab_size, vocab_size], where vocab_size is the size of the vocab.

  The semantics of Q(i, j) is the following: given X_t = j, the
  probability of X_{t+\delta t} = i equals Q(j, i) * \delta t.
  """

  def __init__(self, rank: int, diff_density: DiffusionIntenistyInterface):
    # The transition matrix is of size rank * rank.
    self._rank = rank
    # 'diff_density' governs how diffusion density changes over time.
    self._diff_density = diff_density

  def q(self) -> Float[Array, 'N N']:
    """Returns the Q matrix. This is only useful for testing purposes."""
    cols = []
    for i in range(self._rank):
      cols.append(self.q_column(i))
    # Now stack them together
    return jnp.stack(cols, axis=1)

  def q_column(self, col: int) -> Float[Array, 'N']:
    """Returns 'col'-th column vector of the Q matrix, of dim 'self._rank'.

    Args:
      col: the column to return.

    Returns:
      A vector of dim 'self._rank'.

    The returned vector should sum up to 0.
    """
    ...

  def q_row(self, row: int) -> Float[Array, 'N']:
    """Returns 'row'-th row vector of the Q matrix, of dim 'self._rank'.

    Args:
      row: the row to return.

    Returns:
      A vector of dim 'self._rank'.
    """
    ...

  def batch_q_row(self, rows: Int32[Array, 'B']) -> Float[Array, 'B N']:
    """The batched version of q_row for more efficient training / sampling."""
    ...

  def q_posterior(
      self, cols: Int32[Array, 'B'], prob_ratio: Float[Array, 'B N']
  ) -> Float[Array, 'B N']:
    """Returns the posterior transition matrix.

    q_posterior(i, j) = q(j, i) * p(i) / p(j)

    This is easy to see:
    p(i | j) = p(j | i) * p(i) / p(j)

    In the context of discrete diffusion process:
    p(x_{t - \delta t} | x_t) = p(x_t | x_{t - \delta t}) *
                                p(x_{t - \delta t}) / p(x_t)
    where p(x_t | x_{t - \delta t}) is the forward transition matrix.

    Args:
      cols: is a vector of indices of the node x_t. It is also the column
        indices into the posterior matrix.
      prob_ratio: prob ratio of all other nodes wrt to node x_t

    Returns:
      A matrix of transition probs from node x_t to x_{t - \delta t}.

    Please refer to equation 3 in the paper for details.
    """
    # of shape [B, N]
    post_p = self.batch_q_row(cols) * prob_ratio
    # masks out x_t themselves.
    mask = 1.0 - jax.nn.one_hot(cols, self._rank)
    post_p = post_p * mask
    # of shape [B, 1]
    post_p_sum = jnp.sum(post_p, axis=1, keepdims=True)
    return post_p - post_p_sum * (1 - mask)

  def exp_q(self, node: int, time: float) -> Float[Array, 'N']:
    """Returns the cumulative transition prob from 'node' to all others nodes.

    Args:
      node: the index (into the vocab) of the current node. This is the state of
        the variable at time 0.
      time: the time at which we want to compute the cumulative transition prob
        for.

    Returns:
      A float array of the transition probs, of size [N], where N is the size of
      the vocab. Given X_0 == node (the state of the random variable at time 0),
      this is the marginal distribution of X_t at time 'time'. The returned
      vector should sum up to 1.0.
    """
    ...

  def batch_exp_q(
      self, nodes: Int32[Array, 'B'], time: Float[Array, 'B']
  ) -> Float[Array, 'B N']:
    """A batched version of exp_q above.

    Args:
      nodes: a vector of [B] of the current nodes
      time: a vector of [B] of the times for which to compute the cumulative
        transition probs for. 'time' can be different for different nodes.

    Returns:
      A matrix of size [B, N], where B is the number of nodes, and N is the
      vocab size.
    """
    ...

  def batch_exp_q_row(
      self, nodes: Int32[Array, 'B'], time: Float[Array, 'B']
  ) -> Float[Array, 'B N']:
    """useful for sampling

    Conceptually, what this function does is the following

    rows = []
    for node, time in zip(nodes, times):
      # compute the comulative transition prob from time 0 to time t
      exp_q = exp(t * self.q)
      rows.append(exp_q[node])
    return jnp.stack(rows)

    Args:
      nodes: a vector of [B] of the current nodes
      time: a vector of [B] of the times for which to compute the cumulative
        transition probs for. 'time' can be different for different nodes.

    Returns:
      A matrix of size [B, N], where B is the number of nodes, and N is the
      vocab size. See above for the semantics of the matrix.
    """
    ...

  def adjust_prob_ratio(
      self, prob_ratio: Float[Array, 'B N'], delta_ts: Float[Array, 'B']
  ) -> Float[Array, 'B N']:
    """Given prob ratios at time 'ts', we estimate what they would have been at
    time 'ts - delta_t'.

    Mathmatically, here is what we compute:

    given:
      p_t = P_{transition} p_0

    we would like to compute:
      p_0 = P_{transition}^(-1) p_t

    where p_0 and p_t could be probabilities or probability ratios.

    Args:
      prob_ratio: the probability ratio at time t
      delta_ts: we would like to estimate the probability ratio at time t -
        delta_ts

    Returns:
      Probability or prob ratio at time t - delta_ts
    """
    ...

In [None]:
class QAbsorb(TransitionMatrix):
  """A transition matrix that gradually diffuses any distribution into a ...

  distribution centered around the sink node.
  """

  def q_column(self, col: int) -> Float[Array, 'N']:
    """Returns the 'col'-th column vector of the Q matrix, of dim 'self._rank'.

    Args:
      col: the column to return.

    Returns:
      A vector of dim 'self._rank'.

    The returned vector should sum up to 0.
    """
    # probability mass flows out of the current node.
    cur_node = -1.0 * jax.nn.one_hot(col, self._rank)
    # probability mass flows into the sink node. The sink node is always at
    # the end (the last entry in the vocab).
    sink_node = 1.0 * jax.nn.one_hot(self._rank - 1, self._rank)
    return cur_node + sink_node

  def q_row(self, row: int) -> Float[Array, 'N']:
    """Returns 'row'-th row vector of the Q matrix, of dim 'self._rank'.

    Args:
      row: the row to return.

    Returns:
      A vector of dim 'self._rank'.
    """
    row_vec_normal = -1.0 * jax.nn.one_hot(row, self._rank)
    row_vec_sink = jnp.ones(shape=[self._rank]) + row_vec_normal
    return jnp.where(row < self._rank - 1, row_vec_normal, row_vec_sink)

  def batch_q_row(self, rows: Int32[Array, 'B']) -> Float[Array, 'B N']:
    """The batched version of q_row for more efficient training / sampling."""
    assert rows.ndim == 1
    batch_size = rows.shape[0]
    # of shape [B, N]
    row_vec_normal = -1.0 * jax.nn.one_hot(rows, self._rank)
    row_vec_sink = jnp.ones(shape=[batch_size, self._rank]) + row_vec_normal
    # Now choose between row_vec_normal and row_vec_sink based on the node_id.
    return jnp.where(
        rows[:, jnp.newaxis] < self._rank - 1, row_vec_normal, row_vec_sink
    )

  def exp_q(self, node: int, time: float) -> Float[Array, 'N']:
    """Returns the cumulative transition probability from 'node' to others.

    Args:
      node: the index (into the vocab) of the current node.
      time: the time at which we want to compute the cumulative transition prob
        for.

    Returns:
      A float array of the transition probs, of size [N], where N is the size of
      the vocab. Given X_0 == node (the state of the random variable at time 0),
      this is the marginal distribution of X_t at time 'time'. The returned
      vector should sum up to 1.0.
    """
    # Prob mass that still remains at the current node.
    node_prob = jnp.exp(-1.0 * self._diff_density.cum_sigma_t(time))
    sink_prob = 1.0 - node_prob
    # This is the probability mass that still remain at the current node.
    cur_node = node_prob * jax.nn.one_hot(node, self._rank)
    # This is the probability mass that is flowing into the sink node.
    sink_node = sink_prob * jax.nn.one_hot(self._rank - 1, self._rank)
    return cur_node + sink_node

  def batch_exp_q(
      self, nodes: Int32[Array, 'B'], time: Float[Array, 'B']
  ) -> Float[Array, 'B N']:
    """A batched version of exp_q above.

    Args:
      nodes: a vector of [B] of the current nodes
      time: a vector of [B] of the times for which to return the cumulative
        transition probs for. 'time' can be different for different nodes.

    Returns:
      A matrix of size [B, N], where B is the number of nodes, and N is the
      vocab size.
    """
    assert nodes.ndim == 1
    assert time.ndim == 1
    assert nodes.shape == time.shape
    batch_size = nodes.shape[0]

    # Prob mass that still remains at the current node.
    node_prob = jnp.exp(-1.0 * self._diff_density.cum_sigma_t(time))
    # expand node_prob into a matrix of shape [B, 1]
    node_prob = node_prob[:, jnp.newaxis]
    sink_prob = 1.0 - node_prob
    # This is the probability mass that still remain at the current node.
    cur_node = node_prob * jax.nn.one_hot(nodes, self._rank)
    # This is the probability mass that is flowing into the sink node.
    sink_node = sink_prob * jax.nn.one_hot(
        jnp.zeros([batch_size], dtype=jnp.int32) + self._rank - 1, self._rank
    )
    return cur_node + sink_node

  def batch_exp_q_row(
      self, nodes: Int32[Array, 'B'], time: Float[Array, 'B']
  ) -> Float[Array, 'B N']:
    """Computes and returns the following:

    rows = []
    for node, time in zip(nodes, times):
      # compute the comulative transition prob from time 0 to time t
      exp_q = exp(t * self.q)
      rows.append(exp_q[node])
    return jnp.stack(rows)

    Args:
      nodes: a vector of [B] of the current nodes
      time: a vector of [B] of the times for which to compute the cumulative
        transition probs for. 'time' can be different for different nodes.

    Returns:
      A matrix of size [B, N], where B is the number of nodes, and N is the
      vocab size. See above for the semantics of the matrix.
    """
    assert nodes.ndim == 1
    assert time.ndim == 1
    assert nodes.shape == time.shape
    batch_size = nodes.shape[0]

    # Prob mass that still remains at the current node.
    node_prob = jnp.exp(-1.0 * self._diff_density.cum_sigma_t(time))
    # expand node_prob into a matrix of shape [B, 1]
    node_prob = node_prob[:, jnp.newaxis]
    # of shape [B, 1].
    sink_prob = 1.0 - node_prob
    # This is the probability mass that still remain at the current node.
    row_vec_normal = node_prob * jax.nn.one_hot(nodes, self._rank)
    sink_nodes = jnp.zeros([batch_size], dtype=jnp.int32) + self._rank - 1
    row_vec_sink = sink_prob * jnp.ones(
        shape=[batch_size, self._rank]
    ) + row_vec_normal * jax.nn.one_hot(
        sink_nodes, self._rank
    )  # To make sure prob for the last entry is 1.0

    return jnp.where(
        nodes[:, jnp.newaxis] < self._rank - 1, row_vec_normal, row_vec_sink
    )

  def adjust_prob_ratio(
      self, prob_ratio: Float[Array, 'B N'], delta_ts: Float[Array, 'B']
  ) -> Float[Array, 'B N']:
    """Adjust prob ratio for QAbsorb.

    Given the forward prob transition matrix of this form (this is the exp(tQ))
    The backward transition matrix is of this form (this is exp(-tQ))
    see https://chatgpt.com/share/67928ded-f5cc-800f-a944-caa36da2e830

    Note naive implementation of
    p_0 = Exp(-tQ) p_t would have time complexity of O(N^2), which is
    prohibitively high for large N (e.g. 256k for large vocabs)

    Here we take advantage of the fact that exp(-tQ) matrix is very sparse to
    significantly cut-down the computation cost.
    """
    # The forward transition probability on the diagnal.
    #
    # shape [B, 1]
    one_over_q_x = jnp.exp(self._diff_density.cum_sigma_t(delta_ts))[
        :, jnp.newaxis
    ]
    batch_size = prob_ratio.shape[0]
    rank = prob_ratio.shape[1]
    # shape [B N]
    p_0_normal_node = prob_ratio * one_over_q_x
    # of shape [B N]
    #
    exp_q_inverse_last_row = jnp.zeros([batch_size, rank]) + 1 - one_over_q_x
    adjustment = (
        jax.nn.one_hot(
            jnp.zeros([batch_size], dtype=jnp.int32) + rank - 1, rank
        )
        * one_over_q_x
    )
    exp_q_inverse_last_row = exp_q_inverse_last_row + adjustment
    # of shape [B]
    p_0_sink_node = jnp.sum(exp_q_inverse_last_row * prob_ratio, axis=1)
    # Finally, combine the two
    # of shape [B, N]
    sink_nodes = jax.nn.one_hot(
        jnp.zeros([batch_size], dtype=jnp.int32) + rank - 1, rank
    )
    return (
        p_0_normal_node * (1 - sink_nodes)
        + p_0_sink_node[:, jnp.newaxis] * sink_nodes
    )

In [None]:
class QUniform(TransitionMatrix):
  """A transition matrix that diffuses into other nodes uniformly."""

  def q_column(self, col: int) -> Float[Array, 'N']:
    """Returns 'col'-th column vector of the Q matrix, of dim 'self._rank'.

    Args:
      col: the column to return.

    Returns:
      A vector of dim 'self._rank'.

    The returned vector should sum up to 0.
    """
    # Probability mass flows out of the current node at uniform speed.
    cur_node = -1.0 * jax.nn.one_hot(col, self._rank)
    # Probability mass flows into other nodes at equal prob.
    other_nodes = jnp.ones([self._rank]) / self._rank
    return cur_node + other_nodes

  def q_row(self, row: int) -> Float[Array, 'N']:
    """Returns 'row'-th row vector of the Q matrix, of dim 'self._rank'.

    Args:
      row: the row to return.

    Returns:
      A vector of dim 'self._rank'.
    """
    # The uniform transition matrix is doubly stochastic
    return self.q_column(row)

  def batch_q_row(self, rows: Int32[Array, 'B']) -> Float[Array, 'B N']:
    """The batched version of q_row for more efficient training / sampling."""
    # UniformDiffusion matrix is doublely stochastic. Q == Q^T
    assert rows.ndim == 1
    batch_size = rows.shape[0]
    #
    # Of shape [B, N]
    cur_node = -1.0 * jax.nn.one_hot(rows, self._rank)
    # Probability mass flows into other nodes at equal prob.
    other_nodes = jnp.ones([batch_size, self._rank]) / self._rank
    return cur_node + other_nodes

  def exp_q(self, node: int, time: float) -> Float[Array, 'N']:
    """Returns the cumulative transition probability from 'node' to others.

    Args:
      node: the index (into the vocab) of the current node.
      time: the time at which we want to compute the cumulative transition prob
        for.

    Returns:
      A float array of the transition probs, of size [N], where N is the size of
      the vocab. Given X_0 == node (the state of the random variable at time 0),
      this is the marginal distribution of X_t at time 'time'. The returned
      vector should sum up to 1.0.
    """
    # prob mass still remain at the current node
    node_prob = jnp.exp(-1.0 * self._diff_density.cum_sigma_t(time))
    other_prob = (1.0 - node_prob) / self._rank
    cur_node = node_prob * jax.nn.one_hot(node, self._rank)
    other_nodes = other_prob * jnp.ones([self._rank])
    return cur_node + other_nodes

  def batch_exp_q(
      self, nodes: Int32[Array, 'B'], time: Float[Array, 'B']
  ) -> Float[Array, 'B N']:
    """A batched version of exp_q above.

    Args:
      nodes: a vector of [B] of the current nodes
      time: a vector of [B] of the times for which to return the cumulative
        transition probs for. 'time' can be different for different nodes.

    Returns:
      A matrix of size [B, N], where B is the number of nodes, and N is the
      vocab size.
    """
    assert nodes.ndim == 1
    assert time.ndim == 1
    assert nodes.shape == time.shape
    batch_size = nodes.shape[0]
    # prob mass still remain at the current node
    node_prob = jnp.exp(-1.0 * self._diff_density.cum_sigma_t(time))
    other_prob = (1.0 - node_prob) / self._rank
    node_prob = node_prob[:, jnp.newaxis]
    other_prob = other_prob[:, jnp.newaxis]
    cur_node = node_prob * jax.nn.one_hot(nodes, self._rank)
    other_nodes = other_prob * jnp.ones([batch_size, self._rank])
    return cur_node + other_nodes

  def batch_exp_q_row(
      self, nodes: Int32[Array, 'B'], time: Float[Array, 'B']
  ) -> Float[Array, 'B N']:
    """Computes and returns the following:

    rows = []
    for node, time in zip(nodes, times):
      # compute the comulative transition prob from time 0 to time t
      exp_q = exp(t * self.q)
      rows.append(exp_q[node])
    return jnp.stack(rows)

    Args:
      nodes: a vector of [B] of the current nodes
      time: a vector of [B] of the times for which to compute the cumulative
        transition probs for. 'time' can be different for different nodes.

    Returns:
      A matrix of size [B, N], where B is the number of nodes, and N is the
      vocab size. See above for the semantics of the matrix.
    """
    assert nodes.ndim == 1
    assert time.ndim == 1
    assert nodes.shape == time.shape
    # Since uniform distribution is doublely stochastic, batch_exp_q_row is
    # the same as batch_exp_q_column.
    return self.batch_exp_q(nodes, time)

  def adjust_prob_ratio(
      self, prob_ratio: Float[Array, 'B N'], delta_ts: Float[Array, 'B']
  ) -> Float[Array, 'B N']:
    """Adjust prob ratio for QUniform.

    Given the forward prob transition matrix of this form (this is the exp(tQ)),
    The backward transition matrix is of this form (this is exp(-tQ)):
    See https://chatgpt.com/share/67933f48-51ec-800f-9088-bf7dfd3aa723 for the
    derivation.

    Note naive implementation of
    p_0 = Exp(-tQ) p_t would have time complexity of O(N^2), which is
    prohibitively high for large N (e.g. 256k for large vocab)

    Here we take advantage of the structural property of the exp(-tQ) matrix
    to reduce the time complexity to O(N).
    """
    # The following is an implementation of the analytical solution above.
    rank = prob_ratio.shape[1]
    # (1 - x) and x as in https://chatgpt.com/share/67933f48-51ec-800f-9088-bf7dfd3aa723
    one_minus_x = jnp.exp(-1.0 * self._diff_density.cum_sigma_t(delta_ts))
    one_minus_x = one_minus_x[:, jnp.newaxis]
    x = (1.0 - one_minus_x)
    # of shape [B, 1]
    j_matrix = (-1.0 * x) / (rank * one_minus_x)
    # of shape [B, 1]
    i_matrix = rank / (rank * one_minus_x)
    # i * p_t
    i_pt = i_matrix * prob_ratio
    # j * p_t
    j_pt = jnp.sum(j_matrix * prob_ratio, axis=1, keepdims=True)
    return i_pt + j_pt

### unit-tests for diffusion matrices
Add a bunch of unit-tests to make sure the diffusion matrices are correctly implemented.

In [None]:
# Test the implementation of diffusion matrices.
diff_density = LinearDiffusionDensity()
rank = 8
q_absorb = QAbsorb(rank, diff_density)
print(q_absorb.q_column(0))
print(q_absorb.q_column(2))
print(q_absorb.q_column(rank - 1))
print(q_absorb.exp_q(1, time=0.5))
print(
    q_absorb.batch_exp_q(
        jnp.array([1, 2, rank - 1]), time=jnp.array([0.5, 0.75, 1.0])
    )
)

q_uniform = QUniform(rank, diff_density)
print(q_uniform.q_column(0))
print(q_uniform.q_column(2))
print(q_uniform.q_column(rank - 1))
print(q_uniform.exp_q(1, time=0.5))
print(
    q_uniform.batch_exp_q(
        jnp.array([1, 2, rank - 1]), time=jnp.array([0.5, 0.75, 1.0])
    )
)

In [None]:
# TODO(yonghui): Convert this to a unit-test.
# TODO(yonghui): Add more unit-test for the corner cases.
#
# Test the consistency between q_column and exp_q(). exp_q() should be a time
# of q_column.
#
# Here, we test the equivallence of the QAbsorb matrix.
diff_density = LinearDiffusionDensity(strength=2.0)
rank = 8
q_absorb = QAbsorb(rank, diff_density)

# Assume the starting state is x_0
x_0 = 2
t_0 = 0.0
t_end = 0.5
q_matrix = q_absorb.q()
q_matrix = jnp.astype(q_matrix, jnp.float64)
print('q_matrix', q_matrix)

x_t_end = q_absorb.exp_q(x_0, time=t_end)
print('x_t', x_t_end)

# now we integrate q_column from time t_0 to t_end
x_t = jax.nn.one_hot(x_0, rank, dtype=jnp.float64)
x_t = x_t[:, jnp.newaxis]
print(x_t)
num_iterations = 1000
t_bucket_size = (t_end - t_0) / num_iterations
for i in range(num_iterations):
  t_i_begin = i * t_bucket_size
  t_i_end = (i + 1) * t_bucket_size
  delta_t = diff_density.cum_sigma_t(t_i_end) - diff_density.cum_sigma_t(
      t_i_begin)
  # This follows the differential equation.
  x_t = x_t + jnp.matmul(q_matrix, x_t) * delta_t
x_t_end_integrated = x_t
x_t_end_integrated = jnp.reshape(x_t_end_integrated, [-1])
print('x_t_end_integrated', x_t_end_integrated)
assert np.max(np.abs(x_t_end_integrated - x_t_end)) < 0.001

In [None]:
# TODO(yonghui): Convert this to a unit-test.
# TODO(yonghui): Add more tests for the corner cases.
#
# Test the consistency between q_column and exp_q(). exp_q() should be a time
# of q_column.
#
# Here, we test the equivallence of the QUniform matrix.
diff_density = LinearDiffusionDensity(strength=2.0)
rank = 8
q_uniform = QUniform(rank, diff_density)

# Assume the starting state is x_0
x_0 = 2
t_0 = 0.0
t_end = 0.5
q_matrix = q_uniform.q()
q_matrix = jnp.astype(q_matrix, jnp.float64)
print('q_matrix', q_matrix)

x_t_end = q_uniform.exp_q(x_0, time=t_end)
print('x_t', x_t_end)

# now we integrate q_column from time t_0 to t_end
x_t = jax.nn.one_hot(x_0, rank, dtype=jnp.float64)
x_t = x_t[:, jnp.newaxis]
print(x_t)
num_iterations = 1000
t_bucket_size = (t_end - t_0) / num_iterations
for i in range(num_iterations):
  t_i_begin = i * t_bucket_size
  t_i_end = (i + 1) * t_bucket_size
  delta_t = diff_density.cum_sigma_t(t_i_end) - diff_density.cum_sigma_t(
      t_i_begin)
  # This follows the differential equation.
  x_t = x_t + jnp.matmul(q_matrix, x_t) * delta_t
x_t_end_integrated = x_t
x_t_end_integrated = jnp.reshape(x_t_end_integrated, [-1])
print('x_t_end_integrated', x_t_end_integrated)
assert np.max(np.abs(x_t_end_integrated - x_t_end)) < 0.001

In [None]:
# Test the equivallence between batched and non-batched version.
diff_density = LinearDiffusionDensity()
rank = 8
q_absorb = QAbsorb(rank, diff_density)
nodes = [1, 2, rank - 1]
ts = [0.5, 0.75, 1.0]

xt_lists = []

for node, t in zip(nodes, ts):
  xt_lists.append(q_absorb.exp_q(node, time=t))
xt_non_batch = jnp.stack(xt_lists, axis=0)

xt_batch = q_absorb.batch_exp_q(
    jnp.array([1, 2, rank - 1]), time=jnp.array([0.5, 0.75, 1.0])
)
print('xt_non_batch', xt_non_batch)
print('xt_batch', xt_batch)
assert np.all(xt_non_batch == xt_batch)

In [None]:
# Test the equivallence between batched and non-batched version.
diff_density = LinearDiffusionDensity()
rank = 8
q_uniform = QUniform(rank, diff_density)
nodes = [1, 2, rank - 1]
ts = [0.5, 0.75, 1.0]

xt_lists = []

for node, t in zip(nodes, ts):
  xt_lists.append(q_uniform.exp_q(node, time=t))
xt_non_batch = jnp.stack(xt_lists, axis=0)

xt_batch = q_uniform.batch_exp_q(
    jnp.array([1, 2, rank - 1]), time=jnp.array([0.5, 0.75, 1.0])
)
print('xt_non_batch', xt_non_batch)
print('xt_batch', xt_batch)
assert np.all(xt_non_batch == xt_batch)

In [None]:
# test to make sure q_row and q_col are consistent with each other for QUniform.
diff_density = LinearDiffusionDensity()
rank = 8
q_uniform = QUniform(rank, diff_density)

rows = []
for i in range(rank):
  rows.append(q_uniform.q_row(i))

q_matrix = jnp.stack(rows, axis=0)
print('q_matrix', q_matrix)
assert np.all(q_matrix == q_uniform.q())

In [None]:
# test to make sure q_row and q_col are consistent with each other for QAbsorb.
diff_density = LinearDiffusionDensity()
rank = 8
q_absorb = QAbsorb(rank, diff_density)

rows = []
for i in range(rank):
  rows.append(q_absorb.q_row(i))

q_matrix = jnp.stack(rows, axis=0)
print('q_matrix', q_matrix)
assert np.all(q_matrix == q_absorb.q())

In [None]:
# test to make sure batch_q_row and q_row are consistent
diff_density = LinearDiffusionDensity()
rank = 8
q_absorb = QAbsorb(rank, diff_density)

rows = []
for i in range(rank):
  rows.append(q_absorb.q_row(i))

q_matrix = jnp.stack(rows, axis=0)
batch_q_rows = q_absorb.batch_q_row(jnp.array([i for i in range(rank)]))

print('q_matrix', q_matrix)
print('batch_q_rows', batch_q_rows)
assert np.all(q_matrix == batch_q_rows)

In [None]:
# test to make sure batch_q_row and q_row are consistent
diff_density = LinearDiffusionDensity()
rank = 8
q_uniform = QUniform(rank, diff_density)

rows = []
for i in range(rank):
  rows.append(q_uniform.q_row(i))

q_matrix = jnp.stack(rows, axis=0)
batch_q_rows = q_uniform.batch_q_row(jnp.array([i for i in range(rank)]))

print('q_matrix', q_matrix)
print('batch_q_rows', batch_q_rows)
assert np.all(q_matrix == batch_q_rows)

In [None]:
# Test the equivallence between batched and non-batched version.
diff_density = LinearDiffusionDensity()
rank = 8
q_uniform = QUniform(rank, diff_density)
nodes = jnp.array([0, 1, 2, 3, 4, 5, 6, 7])
ts = jnp.array([0.5] * rank)

columns = q_uniform.batch_exp_q(nodes, time=ts)
rows = q_uniform.batch_exp_q_row(nodes, time=ts)

print('columns', columns)
print('rows', rows)
assert np.all(np.transpose(columns) == rows)

In [None]:
# Test the equivallence between batched and non-batched version.
diff_density = LinearDiffusionDensity()
rank = 8
q_absorb = QAbsorb(rank, diff_density)
nodes = jnp.array([0, 1, 2, 3, 4, 5, 6, 7])
ts = jnp.array([0.5] * rank)

columns = q_absorb.batch_exp_q(nodes, time=ts)
rows = q_absorb.batch_exp_q_row(nodes, time=ts)

print('columns', columns)
print('rows', rows)
assert np.all(np.transpose(columns) == rows)

In [None]:
# Test adjust_prob_ratios.
diff_density = LinearDiffusionDensity(strength=2.0)
rank = 8
q_absorb = QAbsorb(rank, diff_density)
nodes = jnp.array([0, 1, 2, 3, 4, 5, 6, 7])
ts = jnp.array([0.5] * rank)

rows = q_absorb.batch_exp_q_row(nodes, time=ts)
print('rows', rows)

p0 = jnp.array([
    [0.1, 0.1, 0.1, 0.1, 0.2, 0.0, 0.3, 0.1],
    [0.1, 0.2, 0.1, 0.1, 0.1, 0.0, 0.3, 0.1],
    ])

p1 = jnp.matmul(rows, jnp.transpose(p0))
p1 = jnp.transpose(p1)

print('p1', p1)
print(jnp.sum(p1, axis=1))

p0_inversed = q_absorb.adjust_prob_ratio(p1, jnp.array([0.5, 0.5]))
print('p0_inversed', p0_inversed)

print('delta p0 and p0_inversed', jnp.abs(p0 - p0_inversed))
assert np.all(np.abs(p0 - p0_inversed) < 2e-3)

In [None]:
# Test adjust_prob_ratios.
diff_density = LinearDiffusionDensity(strength=2.0)
rank = 8
q_uniform = QUniform(rank, diff_density)
nodes = jnp.array([0, 1, 2, 3, 4, 5, 6, 7])
ts = jnp.array([0.7] * rank)

rows = q_uniform.batch_exp_q_row(nodes, time=ts)
print('rows', rows)

p0 = jnp.array([
    [0.1, 0.1, 0.1, 0.1, 0.2, 0.0, 0.3, 0.1],
    [0.1, 0.2, 0.1, 0.1, 0.1, 0.0, 0.3, 0.1],
    ])

p1 = jnp.matmul(rows, jnp.transpose(p0))
p1 = jnp.transpose(p1)

print('p1', p1)
print(jnp.sum(p1, axis=1))
p0_inversed = q_uniform.adjust_prob_ratio(p1, jnp.array([0.7, 0.7]))
print('p0_inversed', p0_inversed)

print('delta p0 and p0_inversed', jnp.abs(p0 - p0_inversed))
assert np.all(np.abs(p0 - p0_inversed) < 1e-3)

# Algorithm pieces needed for training

In [None]:
def _gumbel_max_sample(prob_ts, random_seed):
  """Samples a token per batch element using the gumbel-max trick."""
  assert prob_ts.ndim == 2
  # use the gumble-max trick to sample from the categorical distribution.
  min_log_prob = -100.0  # exp(-100.0) is small enough to be ignored.
  # Note(yonghui): maybe we don't need to replace -inf by -100.0
  log_prob_ts = jnp.nan_to_num(
      jnp.log(prob_ts), nan=min_log_prob, neginf=min_log_prob
  )
  noise = jax.random.gumbel(random_seed, shape=log_prob_ts.shape)
  samples = jnp.argmax(log_prob_ts + noise, axis=-1, keepdims=True)
  sampled_log_probs = jnp.take_along_axis(log_prob_ts, samples, axis=-1)
  log_prob_ratio = log_prob_ts - sampled_log_probs
  samples = jnp.squeeze(samples, axis=-1)
  return samples, log_prob_ratio


def _forced_sample(prob_ts, samples):
  """Compute log-prob ration with forced samples."""
  assert prob_ts.ndim == 2
  assert samples.ndim == 1

  # use the gumble-max trick to sample from the categorical distribution.
  min_log_prob = -100.0  # exp(-100.0) is small enough to be ignored.
  # Note(yonghui): maybe we don't need to replace -inf by -100.0
  log_prob_ts = jnp.nan_to_num(
      jnp.log(prob_ts), nan=min_log_prob, neginf=min_log_prob
  )
  sampled_log_probs = jnp.take_along_axis(
      log_prob_ts, samples[:, jnp.newaxis], axis=-1
  )
  log_prob_ratio = log_prob_ts - sampled_log_probs
  return samples, log_prob_ratio


# Samples nosified examples and the groundtruth prob ratio
#
# Please refer to algo-1 in the paper https://arxiv.org/pdf/2310.16834 for
# details.
def DiffuseAndSampleXt(
    x0: Int32[Array, 'B T'],
    ts: Float[Array, 'B'],
    transition_matrix: TransitionMatrix,
    random_seed: jax.random.PRNGKey,
    forced_samples: Int32[Array, 'B T'] | None = None,
):
  """Given x0, sample nosified examples at time t.

  Args:
    x0: the input batch, of shape [B, T]. Here we assume x0 is not packed. Each
      element contains one single example.
    ts: times at which to sample X_t from.
    transition_matrix: the transition matrix.
    random_seed: the random seed.
    forced_samples: the forced samples. If they are present, this function only
      computes the log-prob ratio given the forced samples. This is mostly
      useful for testing purposes.

  Returns:
    A tuple (noisified_example, log_prob_ratio), where noisified_example is the
      noisified example at time ts, and log_prob_ratio is the log probability
      ratio between all other tokens and the sampled tokens at time ts.
      noisified_example is of shape [B, T], and prob_ratio is of shape
      [B, T, N].
  """
  batch_size = x0.shape[0]
  seq_length = x0.shape[1]
  assert ts.ndim == 1
  assert ts.shape[0] == batch_size
  ts_replicated = jnp.repeat(ts[:, jnp.newaxis], seq_length, axis=1)

  nodes_1d = jnp.reshape(x0, [-1])
  ts_1d = jnp.reshape(ts_replicated, [-1])
  # prob_ts is the probability distribution at time t for all the nodes
  prob_ts = transition_matrix.batch_exp_q(nodes_1d, ts_1d)
  if forced_samples is not None:
    assert forced_samples.shape == x0.shape
    forced_samples = jnp.reshape(forced_samples, [-1])
    samples, log_prob_ratio = _forced_sample(prob_ts, forced_samples)
  else:
    samples, log_prob_ratio = _gumbel_max_sample(prob_ts, random_seed)
  samples = jnp.reshape(samples, [batch_size, seq_length])
  log_prob_ratio = jnp.reshape(log_prob_ratio, [batch_size, seq_length, -1])
  return samples, log_prob_ratio

In [None]:
prob_ts = jnp.array([[0.1, 0.2, 0.3, 0.4], [0.25, 0, 0.25, 0.5]])

freq = np.zeros_like(prob_ts)
num_trials = 1000

for i in range(num_trials):
  s, x = _gumbel_max_sample(prob_ts, random_seed=jax.random.PRNGKey(i))
  for row_i, row in enumerate(s):
    freq[row_i, row] += 1

print(freq / num_trials)

s, x = _gumbel_max_sample(prob_ts, random_seed=jax.random.PRNGKey(100))
print(s)
print(jnp.exp(x))

s, x_new = _forced_sample(prob_ts, s)
print(s)
print(jnp.exp(x))

In [None]:
# Test the implementation of LinearDiffusionDensity.
vocab_size = 16
diff_density = LinearDiffusionDensity(strength=2.0)
q_absorb = QAbsorb(vocab_size, diff_density)
q_uniform = QUniform(vocab_size, diff_density)

x0 = jnp.array([[1, 2], [3, 4]])
ts = jnp.array([0.1, 1.0])

sample, log_prob_ratio = DiffuseAndSampleXt(x0, ts, q_absorb,
                                            jax.random.PRNGKey(100))
print(sample)
print(log_prob_ratio)

sample, log_prob_ratio = DiffuseAndSampleXt(x0, ts, q_uniform,
                                            jax.random.PRNGKey(100))
print(sample)
print(log_prob_ratio)

In [None]:
# This is the loss function used in SEDD.
def compute_score_entropy(
    sigma_t: Float[Array, 'B T'],
    noisified_samples: Int32[Array, 'B T'],
    valid_tokens: Float[Array, 'B T'],
    log_prob_ratio: Float[Array, 'B T N'],
    predicted_log_prob_ratio: Float[Array, 'B T N'],
):
  """Computes the score entropy loss.

  In a real training loop, noisified_samples is the input the transformer block.
  In addition to taking noisified_samples as input, the transformer block also
  takes ts (time at each token) as input.
  The transformer block predicts predicted_log_prob_ratio for each possible
  token in the vocab. It is basically the linear output before the softmax
  layer.

  Args:
    sigma_t: the diffusion strength at time t.
    noisified_samples: the noisified samples at time t.
    valid_tokens: a mask of valid tokens. It is 0.0 if the token at the pos is
      invalid (e.g. padded toke), and it is 1.0 otherwise.
    log_prob_ratio: the log probability ratio between all other tokens and the
      nosified tokens at time t.
    predicted_log_prob_ratio: the predicted log probability ratio between all
      other tokens and the noisified tokens at time t.

  Returns:
    The score entropy loss, and auxilary info that can be useful for debugging.
  """
  vocab_size = log_prob_ratio.shape[-1]
  # Indices of the nosified examples at time t
  xt_indices = jax.nn.one_hot(noisified_samples, vocab_size)
  # compute the score entropy loss.
  # The loss is minimized if predicted_log_prob_ratio and the groundtruth
  # ratio are the same.
  #
  # TODO(yonghui): normalize the loss such that it is non-negative.
  assert predicted_log_prob_ratio.shape == log_prob_ratio.shape
  loss = (
      jnp.exp(predicted_log_prob_ratio)
      - jnp.exp(log_prob_ratio) * predicted_log_prob_ratio
  )
  # normalization constant
  k_norm = (
      jnp.exp(log_prob_ratio)
      - jnp.exp(log_prob_ratio) * log_prob_ratio
  )
  loss = loss - k_norm
  # mask out loss on tokens in x_t. loss on those tokens should be 0.0 anyways.
  loss = (1.0 - xt_indices) * loss
  per_token_loss = jnp.sum(loss, -1)
  assert sigma_t.shape == per_token_loss.shape == valid_tokens.shape
  per_sequence_loss = sigma_t * per_token_loss * valid_tokens
  summaries = {
      'per_token_element_loss': loss,
      'per_token_loss': per_token_loss,
      'per_sequence_loss': per_sequence_loss,
      'nosied_samples': noisified_samples,
      'sigma_t': sigma_t,
  }
  num_valid_tokens = jnp.sum(valid_tokens)
  final_loss = jnp.sum(per_sequence_loss) / num_valid_tokens
  return final_loss, summaries

## Example implementation of the training loop

In [None]:
# Test out the loss function
vocab_size = 16
diff_density = LinearDiffusionDensity(strength=2.0)
q_uniform = QUniform(vocab_size, diff_density)

x0 = jnp.array([[1, 2, 3], [4, 5, 6]])
ts = jnp.array([0.1, 0.5])

samples, log_prob_ratio = DiffuseAndSampleXt(
    x0, ts, q_uniform, jax.random.PRNGKey(100)
)

# The training algorithm will produce an estimate log_prob_ratio.
# Here, we simulate training by producing the estimated log_prob_ratio
# from a slightly different time.
ts_prime = jnp.array([0.2, 0.4])
_, estimated_prob_ratio = DiffuseAndSampleXt(
    x0, ts_prime, q_uniform, jax.random.PRNGKey(100), forced_samples=samples
)

print('x0', x0)
print('ts', ts)
print('samples', samples)
print('log_prob_ratio', log_prob_ratio)
print('estimated_prob_ratio', estimated_prob_ratio)

seq_len = samples.shape[1]
sigma_t = diff_density.sigma_t(ts)
sigma_t = jnp.repeat(sigma_t[:, jnp.newaxis], seq_len, axis=1)

valid_tokens = jnp.ones_like(samples)

print('log_prob_ratio', log_prob_ratio)
print('estimated_prob_ratio', estimated_prob_ratio)

final_loss, summaries = compute_score_entropy(
    sigma_t,
    samples,
    valid_tokens,
    log_prob_ratio,
    predicted_log_prob_ratio=estimated_prob_ratio,
)

print(final_loss)
print(summaries)

# Algorithm pieces needed for sampling

In [None]:
# First test the implementation of q_posterior
vocab_size = 16
diff_density = LinearDiffusionDensity(strength=2.0)
q_absorb = QAbsorb(vocab_size, diff_density)
q_uniform = QUniform(vocab_size, diff_density)

x0 = jnp.array([[1, 2], [3, 4]])
ts = jnp.array([0.1, 1.0])

print('Test absorb transition matrix')
sample, log_prob_ratio = DiffuseAndSampleXt(
    x0, ts, q_absorb, jax.random.PRNGKey(100)
)
print('sample', sample)
print('log_prob_ratio', log_prob_ratio)

# Now compute the posterior probs.
sample_1d = jnp.reshape(sample, [-1])
log_prob_ratio_1d = jnp.reshape(log_prob_ratio, [sample_1d.shape[0], -1])
posterior_probs = q_absorb.q_posterior(sample_1d, jnp.exp(log_prob_ratio_1d))
print('q_absorb posterior_probs', posterior_probs)
# make sure posterior_probs sum up to 0
print('posterior_probs sum', jnp.sum(posterior_probs, -1))
assert jnp.all(jnp.abs(jnp.sum(posterior_probs, -1)) < 1e-6)

print('Test uniform transition matrix')
sample, log_prob_ratio = DiffuseAndSampleXt(
    x0, ts, q_uniform, jax.random.PRNGKey(100)
)
print('sample', sample)
print('log_prob_ratio', log_prob_ratio)

# Now compute the posterior probs.
sample_1d = jnp.reshape(sample, [-1])
log_prob_ratio_1d = jnp.reshape(log_prob_ratio, [sample_1d.shape[0], -1])
posterior_probs = q_uniform.q_posterior(sample_1d, jnp.exp(log_prob_ratio_1d))
print('q_uniform posterior_probs', posterior_probs)

# make sure posterior_probs sum up to 0
print('posterior_probs sum', jnp.sum(posterior_probs, -1))
assert jnp.all(jnp.abs(jnp.sum(posterior_probs, -1)) < 1e-6)

# TODO(yonghui): Add a bunch more tests to assert the validity of the
# impementation.

## Example implementation of the euler sampling algorithm

In [None]:
# An example implementation of the Euler sampling algorithm
def TestEulerSampler(q_matrix_type: str):
  # Now we implement the euler sampling algorithm.
  vocab_size = 16

  # jnp.exp(-10) is small enough that any initial distribution will be diffused
  # into complete noise.
  diff_density = LinearDiffusionDensity(strength=4.0)

  if q_matrix_type == 'absorb':
    # First test out the uniform diffusion matrix.
    q_matrix = QAbsorb(vocab_size, diff_density)
  elif q_matrix_type == 'uniform':
    q_matrix = QUniform(vocab_size, diff_density)
  else:
    raise ValueError(f'Unknown q_matrix_type {q_matrix_type}')

  # This is the state that we would like to recover from.
  x0 = jnp.array([[1, 2, 3], [4, 5, 6]])

  samples, _ = DiffuseAndSampleXt(
      x0, jnp.array([1.0, 1.0]), q_matrix, jax.random.PRNGKey(100)
  )

  # samples should be complete noise
  print('samples', samples)

  def estimate_prob_ratio(ts, xt, x0=x0):
    # Estimate the prob ratio at times 'ts' conditioned on x_t being "xt".
    # Here we cheat by providing the groundtruth prob ratios.
    # In the real sampling algorithm, the prob ratio should come from our model
    # predictions.
    assert xt.shape == x0.shape
    assert ts.ndim == 1
    assert ts.shape[0] == xt.shape[0]
    _, estimated_prob_ratio = DiffuseAndSampleXt(
        x0, ts, q_matrix, None, forced_samples=xt
    )
    return jnp.exp(estimated_prob_ratio)

  # We should test out much smaller step sizes
  num_iterations = 100
  step_size = 1.0 / num_iterations

  xt = samples
  batch_size = samples.shape[0]

  # Here is the sampling loop.
  for i in range(num_iterations, 0, -1):
    t_begin = (i - 1) * step_size
    t_end = i * step_size
    t_delta = diff_density.cum_sigma_t(t_end) - diff_density.cum_sigma_t(
        t_begin
    )
    prob_ratio = estimate_prob_ratio(jnp.zeros([batch_size]) + t_end, xt)

    xt_1d = jnp.reshape(xt, [-1])
    prob_ratio_1d = jnp.reshape(prob_ratio, [xt_1d.shape[0], -1])
    posterior = q_matrix.q_posterior(xt_1d, prob_ratio_1d)
    xt_one_hot = jax.nn.one_hot(xt_1d, vocab_size)
    # Take one step backward in time.
    # P(x_{t - \delta t} | x_t =
    #        p(x_t) +
    #        t_delta * posterior(x_{t - \delta t}, x_t))
    #
    # This is the differential equation for the backward process.
    xt_minus_delta = xt_one_hot + t_delta * posterior
    # Now we sample the new xt
    xt_minus_delta_sample, _ = _gumbel_max_sample(
        xt_minus_delta, jax.random.PRNGKey(100 + i)
    )
    xt = jnp.reshape(xt_minus_delta_sample, xt.shape)
    print('xt', xt)

  x_final = xt

  # Here we assert we fully recover the original x0
  assert jnp.all(x_final == x0)

In [None]:
TestEulerSampler(q_matrix_type = 'absorb')

In [None]:
TestEulerSampler(q_matrix_type = 'uniform')

In [None]:
# An example implementation of the Tweedie sampling algorithm
#
# Note, this is not a full implementation of the algorithm. In particular, here
# we are missing the part to adjust the prob ratio based on Exp({-\delta t} * Q)
def TestTweedieSampler(q_matrix_type: str):
  # Now we implement the euler sampling algorithm.
  vocab_size = 16

  # jnp.exp(-10) is small enough that any initial distribution will be diffused
  # into complete noise.
  diff_density = LinearDiffusionDensity(strength=4.0)

  if q_matrix_type == 'absorb':
    # First test out the uniform diffusion matrix.
    q_matrix = QAbsorb(vocab_size, diff_density)
  elif q_matrix_type == 'uniform':
    q_matrix = QUniform(vocab_size, diff_density)
  else:
    raise ValueError(f'Unknown q_matrix_type {q_matrix_type}')

  # This is the state that we would like to recover from.
  x0 = jnp.array([[1, 2, 3], [4, 5, 6]])

  samples, _ = DiffuseAndSampleXt(
      x0, jnp.array([1.0, 1.0]), q_matrix, jax.random.PRNGKey(100)
  )

  # samples should be complete noise
  print('samples', samples)

  def estimate_prob_ratio(ts, xt, x0=x0):
    # Estimate the prob ratio at times 'ts' conditioned on x_t being "xt".
    # Here we cheat by providing the groundtruth prob ratios.
    # In the real sampling algorithm, the prob ratio should come from our model
    # predictions.
    assert xt.shape == x0.shape
    assert ts.ndim == 1
    assert ts.shape[0] == xt.shape[0]
    _, estimated_prob_ratio = DiffuseAndSampleXt(
        x0, ts, q_matrix, None, forced_samples=xt
    )
    return jnp.exp(estimated_prob_ratio)

  # We should test out much smaller step sizes
  num_iterations = 100
  step_size = 1.0 / num_iterations

  xt = samples
  batch_size = samples.shape[0]
  num_tokens = samples.shape[0] * samples.shape[1]

  # Here is the sampling loop.
  for i in range(num_iterations, 0, -1):
    t_begin = (i - 1) * step_size
    t_end = i * step_size
    t_delta = diff_density.cum_sigma_t(t_end) - diff_density.cum_sigma_t(
        t_begin
    )
    prob_ratio = estimate_prob_ratio(jnp.zeros([batch_size]) + t_end, xt)

    xt_1d = jnp.reshape(xt, [-1])
    prob_ratio_1d = jnp.reshape(prob_ratio, [xt_1d.shape[0], -1])
    batch_t_delta = jnp.ones([xt_1d.shape[0]]) * t_delta

    # adjust the prob ratio, basically, this takes care of this part
    # exp(-tQ) part in equation 18.
    prob_ratio_1d_adjusted = q_matrix.adjust_prob_ratio(
        prob_ratio_1d, jnp.zeros([num_tokens]) + step_size
    )

    posterior = (
        q_matrix.batch_exp_q_row(xt_1d, batch_t_delta) * prob_ratio_1d_adjusted
    )
    posterior_sum = jnp.sum(posterior, axis=-1, keepdims=True)
    normed_posterior = posterior / posterior_sum

    # Now we sample the new xt
    xt_minus_delta_sample, _ = _gumbel_max_sample(
        normed_posterior, jax.random.PRNGKey(100 + i)
    )
    xt = jnp.reshape(xt_minus_delta_sample, xt.shape)
    print('xt', xt)

  x_final = xt

  # Here we assert we fully recover the original x0
  assert jnp.all(x_final == x0)

In [None]:
TestTweedieSampler(q_matrix_type = 'absorb')

In [None]:
TestTweedieSampler(q_matrix_type = 'uniform')