# Testing Distributed GPU Training using Jax.Pmap

In [1]:
import os

import jax
import jax.numpy as jnp

# Hardware setup
print("JAX version:", jax.__version__)
devices = jax.devices()
print("Available devices:")
for d in devices:
  print(d)

#jax.config.update("jax_platform_name", "gpu") # Make sure we're using the GPU
#jax.config.update("jax_enable_x64", True) # Make sure the highest precision is enabled in case we need
#jax.config.update("jax_default_matmul_precision", "bfloat16") # Set the default precision for matrix multiplication

#os.environ["NVIDIA_TF32_OVERRIDE"] = "1"
#os.environ["JAX_ENABLE_X64"] = "False"

print("Using device:", jax.default_backend())  # Should print 'gpu'

JAX version: 0.4.30
Available devices:
cuda:0
cuda:1
cuda:2
cuda:3
cuda:4
cuda:5
cuda:6
cuda:7
Using device: gpu


## Distributed Computation

In [2]:
def f(x):
    return x ** 2

x = jnp.arange(32).reshape(8, 4)
pmap_f = jax.pmap(f)

y = pmap_f(x)
print(y)

[[  0   1   4   9]
 [ 16  25  36  49]
 [ 64  81 100 121]
 [144 169 196 225]
 [256 289 324 361]
 [400 441 484 529]
 [576 625 676 729]
 [784 841 900 961]]


## Distributed Training

In [3]:
@jax.pmap
def forward_pass(x):
    return jax.nn.relu(x)

x = jnp.ones((8, 128))  # 8 examples
y = forward_pass(x)
print(y.shape)

(8, 128)
