In [1]:
from jax import numpy as jnp
import optax
from fused_ce_loss import fused_ce_loss_fwd

In [2]:
from jax import numpy as jnp
from jax import random as jr

vocab_size = 50257
embed_size = 768
batch_size = 32
seq_len = 128

xs = jr.normal(jr.PRNGKey(0), (batch_size, seq_len, embed_size), dtype=jnp.bfloat16)
ys = jr.randint(jr.PRNGKey(1), (batch_size, seq_len), minval=0, maxval=vocab_size)
vocab = jr.normal(jr.PRNGKey(2), (vocab_size, embed_size), dtype=jnp.bfloat16)

2024-06-24 02:33:33.556622: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
xs = jnp.ones((batch_size, seq_len, embed_size), dtype=jnp.bfloat16)
ys = jr.randint(jr.PRNGKey(1), (batch_size, seq_len), minval=0, maxval=vocab_size)
vocab = jnp.ones((vocab_size, embed_size), dtype=jnp.bfloat16)

In [4]:
fused_ce_loss_fwd(xs, vocab, ys);

 gX : gmem_ptr[16b](0x7fbfb4fa2900) o (_128,_8,96)

:(768,_1,_8)
 gV : gmem_ptr[16b](0x7fbfb55a2900) o (_128,_8,393,96):(768,_1,98304,_8)
 tVgV : gmem_ptr[16b](0x7fbfb55a2900) o ((_1,_1),_4,_1,393,96):((_0,_0),24576,_0,98304,_8)
 tXgX : gmem_ptr[16b](0x7fbfb4fa2900) o ((_1,_1),_4,_1,96):((_0,_0),24576,_0,_8)
 sC : smem_ptr[16b](0x7fc451001000) o (_128,_128):(_128,_1)
 tCsC : smem_ptr[16b](0x7fc451001000) o (_1,_8,_8):(_0,_2048,_16)
 tCrC : ptr[16b](0x7fc44ffffc60) o (_1,_8,_8):(_0,_1,_8)
 reduce_coord : (0,0)
 tRsC : smem_ptr[16b](0x7fc451001000) o (_1,_64):(_0,_2)
 tRsN : smem_ptr[16b](0x7fc451009000) o ((_1)):((_0))
 tRsM : smem_ptr[16b](0x7fc451009100) o ((_1)):((_0))
 get<0>(reduce_tiler) : _128:_1
 select<0>(reduce_tiler) : (_128):(_1)
 size(tRsC) _64
 V_BLOCK_MAX : 393
 E_TILE_MAX : 96
 tCrC(0) : 8.000000
 tCrC(0) + 1 : 9.000000
 tXsX(2) : 1.000000
 tVsV(2) : 1.000000
 tCrC2(0) : 8.000000
 tCrC(0) : 16.000000
 tCrC(0) + 1 : 17.000000
 tXsX(2) : 1.000000
 tVsV(2) : 1.000000
 tCrC2(0) : 8.000000
 tCrC(0) : 24.000000
 tCrC(0) + 1 : 2

In [5]:
def ref_ce_loss(xs, ys, vocab):
    logits = jnp.einsum('b s e, v e->b s v', xs, vocab)
    return optax.losses.softmax_cross_entropy_with_integer_labels(logits, ys)

):(768,_1,98304,_8)
 tVgV : gmem_ptr[16b](0x7f68e95a2900) o ((_1,_1),_4,_1,393,96):((_0,_0),24576,_0,98304,_8)
 tXgX : gmem_ptr[16b](0x7f68e8fa2900) o ((_1,_1),_4,_1,96):((_0,_0),24576,_0,_8)
 sC : smem_ptr[16b](0x7f6d84001000) o (_128,_128):(_128,_1)
 tCsC : smem_ptr[16b](0x7f6d84001000) o (_1,_8,_8):(_0,_2048,_16)
 tCrC : ptr[16b](0x7f6d82fffc60) o (_1,_8,_8):(_0,_1,_8)
 reduce_coord : (0,0)
 tRsC : smem_ptr[16b](0x7f6d84001000) o (_1,_64):(_0,_2)
 tRsN : smem_ptr[16b](0x7f6d84009000) o ((_1)):((_0))
 tRsM : smem_ptr[16b](0x7f6d84009100) o ((_1)):((_0))
 get<0>(reduce_tiler) : _128:_1
 select<0>(reduce_tiler) : (_128):(_1)
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_maximum: 
atomic_max

In [11]:
assert jnp.allclose(ref_ce_loss(xs, ys, vocab), fused_ce_loss_fwd(xs, ys, vocab))

AssertionError: 