In [1]:
import einops
import jax
from jax import numpy as jnp
import optax
from fused_ce_loss import fused_ce_loss_fwd

In [16]:
from jax import numpy as jnp
from jax import random as jr

vocab_size = 512
embed_size = 128
batch_size = 4
seq_len = 32

xs = jr.normal(jr.PRNGKey(0), (batch_size, seq_len, embed_size), dtype=jnp.bfloat16)
ys = jr.randint(jr.PRNGKey(1), (batch_size, seq_len), minval=0, maxval=vocab_size)
vocab = jr.normal(jr.PRNGKey(2), (vocab_size, embed_size), dtype=jnp.bfloat16)

In [17]:
out = fused_ce_loss_fwd(xs, vocab, ys)

In [18]:
def ref_ce_loss(xs, vocab, ys):
    logits = jnp.einsum('bse,ve->bsv', xs, vocab)
    return optax.softmax_cross_entropy_with_integer_labels(logits, ys)

In [19]:
ref_ce_loss(xs, vocab, ys) - out

Array([[0, 0, 0, 0, 0.125, 0, 0, 0, -0.125, 0.0625, -0.25, 0, -0.25, 0,
        0, 0, 0, 0, 0, -0.25, 0, 0, 0, 0, -0.25, 0, -0.125, -0.25,
        -0.125, -0.125, 0, 0],
       [-0.125, 0, 0.125, 0, -0.00323486, 0, 0, 0, 0, 0, -0.25, -0.25, 0,
        0, -0.125, 0, 0, 0.25, 0, 0, 0, 0.25, -0.125, 0, 0, -0.125, 0, 0,
        0, 0, -0.125, 0.25],
       [0, 0, -0.125, -0.125, 0, 0, -0.25, -0.25, 0, 0, 0, 0.125, -0.125,
        -0.125, 0, 0, 0, 0, 0, 0, 0.125, 0.25, -0.25, 0, 0.125, 0.25,
        -0.25, 0, 0, 0, 0, 0],
       [-0.25, 0, 0, 0, 0, 0, -0.25, 0, -0.125, 0, 0, 0.0625, 0, 0.125,
        0, -0.25, 0, 0.125, 0, 0, 0.25, -0.5, 0.25, 0, -0.25, 0, 0.25,
        -0.125, 0.125, -0.125, 0, 0]], dtype=bfloat16)

In [4]:
out

Array([[36, 26.75, 12.4375, 35.25, 29, 36.75, 39, 30.5, 25.125, 8.6875,
        55, 37.5, 38.5, 33.5, 47, 43, 45.5, 48.25, 48.75, 45.5, 35.25,
        23, 47.25, 36, 45.75, 28.125, 27.25, 45.75, 30.75, 14.875, 31.75,
        38.5],
       [25.75, 28.625, 31, 53.5, 0.00323486, 38.5, 28.875, 24.75, 38.5,
        41.25, 54.75, 37, 33.25, 31.125, 23.375, 24.375, 66, 32.25, 23.5,
        19.5, 25.5, 31.25, 27.75, 33.75, 32, 18, 51, 31.625, 46.5, 51.25,
        24.375, 43],
       [19, 56, 23, 23.125, 42.75, 28.75, 36.75, 58.5, 40.5, 43, 35.5,
        23.625, 19.875, 20.5, 55, 26.125, 13.875, 42, 41, 40.5, 24.375,
        37, 32.25, 28.75, 18, 33, 38, 11.5, 26.75, 22.875, 32.75, 32.75],
       [56.75, 33.5, 21.5, 19.25, 26, 47.75, 39, 13.8125, 21.75, 34,
        26.125, 11, 39.75, 28.25, 54, 54.25, 24.875, 26.875, 38.25,
        40.25, 39.5, 65.5, 40, 36.75, 45, 50, 38.5, 25.875, 17, 30, 40,
        39]], dtype=bfloat16)

In [6]:
xs = einops.rearrange(xs, 'b s e -> (b s) e')
ys = einops.rearrange(ys, 'b s -> (b s)')
logits = einops.einsum(xs, vocab, 'b e, v e->b v')

