In [None]:
%env JAX_PLATFORMS=cpu
import jax
import jax.numpy as jnp

In [None]:
n_classes = 10
seq_len = 16
batch = 32
data = jax.random.randint(jax.random.PRNGKey(0), (batch, seq_len), 0, n_classes)
# https://github.com/neverix/Score-Entropy-Discrete-Diffusion/blob/f7221e3b835045f75444c7429955aa420111cc7d/noise_lib.py#L56
noise_eps = 1e-4
def noise_schedule(t):
    total_noise = -jnp.log1p(-(1 - noise_eps) * t)
    rate_noise = (1 - noise_eps) / (1 - (1 - noise_eps) * t)
    return total_noise, rate_noise

# https://github.com/neverix/Score-Entropy-Discrete-Diffusion/blob/f7221e3b835045f75444c7429955aa420111cc7d/graph_lib.py#L228C1-L232C22
def sample_transition(key, i, sigma):
    move_chance = 1 - jnp.exp(-sigma)
    move_indices = jax.random.bernoulli(key, move_chance, i.shape)
    i_pert = jnp.where(move_indices, n_classes, i)
    return i_pert

# https://github.com/neverix/Score-Entropy-Discrete-Diffusion/blob/f7221e3b835045f75444c7429955aa420111cc7d/graph_lib.py#L244
def score_entropy(score, sigma, x, x0):
    rel_ind = x == n_classes
    esigm1 = jnp.where(
        sigma < 0.5,
        jnp.expm1(sigma),
        jnp.exp(sigma) - 1
    )

    ratio = 1 / jnp.repeat(esigm1, x.shape[-1], -1)[rel_ind]
    other_ind = x0[rel_ind]

    # negative_term
    neg_term = ratio * jnp.take_along_axis(score[rel_ind], other_ind[..., None], -1).squeeze(-1)

    # positive term
    pos_term = jnp.exp(score[rel_ind][..., :-1]).sum(axis=-1)

    # constant term
    const = ratio * (jnp.log(ratio) - 1)

    entropy = jnp.zeros(x.shape, score.dtype).at[rel_ind].add(pos_term - neg_term + const)
    return entropy

key = jax.random.PRNGKey(0)
noise_key, transition_key = jax.random.split(key)
t = jax.random.uniform(noise_key, (batch, 1))
total_noise, rate_noise = noise_schedule(t)
data_perturbed = sample_transition(transition_key, data, total_noise)

logits = jax.random.gumbel(key, (batch, seq_len, n_classes))
loss = (score_entropy(logits, total_noise, data_perturbed, data) * rate_noise).sum(1)