In [None]:
%env JAX_PLATFORMS=cpu
import jax
import jax.numpy as jnp

In [None]:
n_classes = 10
seq_len = 4
batch = 2
data = jax.random.randint(jax.random.PRNGKey(0), (batch, seq_len), 0, n_classes)
# https://github.com/neverix/Score-Entropy-Discrete-Diffusion/blob/f7221e3b835045f75444c7429955aa420111cc7d/noise_lib.py#L56
noise_eps = 1e-6
def noise_schedule(t):
    total_noise = -jnp.log1p(-(1 - noise_eps) * t)
    rate_noise = (1 - noise_eps) / (1 - (1 - noise_eps) * t)
    return total_noise, rate_noise

# https://github.com/neverix/Score-Entropy-Discrete-Diffusion/blob/f7221e3b835045f75444c7429955aa420111cc7d/graph_lib.py#L228C1-L232C22
def sample_transition(key, i, sigma):
    move_chance = 1 - jnp.exp(-sigma)
    move_indices = jax.random.bernoulli(key, move_chance, i.shape)
    i_pert = jnp.where(move_indices, n_classes, i)
    return i_pert

# https://github.com/neverix/Score-Entropy-Discrete-Diffusion/blob/f7221e3b835045f75444c7429955aa420111cc7d/graph_lib.py#L244
def score_entropy(score, sigma, x, x0):
    esigm1 = jnp.where(
        sigma < 0.5,
        jnp.expm1(sigma),
        jnp.exp(sigma) - 1
    )

    ratio = 1 / jnp.repeat(esigm1, x.shape[-1], -1)
    other_ind = x0

    # negative_term
    neg_term = ratio * jnp.take_along_axis(score, other_ind[..., None], -1).squeeze(-1)

    # positive term
    pos_term = jnp.exp(score[..., :-1]).sum(axis=-1)

    # constant term
    const = ratio * (jnp.log(ratio) - 1)

    rel_ind = x == n_classes
    entropy = jnp.where(rel_ind, pos_term - neg_term + const, jnp.zeros(x.shape, score.dtype))
    return entropy

# https://github.com/neverix/Score-Entropy-Discrete-Diffusion/blob/f7221e3b835045f75444c7429955aa420111cc7d/graph_lib.py#L234C1-L239C21
def staggered_score(score, dsigma):
    dse = jnp.exp(dsigma)
    extra_const = (1 - dse) * score.sum(axis=-1)
    score = (score * dse[..., None]).at[..., -1].add(extra_const)
    return score

# https://github.com/neverix/Score-Entropy-Discrete-Diffusion/blob/f7221e3b835045f75444c7429955aa420111cc7d/graph_lib.py#L218
def transp_transition(i, sigma):
    sigma = sigma.reshape(*sigma.shape, *((1,) * (i.ndim + 1 - sigma.ndim)))
    edge = (
        jnp.exp(-sigma) * jax.nn.one_hot(i, num_classes=n_classes + 1)
        + jnp.where(
            i == n_classes,
            1 - jnp.exp(-sigma).squeeze(-1),
            0
        )[..., None])
    return edge

def sample_limit(dims):
    return jnp.full(dims, n_classes)

In [None]:
import math