In [7]:
xs.shape

(128, 128)

In [8]:
ys[0]

Array(410, dtype=int32)

In [9]:
logits[:, 64:]

Array([[-8.3125, 3.92188, 7.90625, ..., -11.5625, -20.875, 0.746094],
       [-12.25, 5.5, -27.375, ..., -15, -14.375, 8.9375],
       [-2.84375, 3.73438, 10.8125, ..., 3.14062, 1.03125, -3.42188],
       ...,
       [-4.34375, -16.5, -24.375, ..., 7, 3.01562, -19.5],
       [0.188477, 8, -15.0625, ..., -14.25, 11.875, -2.51562],
       [-10.875, -17.5, -19.75, ..., 2.67188, -9.9375, 15.3125]],      dtype=bfloat16)

In [10]:
logits_max = jnp.max(logits, axis=-1, keepdims=True)
logits_max[0]

Array([31.5], dtype=bfloat16)

In [11]:
label_logits = jnp.take_along_axis(logits, ys[..., None], axis=-1)[..., 0]
label_logits[0]

Array(-4.28125, dtype=bfloat16)

In [12]:
norm_logits = logits - jax.lax.stop_gradient(logits_max)

In [13]:
denom = jnp.sum(jnp.exp(norm_logits), axis=-1)
denom[0]

Array(1.38281, dtype=bfloat16)

In [14]:
log_normalizers = jnp.log(denom)
ref_ce_loss = -label_logits + logits_max[..., 0] + log_normalizers

In [15]:
ref_ce_loss

Array([36, 26.75, 12.4375, 35.25, 29.125, 36.75, 39, 30.5, 25, 8.75,
       54.75, 37.5, 38.25, 33.5, 47, 43, 45.5, 48.25, 48.75, 45.25, 35.25,
       23, 47.25, 36, 45.5, 28.125, 27.125, 45.5, 30.625, 14.75, 31.75,
       38.5, 25.625, 28.625, 31.125, 53.5, 0, 38.5, 28.875, 24.75, 38.5,
       41.25, 54.5, 36.75, 33.25, 31.125, 23.25, 24.375, 66, 32.5, 23.5,
       19.5, 25.5, 31.5, 27.625, 33.75, 32, 17.875, 51, 31.625, 46.5,
       51.25, 24.25, 43.25, 19, 56, 22.875, 23, 42.75, 28.75, 36.5, 58.25,
       40.5, 43, 35.5, 23.75, 19.75, 20.375, 55, 26.125, 13.875, 42, 41,
       40.5, 24.5, 37.25, 32, 28.75, 18.125, 33.25, 37.75, 11.5, 26.75,
       22.875, 32.75, 32.75, 56.5, 33.5, 21.5, 19.25, 26, 47.75, 38.75,
       13.8125, 21.625, 34, 26.125, 11.0625, 39.75, 28.375, 54, 54,
       24.875, 27, 38.25, 40.25, 39.75, 65, 40.25, 36.75, 44.75, 50,
       38.75, 25.75, 17.125, 29.875, 40, 39], dtype=bfloat16)

In [16]:
out

Array([[36, 26.75, 12.4375, 35.25, 29, 36.75, 39, 30.5, 25.125, 8.6875,
        55, 37.5, 38.5, 33.5, 47, 43, 45.5, 48.25, 48.75, 45.5, 35.25,
        23, 47.25, 36, 45.75, 28.125, 27.25, 45.75, 30.75, 14.875, 31.75,
        38.5],
       [25.75, 28.625, 31, 53.5, 0.00323486, 38.5, 28.875, 24.75, 38.5,
        41.25, 54.75, 37, 33.25, 31.125, 23.375, 24.375, 66, 32.25, 23.5,
        19.5, 25.5, 31.25, 27.75, 33.75, 32, 18, 51, 31.625, 46.5, 51.25,
        24.375, 43],
       [19, 56, 23, 23.125, 42.75, 28.75, 36.75, 58.5, 40.5, 43, 35.5,
        23.625, 19.875, 20.5, 55, 26.125, 13.875, 42, 41, 40.5, 24.375,
        37, 32.25, 28.75, 18, 33, 38, 11.5, 26.75, 22.875, 32.75, 32.75],
       [56.75, 33.5, 21.5, 19.25, 26, 47.75, 39, 13.8125, 21.75, 34,
        26.125, 11, 39.75, 28.25, 54, 54.25, 24.875, 26.875, 38.25,
        40.25, 39.5, 65.5, 40, 36.75, 45, 50, 38.5, 25.875, 17, 30, 40,
        39]], dtype=bfloat16)

