In [2]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx

In [3]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "bfloat16") # Set the default precision for matrix multiplication

#os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

%timeit (A@A).block_until_ready()

JAX version: 0.4.33
Available devices: [CudaDevice(id=0)]
Using device: gpu
1.24 ms ± 3.92 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [4]:
n_features = 4096

class Net(nnx.Module):

  def __init__(self, dtype: jnp.dtype, rngs: nnx.Rngs):
    self.layer1 = nnx.Linear(n_features, n_features, dtype=dtype, rngs=rngs)
    self.layer2 = nnx.Linear(n_features, n_features, dtype=dtype, rngs=rngs)
    self.layer3 = nnx.Linear(n_features, n_features, dtype=dtype, rngs=rngs)
    self.out = nnx.Linear(n_features, 2, dtype=dtype, rngs=rngs)


  def __call__(self, x):
    x = nnx.gelu(self.layer1(x), approximate=True)
    x = nnx.gelu(self.layer2(x), approximate=True)
    x = nnx.gelu(self.layer3(x), approximate=True)
    y = self.out(x)
    return y


In [5]:
key = jax.random.PRNGKey(0)
rngs = nnx.Rngs(key)
m = Net(dtype=jnp.float32, rngs=rngs)

In [13]:
B = 32
N = B*10000

X = jax.random.normal(key=key, shape=(N, n_features), dtype=jnp.float32).reshape(N // B, B, n_features)
Y = jax.random.randint(key=key, shape=(N, 1), minval=0, maxval=2, dtype=jnp.int8).reshape( N // B, B, 1)
print(X.shape, Y.shape)
print(X[0,0,0], Y[0,:5, :])

(10000, 32, 4096) (10000, 32, 1)
-0.05067226 [[0]
 [1]
 [1]
 [1]
 [1]]


In [14]:
optimizer = nnx.Optimizer(m, optax.adamw(3e-4))

In [21]:
@nnx.jit(donate_argnums=(0,1))
def train_step(model, optimizer, X, Y, i):
    x, y = X[i, ...], Y[i, ...]
    def loss_fn(model, x, y):
        logits = model(x)
        loss = optax.softmax_cross_entropy(logits, y).mean()
        return loss

    loss, grads =  nnx.value_and_grad(loss_fn)(model, x, y)
    optimizer.update(grads)
    return loss


In [23]:
import time
from IPython.display import clear_output

num_epochs = 1

avg_iter_time = -1
for e in range(num_epochs):
  for i in range( N // B ):
    start = time.time()
    loss = train_step(m, optimizer, X, Y, i)
    jax.block_until_ready(loss)
    if avg_iter_time == -1:
      avg_iter_time = (time.time() - start)*1000
    else:
      avg_iter_time = (avg_iter_time * i + (time.time() - start)*1000) / (i + 1)
    print(f"Epoch: {e}, Iter: {i}, Loss: {loss:0.4f}, Iter time: {avg_iter_time:0.4f} ms")
    clear_output(wait=True)

Epoch: 0, Iter: 9999, Loss: 0.7365, Iter time: 2.9160 ms