# https://github.com/neverix/Score-Entropy-Discrete-Diffusion/blob/f7221e3b835045f75444c7429955aa420111cc7d/model/transformer.py#L80
def timestep_embedding(t, dim, max_period=10000):
    # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
    half = dim // 2
    freqs = jnp.exp(
        -math.log(max_period) * jnp.arange(0, half, dtype=jnp.float32) / half
    )
    args = t[..., None].astype(jnp.float32) * freqs[None]
    embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)
    if dim % 2:
        embedding = jnp.cat([embedding, jnp.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


In [None]:
from functools import partial
import optax
from tqdm import trange
key_params = jax.random.key(3)
key_w1, key_w2, key_w3, key_b1, key_b2, key_b3 = jax.random.split(key_params, 6)
n_time = 32
in_dim = (n_classes + 1) * seq_len + n_time
d = 128
out_dim = (n_classes + 1) * seq_len
scale_in = math.sqrt(in_dim)
scale_out = math.sqrt(out_dim)
w1, w2, w3 = jax.random.normal(key_w1, (in_dim, d)) / scale_in, jax.random.normal(key_w2, (in_dim, d)) / scale_in, jax.random.normal(key_w3, (d, out_dim)) / scale_out
b1, b2, b3 = jax.random.normal(key_b1, (d,)) * 0.1, jax.random.normal(key_b2, (d,)) * 0.1, jax.random.normal(key_b3, (out_dim,)) * 0.1
params = w1, w2, w3, b1, b2, b3

def run_net(params, data_perturbed, total_noise):
    x = jnp.concatenate([jax.nn.one_hot(data_perturbed, n_classes + 1).reshape(*data_perturbed.shape[:-1], -1) * 2 - 1, timestep_embedding(total_noise[..., 0], n_time).reshape(*data_perturbed.shape[:-1], -1)], -1)
    w1, w2, w3, b1, b2, b3 = params
    # h = (x @ w1 + b1) * (x @ w2 + b2)
    h = jax.nn.relu(x @ w1 + b1)
    logits = (h @ w3 + b3).reshape(*data_perturbed.shape, n_classes + 1)
    return logits

def get_loss(params, data, key):
    batch = data.shape[0]
    noise_key, transition_key = jax.random.split(key)
    t = jax.random.uniform(noise_key, (batch, 1))
    total_noise, rate_noise = noise_schedule(t)
    data_perturbed = sample_transition(transition_key, data, total_noise)
    logits = run_net(params, data_perturbed, total_noise)
    loss = (score_entropy(logits, total_noise, data_perturbed, data) * rate_noise).sum(1)
    return loss.mean()

@partial(jax.jit, donate_argnums=(0, 3))
def update_step(params, data, key, opt_state):
    loss, grad = jax.value_and_grad(get_loss)(params, data, key)
    updates, opt_state = optimizer.update(grad, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

loss_and_grad = jax.value_and_grad(get_loss)
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(3e-4))
ema_params = jax.tree.map(jnp.copy, params)
opt_state = optimizer.init(params)
losses = []
for i in (bar := trange(4096)):
    key = jax.random.PRNGKey(i)
    # loss, grad = loss_and_grad(params, data.repeat(16, 0), key)
    # updates, opt_state = optimizer.update(grad, opt_state, params)
    # params = optax.apply_updates(params, updates)
    params, opt_state, loss = update_step(params, data.repeat(128, 0), key, opt_state)
    ema_params = jax.tree.map(lambda x, y: x * 0.99 + y * 0.01, ema_params, params)
    bar.set_postfix(loss=loss)
    losses.append(loss)

In [None]:
from matplotlib import pyplot as plt
plt.plot(losses)
plt.yscale("log")
plt.xscale("log")

In [None]:
import random


# @jax.jit
def sample(score_fn, key, n_steps, denoise=True, projector=lambda x: x):
    # https://github.com/neverix/Score-Entropy-Discrete-Diffusion/blob/f7221e3b835045f75444c7429955aa420111cc7d/sampling.py#L78
    def update_fn(score_fn, x, t, step_size):
        curr_sigma = noise_schedule(t)[0]
        next_sigma = noise_schedule(t - step_size)[0]
        dsigma = curr_sigma - next_sigma

        score = jnp.exp(score_fn(x, curr_sigma))
        print(score.argmax(-1))

        stag_score = staggered_score(score, dsigma)
        probs = stag_score * transp_transition(x, dsigma)
        return probs

    x = sample_limit((seq_len,))
    timesteps = jnp.linspace(1, noise_eps, n_steps + 1)
    dt = (1 - noise_eps) / max(n_steps, 1)

    for i in trange(n_steps):
        key, subkey = jax.random.split(key)
        t = timesteps[i] * jnp.ones(x.shape)
        x = projector(x)
        probs = update_fn(score_fn, x, t, dt)
        x = jax.random.categorical(subkey, probs)

    if denoise:
        # denoising step
        x = projector(x)
        t = timesteps[-1] * jnp.ones(x.shape)
        probs = update_fn(score_fn, x, t, dt)
        x = probs.argmax(-1)
    return x

(sample(lambda x, t: run_net(ema_params, x, t), jax.random.key(random.randrange(0, 100)), 1000) == data).mean(1).max()

In [None]:
data