In [14]:

logits_max = jnp.max(logits, axis=-1, keepdims=True)
norm_logit = logits - jax.lax.stop_gradient(logits_max)
label_logits = jnp.take_along_axis(logits, ys[..., None], axis=-1)[..., 0]
log_normalizers = jnp.log(jnp.sum(jnp.exp(logits), axis=-1))
ref_ce_loss = log_normalizers - label_logits

In [8]:
ref_ce_loss

Array([36, 26.75, 12.4375, 35.25, 29.125, 36.75, 39, 30.5, 25, 8.75,
       54.75, 37.5, 38.25, 33.5, 47, 43, 45.5, 48.25, 48.75, 45.25, 35.25,
       23, 47.25, 36, 45.5, 28.125, 27.125, 45.5, 30.625, 14.75, 31.75,
       38.5, 25.625, 28.625, 31.125, 53.5, 0, 38.5, 28.875, 24.75, 38.5,
       41.25, 54.5, 36.75, 33.25, 31.125, 23.25, 24.375, 66, 32.5, 23.5,
       19.5, 25.5, 31.5, 27.625, 33.75, 32, 17.875, 51, 31.625, 46.5,
       51.25, 24.25, 43.25, 19, 56, 22.875, 23, 42.75, 28.75, 36.5, 58.25,
       40.5, 43, 35.5, 23.75, 19.75, 20.375, 55, 26.125, 13.875, 42, 41,
       40.5, 24.5, 37.25, 32, 28.75, 18.125, 33.25, 37.75, 11.5, 26.75,
       22.875, 32.75, 32.75, 56.5, 33.5, 21.5, 19.25, 26, 47.75, 38.75,
       13.8125, 21.625, 34, 26.125, 11.0625, 39.75, 28.375, 54, 54,
       24.875, 27, 38.25, 40.25, 39.75, 65, 40.25, 36.75, 44.75, 50,
       38.75, 25.75, 17.125, 29.875, 40, 39], dtype=bfloat16)

In [18]:
logits

Array([[-6.875, -11.5, -44, ..., -43, -52.5, -30.75],
       [-38.75, -45, -23.75, ..., -53.75, -53, -29.75],
       [-28.625, -46, -19, ..., -28.75, -30.875, -35.25],
       ...,
       [-49, -54, -43.75, ..., -35.5, -39.5, -62],
       [-47.75, -54, -29.75, ..., -48.25, -22.125, -36.5],
       [-32.25, -30.125, -37.25, ..., -29.875, -42.5, -17.25]],      dtype=bfloat16)

In [17]:
logits.shape

(128, 512)

In [7]:
logits.shape

(4, 32, 512)

In [14]:
label_logits = jnp.take_along_axis(logits, ys[..., None], axis=-1)[..., 0]

In [15]:
label_logits[0]

Array([-4.28125, 12.0625, 20.25, -7.8125, 15.625, -5.75, -5.125,
       -0.310547, 9.875, 29.875, -19.875, -1.34375, -4.25, 11.5, -14.1875,
       -10.8125, 1.90625, -14.1875, -11.3125, -9.8125, 10.8125, 6.8125,
       -9.4375, 12.25, -11.375, 0.484375, 3.35938, -9, 7.6875, 19.25,
       -1.78906, -2.17188], dtype=bfloat16)

In [8]:
logits_max[0]

