In [None]:
%env JAX_PLATFORMS=cpu
import jax
import jax.numpy as jnp

In [None]:
n_classes = 10
seq_len = 4
batch = 32
data = jax.random.randint(jax.random.PRNGKey(0), (batch, seq_len), 0, n_classes)
# https://github.com/neverix/Score-Entropy-Discrete-Diffusion/blob/f7221e3b835045f75444c7429955aa420111cc7d/noise_lib.py#L56
noise_eps = 1e-4
def noise_schedule(t):
    rate_noise = (1 - noise_eps) / (1 - (1 - noise_eps) * t)
    total_noise = -jnp.log1p(-(1 - noise_eps) * t)
    return rate_noise, total_noise

# https://github.com/neverix/Score-Entropy-Discrete-Diffusion/blob/f7221e3b835045f75444c7429955aa420111cc7d/graph_lib.py#L228C1-L232C22
def sample_transition(key, i, sigma):
    move_chance = 1 - jnp.exp(-sigma)
    move_indices = jax.random.bernoulli(key, move_chance, i.shape)
    i_pert = jnp.where(move_indices, n_classes, i)
    return i_pert

key = jax.random.PRNGKey(0)
noise_key, transition_key = jax.random.split(key)
sigma = jax.random.uniform(noise_key, (batch, 1))
rate_noise, total_noise = noise_schedule(sigma)
data_perturbed = sample_transition(transition_key, data, sigma)
data_perturbed