<a href="https://colab.research.google.com/github/novastar53/jaxpt/blob/dev/notebooks/GPU_Performance_Tuning_(Jax).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import optax
import jax.numpy as jnp
import numpy as np
from flax import nnx

In [2]:
import os

# Hardware setup
print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
jax.config.update("jax_default_matmul_precision", "float32") # Set the default precision for matrix multiplication

#os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

print("Using device:", jax.default_backend())  # Should print 'gpu'

A = jnp.array(np.random.normal(size=(4096, 4096)), dtype=jnp.float32) # Makes sure the matmul is fast

%timeit (A@A).block_until_ready()

JAX version: 0.4.33
Available devices: [CudaDevice(id=0)]
Using device: gpu
9.53 ms ± 36.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [3]:
n_features = 8196

class Net(nnx.Module):

  def __init__(self, dtype: jnp.dtype, rngs: nnx.Rngs):
    self.layer1 = nnx.Linear(n_features, n_features, dtype=dtype, rngs=rngs)
    self.layer2 = nnx.Linear(n_features, n_features, dtype=dtype, rngs=rngs)
    self.layer3 = nnx.Linear(n_features, n_features, dtype=dtype, rngs=rngs)
    self.out = nnx.Linear(n_features, 2, dtype=dtype, rngs=rngs)


  def __call__(self, x):
    x = nnx.gelu(self.layer1(x), approximate=True)
    x = nnx.gelu(self.layer2(x), approximate=True)
    x = nnx.gelu(self.layer3(x), approximate=True)
    y = self.out(x)
    return y


In [4]:
key = jax.random.PRNGKey(0)
rngs = nnx.Rngs(key)
m = Net(dtype=jnp.float32, rngs=rngs)

In [5]:
B = 32
N = B*10000

X = jax.random.normal(key=key, shape=(N, n_features), dtype=jnp.float32).reshape(N // B, B, n_features)
Y = jax.random.randint(key=key, shape=(N, 1), minval=0, maxval=2, dtype=jnp.int8).reshape( N // B, B, 1)
print(X.shape, Y.shape)
print(X[0,0,0], Y[0,:5, :])

(10000, 32, 8196) (10000, 32, 1)
-0.44350547 [[0]
 [1]
 [1]
 [1]
 [1]]


In [6]:
optimizer = nnx.Optimizer(m, optax.adamw(3e-4))

In [7]:
@nnx.jit
def train_step(model, optimizer, batch, targets):

    def loss_fn(model, batch, targets):
        logits = model(batch)
        loss = optax.softmax_cross_entropy(logits, targets).mean()
        return loss

    loss, grads =  nnx.value_and_grad(loss_fn)(model, batch, targets)
    optimizer.update(grads)
    return loss


In [8]:
import time
from IPython.display import clear_output

num_epochs = 1

avg_iter_time = -1
for e in range(num_epochs):
  for i in range( N // B ):
    start = time.time()
    x, y = X[i,...], Y[i,...]
    loss = train_step(m, optimizer, x, y)
    jax.block_until_ready(loss)
    if avg_iter_time == -1:
      avg_iter_time = (time.time() - start)*1000
    else:
      avg_iter_time = (avg_iter_time * i + (time.time() - start)*1000) / (i + 1)
    print(f"Epoch: {e}, Iter: {i}, Loss: {loss:0.4f}, Iter time: {avg_iter_time:0.4f} ms")
    clear_output(wait=True)

Epoch: 0, Iter: 9999, Loss: 0.7368, Iter time: 8.9720 ms