Array([[31.5],
       [38.75],
       [31.875],
       [26],
       [44.75],
       [31],
       [33.25],
       [29.875],
       [34.75],
       [38.5],
       [34.75],
       [36],
       [34],
       [45],
       [32.75],
       [32],
       [47.5],
       [34],
       [37.5],
       [35.5],
       [46],
       [29],
       [37.5],
       [48.25],
       [33.75],
       [28.125],
       [29.5],
       [36.25],
       [38.25],
       [34],
       [29.375],
       [36.25]], dtype=bfloat16)

In [6]:
ref_ce_loss

Array([[36, 26.75, 12.4375, 35.25, 29.125, 36.75, 39, 30.5, 25, 8.75,
        54.75, 37.5, 38.25, 33.5, 47, 43, 45.5, 48.25, 48.75, 45.25,
        35.25, 23, 47.25, 36, 45.5, 28.125, 27.125, 45.5, 30.625, 14.75,
        31.75, 38.5],
       [25.625, 28.625, 31.125, 53.5, 0, 38.5, 28.875, 24.75, 38.5,
        41.25, 54.5, 36.75, 33.25, 31.125, 23.25, 24.375, 66, 32.5, 23.5,
        19.5, 25.5, 31.5, 27.625, 33.75, 32, 17.875, 51, 31.625, 46.5,
        51.25, 24.25, 43.25],
       [19, 56, 22.875, 23, 42.75, 28.75, 36.5, 58.25, 40.5, 43, 35.5,
        23.75, 19.75, 20.375, 55, 26.125, 13.875, 42, 41, 40.5, 24.5,
        37.25, 32, 28.75, 18.125, 33.25, 37.75, 11.5, 26.75, 22.875,
        32.75, 32.75],
       [56.5, 33.5, 21.5, 19.25, 26, 47.75, 38.75, 13.8125, 21.625, 34,
        26.125, 11.0625, 39.75, 28.375, 54, 54, 24.875, 27, 38.25, 40.25,
        39.75, 65, 40.25, 36.75, 44.75, 50, 38.75, 25.75, 17.125, 29.875,
        40, 39]], dtype=bfloat16)

In [10]:
logits[0]

Array([[-6.875, -11.5, -44, ..., -43, -52.5, -30.75],
       [-38.75, -45, -23.75, ..., -53.75, -53, -29.75],
       [-28.625, -46, -19, ..., -28.75, -30.875, -35.25],
       ...,
       [-20.5, -29.125, -22.25, ..., -27.5, -17.625, -41],
       [-37.5, -18, -60.75, ..., -34.5, -45.75, -16.625],
       [-33.25, -48, -21.25, ..., -30.375, -41.75, -44]], dtype=bfloat16)

In [9]:
jnp.sum(jnp.exp(logits), axis=-1)[0]

Array([1.38281, 1, 2.21875, 4.25, 1, 1.01562, 1.78906, 1.42188, 1.16406,
       1.125, 1.25, 1.28125, 1.125, 1, 1.03125, 1.23438, 1, 1.04688,
       1.00781, 1.11719, 1, 2.14062, 1.36719, 1, 1.75, 1.64062, 2.76562,
       1.42188, 1.20312, 1.00781, 1.9375, 1.00781], dtype=bfloat16)

In [11]:
log_normalizers[0]

Array([0.324219, 0, 0.796875, 1.44531, 0, 0.0155029, 0.582031, 0.351562,
       0.152344, 0.117676, 0.222656, 0.248047, 0.117676, 0, 0.0307617,
       0.210938, 0, 0.0458984, 0.00778198, 0.11084, 0, 0.761719, 0.3125,
       0, 0.558594, 0.494141, 1.01562, 0.351562, 0.18457, 0.00778198,
       0.660156, 0.00778198], dtype=bfloat16)

In [27]:
label_logits

Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=bfloat16)

In [22]:
ref_ce_loss(xs, vocab, ys)

Array([[6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25,
        6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25,
        6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25],
       [6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25,
        6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25,
        6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25],
       [6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25,
        6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25,
        6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25],
       [6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25,
        6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25,
        6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25, 6.25]],      dtype=bfloat16)

In [11]:
assert jnp.allclose(ref_ce_loss(xs, ys, vocab), fused_ce_loss_fwd(xs, ys, vocab))

AssertionError: 