# JAX Key Aspects

Some of the example are taken from 
https://jax.readthedocs.io/en/latest/

In [1]:
#Suppress warnings
import os
import absl.logging

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
absl.logging.set_verbosity('error')
absl.logging.set_verbosity(absl.logging.ERROR)
absl.logging.set_stderrthreshold(absl.logging.ERROR)

In [2]:
import numpy as np
import tensorflow as tf
import time
import torch

## Automatic detection of the hardware back-end

In [3]:
import jax
import jax.numpy as jnp

If you have correctly installed a GPU, it will automatically detect and use for your computations.

In [4]:
x = jnp.arange(5)

#If you want to know if your array is assigned to what devices.
x.devices()

{cuda(id=0)}

If you want to force for your pourposes the CPU utilization as default don't forget to specify this.

Pay attention you have to restart the notebook if you have already execute the previous cells.

In [2]:
os.environ['JAX_PLATFORM_NAME'] = 'cpu'

In [3]:
import jax
import jax.numpy as jnp
y = jnp.arange(5)

#If you want to know if your array is assigned to what devices.
y.devices()

{CpuDevice(id=0)}

Alternatively you are forced to specify it manually. Like in this example.

In [3]:
from jax import devices
import jax
import jax.numpy as jnp

z = jax.device_put(jnp.arange(5),device=devices('cpu')[0])
w = jax.device_put(jnp.arange(5),device=devices('gpu')[0])

print("Array z:", z)
print("Array z is on device:", z.devices())

print("Array w:", w)
print("Array w is on device:", w.devices())

Array z: [0 1 2 3 4]
Array z is on device: {CpuDevice(id=0)}
Array w: [0 1 2 3 4]
Array w is on device: {cuda(id=0)}


## Similarity with NumPy

Remember that NumPy only supports CPU

### Basic NumPy Operations

In [6]:
#Basic Array Operations 

arr = np.array([1, 2, 3])
zeros = np.zeros((2, 3))
ones = np.ones((2, 3))
arange = np.arange(10)
linspace = np.linspace(0, 1, 5)
eye = np.eye(3)

np.random.seed(0)
random = np.random.rand(2, 3)

#Aritmetical operations
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])

add = a + b
subtract = a - b
multiply = a * b
divide = a / b
dot = np.dot(a, b)


#Aggregate operations
arr2 = np.array([1, 2, 3, 4, 5])

sum_arr = np.sum(arr2)
mean_arr = np.mean(arr2)
max_arr = np.max(arr2)
min_arr = np.min(arr2)

#Indexing slicing
sliced = arr2[1:4]
indexed = arr2[2]
indexed2= arr2[-1]

#Reshaping
arr3 = np.array([1, 2, 3, 4, 5, 6])

reshaped = arr3.reshape((2, 3))

#Linear Algebra
a1 = np.array([[1, 2], [3, 4]])
b1 = np.array([[5, 6], [7, 8]])

matmul = np.matmul(a1, b1)
det = np.linalg.det(a1)
inv = np.linalg.inv(a1)
eigvals, eigvecs = np.linalg.eig(a1)



print(f"arr: {arr}")
print(f"\nzeros: \n{zeros}")
print(f"\nones: \n{ones}")
print(f"\narange: {arange}")
print(f"\nlinspace: {linspace}")
print(f"\neye: \n{eye}")
print(f"\nrandom: \n{random}")

print(f"\nA: {a}")
print(f"\nB: {b}")

print(f"\nA+B: {add}")
print(f"\nA-B: {subtract}")
print(f"\nA*B: {multiply}")
print(f"\nA/B: {divide}")
print(f"\ndot: {dot}")

print(f"\narr2: {arr2}")
print(f"\nSUM (arr): {sum_arr}")
print(f"\nMEAN (arr): {mean_arr}")
print(f"\nMAX (arr): {max_arr}")
print(f"\nMIN (arr): {min_arr}")


print (f"\narr2[1:4]: {sliced}")

print(f"\narr3:{arr3}")
print(f"\nreshaped:\n{reshaped}")

print(f"\na1: \n{a1}")
print(f"\nb1: \n{b1}")
print(f"\nMatMul (a1,b1): \n{matmul}")
print(f"\nDet (a1): {det}")
print(f"\nInv (a1): \n{inv}")
print(f"\nEigenvalues (a1): {eigvals} \nEigenvectors(a1): \n{eigvecs}\n ")

arr: [1 2 3]

zeros: 
[[0. 0. 0.]
 [0. 0. 0.]]

ones: 
[[1. 1. 1.]
 [1. 1. 1.]]

arange: [0 1 2 3 4 5 6 7 8 9]

linspace: [0.   0.25 0.5  0.75 1.  ]

eye: 
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]

random: 
[[0.5488135  0.71518937 0.60276338]
 [0.54488318 0.4236548  0.64589411]]

A: [1 2 3]

B: [4 5 6]

A+B: [5 7 9]

A-B: [-3 -3 -3]

A*B: [ 4 10 18]

A/B: [0.25 0.4  0.5 ]

dot: 32

arr2: [1 2 3 4 5]

SUM (arr): 15

MEAN (arr): 3.0

MAX (arr): 5

MIN (arr): 1

arr2[1:4]: [2 3 4]

arr3:[1 2 3 4 5 6]

reshaped:
[[1 2 3]
 [4 5 6]]

a1: 
[[1 2]
 [3 4]]

b1: 
[[5 6]
 [7 8]]

MatMul (a1,b1): 
[[19 22]
 [43 50]]

Det (a1): -2.0000000000000004

Inv (a1): 
[[-2.   1. ]
 [ 1.5 -0.5]]

Eigenvalues (a1): [-0.37228132  5.37228132] 
Eigenvectors(a1): 
[[-0.82456484 -0.41597356]
 [ 0.56576746 -0.90937671]]
 


### Basic equivalent Jax Operations

In [12]:
from jax import random

#Basic Array Operations 

arr = jnp.array([1, 2, 3])
zeros = jnp.zeros((2, 3))
ones = jnp.ones((2, 3))
arange = jnp.arange(10)
linspace = jnp.linspace(0, 1, 5)
eye = jnp.eye(3)
rng = random.PRNGKey(0)
random_arr = random.uniform(rng, (2, 3))

#Aritmetical operations
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])

add = a + b
subtract = a - b
multiply = a * b
divide = a / b
dot = jnp.dot(a, b)

#Aggregate operations
arr2 = jnp.array([1, 2, 3, 4, 5])

sum_arr = jnp.sum(arr2)
mean_arr = jnp.mean(arr2)
max_arr = jnp.max(arr2)
min_arr = jnp.min(arr2)

#Indexing slicing
sliced = arr2[1:4]
indexed = arr2[2]
indexed2= arr2[-1]


#Reshaping
arr3 = jnp.array([1, 2, 3, 4, 5, 6])

reshaped = arr3.reshape((2, 3))

#Linear Algebra
a1 = jnp.array([[1, 2], [3, 4]])
b1 = jnp.array([[5, 6], [7, 8]])

matmul = jnp.matmul(a1, b1)
det = jnp.linalg.det(a1)
inv = jnp.linalg.inv(a1)
eigvals, eigvecs = jnp.linalg.eig(a1)


print(f"arr: {arr}")
print(f"\nzeros: \n{zeros}")
print(f"\nones: \n{ones}")
print(f"\narange: {arange}")
print(f"\nlinspace: {linspace}")
print(f"\neye: \n{eye}")
print(f"\nrandom: \n{random_arr}")

print(f"\nA: {a}")
print(f"\nB: {b}")

print(f"\nA+B: {add}")
print(f"\nA-B: {subtract}")
print(f"\nA*B: {multiply}")
print(f"\nA/B: {divide}")
print(f"\ndot: {dot}")

print(f"\narr: {arr2}")
print(f"\nSUM (arr): {sum_arr}")
print(f"\nMEAN (arr): {mean_arr}")
print(f"\nMAX (arr): {max_arr}")
print(f"\nMIN (arr): {min_arr}")

print(f"\narr3:{arr3}")
print(f"\nreshaped:\n{reshaped}")

print(f"\na1: \n{a1}")
print(f"\nb1: \n{b1}")
print(f"\nMatMul (a1,b1): \n{matmul}")
print(f"\nDet (a1): {det}")
print(f"\nInv (a1): \n{inv}")
print(f"\nEigenvalues (a1): {eigvals} \nEigenvectors(a1): \n{eigvecs}\n ")


arr: [1 2 3]

zeros: 
[[0. 0. 0.]
 [0. 0. 0.]]

ones: 
[[1. 1. 1.]
 [1. 1. 1.]]

arange: [0 1 2 3 4 5 6 7 8 9]

linspace: [0.   0.25 0.5  0.75 1.  ]

eye: 
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]

random: 
[[0.57450044 0.09968603 0.7419659 ]
 [0.8941783  0.59656656 0.45325184]]

A: [1 2 3]

B: [4 5 6]

A+B: [5 7 9]

A-B: [-3 -3 -3]

A*B: [ 4 10 18]

A/B: [0.25 0.4  0.5 ]

dot: 32

arr: [1 2 3 4 5]

SUM (arr): 15

MEAN (arr): 3.0

MAX (arr): 5

MIN (arr): 1

arr3:[1 2 3 4 5 6]

reshaped:
[[1 2 3]
 [4 5 6]]


### Example nr. 1 - Matrix multiplication

#### NumPy

In [4]:
np.random.seed(0)

# Generate two random matrices
size = 10000
A = np.random.rand(size, size).astype(np.float32)
B = np.random.rand(size, size).astype(np.float32)
print (np.dot(A, B))

[[2496.2744 2466.5544 2470.9575 ... 2494.8926 2525.964  2467.1694]
 [2461.0415 2469.6577 2484.0947 ... 2477.7363 2494.6555 2485.727 ]
 [2487.7373 2482.263  2493.5962 ... 2511.0745 2535.1755 2494.8848]
 ...
 [2518.0334 2511.9395 2514.2825 ... 2530.6328 2551.318  2508.673 ]
 [2512.0613 2488.3813 2483.8801 ... 2516.4438 2516.2053 2489.3577]
 [2460.9507 2452.2944 2478.3572 ... 2476.3423 2494.7036 2460.0671]]


In [5]:
%timeit np.dot(A, B)

3.53 s ± 168 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### JAX using CPU

In [4]:
#jax.config.update("jax_enable_x64", True)  # Optional: if using float64

JAX_A = jax.device_put(jnp.array(A),device=devices('cpu')[0]) 
JAX_B = jax.device_put(jnp.array(B),device=devices('cpu')[0]) 

print (jnp.dot(JAX_A, JAX_B).block_until_ready())

[[2496.2742 2466.5542 2470.9575 ... 2494.8923 2525.9636 2467.1692]
 [2461.0415 2469.6575 2484.095  ... 2477.737  2494.6555 2485.727 ]
 [2487.7368 2482.2634 2493.5952 ... 2511.0747 2535.175  2494.885 ]
 ...
 [2518.033  2511.9395 2514.2822 ... 2530.6328 2551.3179 2508.6733]
 [2512.06   2488.3813 2483.8801 ... 2516.4438 2516.206  2489.358 ]
 [2460.9504 2452.2944 2478.3574 ... 2476.3418 2494.7039 2460.0674]]


In [7]:
%timeit jnp.dot(JAX_A, JAX_B).block_until_ready()

4.22 s ± 52.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### JAX using GPU

In [None]:
JAX_A = jax.device_put(jnp.array(A),device=devices('gpu')[0]) 
JAX_B = jax.device_put(jnp.array(B),device=devices('gpu')[0]) 
print (jnp.dot(JAX_A, JAX_B).block_until_ready())

In [6]:
%timeit jnp.dot(JAX_A, JAX_B).block_until_ready()

484 ms ± 2.97 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


As we can see matrix multiplication performs better on GPU and it is achiavable only with JAX

### Example nr. 2 - Computation on specific functions

#### NumPy

In [9]:
x = np.ones((10000, 10000))
y = np.arange(10000)

z = np.sin(x) + np.cos(y)

In [2]:
%timeit np.sin(x) + np.cos(y)

819 ms ± 36.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### JAX using CPU

In [3]:
x=jax.device_put(jnp.ones((10000, 10000)),device=devices('cpu')[0]) 
y=jax.device_put(jnp.arange(10000),device=devices('cpu')[0]) 

z = jnp.sin(x) + jnp.cos(y) 

In [5]:
%timeit jnp.sin(x) + jnp.cos(y) 

180 ms ± 7.61 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### JAX using GPU

In [6]:
x=jax.device_put(jnp.ones((10000, 10000)),device=devices('gpu')[0]) 
y=jax.device_put(jnp.arange(10000),device=devices('gpu')[0]) 

z = jnp.sin(x) + jnp.cos(y) 

In [8]:
%timeit jnp.sin(x) + jnp.cos(y) 

10.7 ms ± 19.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In this case also on CPU the operation executed with Jax performs better.

## Differences with NumPy

In JAX, data structures are **immutable** by default, meaning they cannot be modified once created. Instead, when an operation that would typically modify an array is performed, JAX returns a new array with the modifications, leaving the original array unchanged.

In [4]:
x = jnp.array([1, 2, 3, 4, 5])

#Let's try to modify an element inside the array
try:
    x[0] = 10
except Exception as e:
    print(f"Errore: {e}")

Errore: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html


In [5]:
#The correct way to do

#Creating a new array with the desired modification
y = x.at[0].set(10)

print("Original Array (x):", x)
print("Modified Array (y):", y)

Original Array (x): [1 2 3 4 5]
Modified Array (y): [10  2  3  4  5]


## Just in time compilation (JIT)

Let's take as example the computation of the softmax function. Here is reported a variant that garantees numerical stability.

### JAX jit

In [4]:
from jax import jit

# Define the softmax function
def softmax(logits):
    exp_logits = jnp.exp(logits - jnp.max(logits, axis=-1, keepdims=True))
    return exp_logits / jnp.sum(exp_logits, axis=-1, keepdims=True)

In [5]:
# Generate random logits for testing
key = jax.random.PRNGKey(0)
logits = jax.random.normal(key, (1000, 10))  # 1000 samples, 10 classes

In [6]:
# Measure time without JIT
%timeit softmax(logits)

563 µs ± 46.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [7]:
jit_softmax = jit(softmax)

In [8]:
# Measure time with JIT
%timeit jit_softmax(logits)

111 µs ± 4.82 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### Tensorflow jit

In [3]:
# Define the softmax function
@tf.function
def softmax(logits):
    # Subtract the max logit for numerical stability
    exp_logits = tf.exp(logits - tf.reduce_max(logits, axis=-1, keepdims=True))
    # Normalize to get probabilities
    return exp_logits / tf.reduce_sum(exp_logits, axis=-1, keepdims=True)

In [4]:
# Generate random logits for testing
logits = tf.random.normal((100000, 10))  # 1000 samples, 10 classes

In [5]:
%timeit softmax(logits)

697 µs ± 19.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
# Measure time with JIT
@tf.function(jit_compile=True) 
def softmax_with_jit(logits):
    return softmax(logits)

In [7]:
%timeit softmax_with_jit(logits)

443 µs ± 63 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


I0000 00:00:1717694244.790568    2604 service.cc:145] XLA service 0x5593cf978a30 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1717694244.790679    2604 service.cc:153]   StreamExecutor device (0): NVIDIA GeForce GTX 1060 6GB, Compute Capability 6.1
I0000 00:00:1717694244.984311    2604 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


### PyTorch jit

In [16]:
import torch

# Define the softmax function
def softmax(logits):
    exp_logits = torch.exp(logits - torch.max(logits))
    return exp_logits / torch.sum(exp_logits)

In [17]:
# Generate random logits
logits = torch.randn(1000, 10)

In [18]:
# Measure time without JIT
%timeit softmax(logits)

26.4 µs ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [19]:
@torch.jit.script
def softmax_jit(logits):
    exp_logits = torch.exp(logits - torch.max(logits))
    return exp_logits / torch.sum(exp_logits)

In [20]:
%timeit softmax_jit(logits)

26.4 µs ± 1.4 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## Automatic Differentiation

Let's consider a simple polinomial function: x^2+2x+1

In [4]:
def f(x):
    return x**2 + 2*x + 1

### Jax Autograd

In [5]:
import jax

x = 3.0
grad_fn = jax.grad(f)
grad = grad_fn(x)

print("Gradient of f(x) in Jax:",grad)

Gradient of f(x) in Jax: 8.0


### Tensorflow GradientTape

In [6]:
x = tf.Variable(3.0)

with tf.GradientTape() as tape:
    y = f(x)
grad = tape.gradient(y, x)

print("Gradient of f(x) in TensorFlow:",grad.numpy())

Gradient of f(x) in TensorFlow: 8.0


### PyTorch Backward

In [7]:
x = torch.tensor(3.0, requires_grad=True)
y = f(x)
y.backward()

print("Gradient of f(x) in PyTorch:",x.grad.item())

Gradient of f(x) in PyTorch: 8.0


In [8]:
# Measure time for JAX - without using jit
def jax_time():
    x = 3.0
    grad_fn = jax.grad(f)
    start_time = time.time()
    with jax.default_device(jax.devices('cpu')[0]):
    #with jax.default_device(jax.devices('gpu')[0]):
        grad = grad_fn(x)
    end_time = time.time()
    return end_time - start_time

time_jax = jax_time()

# Measure time for TensorFlow
def tensorflow_time():
    x = tf.Variable(3.0)
    start_time = time.time()
    with tf.device('/CPU:0'):
    #with tf.device('/GPU:0'):
        with tf.GradientTape() as tape:
            y = f(x)
    grad = tape.gradient(y, x)
    end_time = time.time()
    return end_time - start_time
    
time_tf=tensorflow_time()

# Measure time for PyTorch
def pytorch_time():
    device = torch.device('cpu')
    #device = torch.device('cuda')
    x = torch.tensor(3.0, requires_grad=True,device=device)
    start_time = time.time()
    y = f(x)
    y.backward()
    end_time = time.time()
    return end_time - start_time

time_pt=pytorch_time()

print("Time elapsed in JAX:", time_jax)
print("Time elapsed in TensorFlow:", time_tf)
print("Time elapsed in PyTorch:", time_pt)

Time elapsed in JAX: 0.13792634010314941
Time elapsed in TensorFlow: 0.00521397590637207
Time elapsed in PyTorch: 0.00021648406982421875


## Automatic Vectorization VMap

It allows you to apply a function to each element in a batch of inputs without manually writing loops. This leads to more concise and often more efficient code by leveraging JAX's underlying compilation and optimization.

Basic usage considering simple sum of arrays 2D.

In [4]:
import jax.numpy as jnp

#Definition of samples arrays
a=jnp.array(([1,3],[4,-1]))
b=jnp.array(([11,7],[-2,5]))
a,b

(Array([[ 1,  3],
        [ 4, -1]], dtype=int32),
 Array([[11,  7],
        [-2,  5]], dtype=int32))

In [5]:
#Classical add operation
jnp.add(a,b)

Array([[12, 10],
       [ 2,  4]], dtype=int32)

In [6]:
#Add operation using vmap row by row and the output is given by row
jax.vmap(jnp.add,in_axes=(0,0),out_axes=0)(a,b)

## [ [ (1+11) , (3+7) ]
##   [ (4-2), (-1+5) ]
## ]

Array([[12, 10],
       [ 2,  4]], dtype=int32)

In [7]:
#Add operation using vmap row by row and the output is given by column
jax.vmap(jnp.add,in_axes=(0,0),out_axes=1)(a,b)

## [ [ (1+11) , (4-2) ] 
##   [ (3+7), (-1+5) ]
## ]

Array([[12,  2],
       [10,  4]], dtype=int32)

In [8]:
jax.vmap(jnp.add,in_axes=(1,0),out_axes=0)(a,b)

a=jnp.array(([1,3],[4,-1]))
b=jnp.array(([11,7],[-2,5]))

## [ [ (1+11) , (4+7) ] #Col of a + corr. row of b
##   [ (3+7), (-1+5) ]
## ]

### A simple example

In this example, it calculates the square of each element in the array passed as argument.

In [9]:
from jax import vmap

def square(x):
  return jnp.square(x)
    
def square_numpy(arr):
    return np.square(arr)

In [10]:
numPy_numbers = np.arange(1000)
jax_numbers=jnp.arange(1000)

In [11]:
# Apply the square function using a standard loop: classical way.
start_time = time.time()
squared_numbers_loop = [square(x) for x in jax_numbers]
loop_time = time.time() - start_time
print(f"Time taken using standard loop: {loop_time:.5f} seconds")

#Apply the square function using numpy function.
start_time = time.time()
result = square_numpy(numPy_numbers)
end_time=time.time()
numpy_end_time=end_time-start_time
print(f"Time taken using numpy: {numpy_end_time:.5f} seconds")


# Apply the vectorized function
start_time = time.time()
# Vectorize the square function using vmap
vectorized_square = jax.vmap(square)
vmap_time = time.time() - start_time
print(f"Time taken using vmap: {vmap_time:.5f} seconds")

Time taken using standard loop: 0.37405 seconds
Time taken using numpy: 0.00028 seconds
Time taken using vmap: 0.00008 seconds


### A other example : matrix-vector multiplication

In [4]:
# Define matrix-vector multiplication function
@jax.jit
def matrix_vector_multiplication(matrix, vector):
    return jnp.dot(matrix, vector)

def matrix_vector_multiplication_numpy(matrix, vector):
    return np.dot(matrix, vector)

In [5]:
# Define the matrix size
matrix_size = 10000  # For example, change this to your desired size

# Create an incremental matrix  and vector for jax
matrix_jax = jnp.arange(matrix_size * matrix_size).reshape((matrix_size, matrix_size))
vector_jax = jnp.arange(matrix_size).reshape((matrix_size, 1))

matrix_np=np.arange(matrix_size * matrix_size).reshape((matrix_size, matrix_size))
vector_np=np.arange(matrix_size).reshape((matrix_size,1))

print("Incremental Matrix:")
print(matrix_jax)

print("\nIncremental Vector:")
print(vector_jax)

In [6]:
# Use vmap to perform batched matrix-vector multiplication
matvec_vmap = jax.vmap(matrix_vector_multiplication, in_axes=(0, None))

start_time=time.time()
result = matvec_vmap(matrix_jax, vector_jax)
end_time=time.time()

tot_time=end_time-start_time
print(f"Time taken using vmap: {tot_time}")

Time taken using vmap: 0.045346975326538086


In [8]:
# Apply the function directly
start_time=time.time()
result = matrix_vector_multiplication_numpy(matrix_np, vector_np)
end_time=time.time()
tot_time=end_time-start_time

print(f"Time taken using numpy: {tot_time}")

Time taken using numpy: 0.07039618492126465


## Automatic Parallelization PMap

Pmap is another transformation that enables to replicate the computation into multiple cores or devices and execute them in parallel(p in pmap stands for parallel).

To have a real benefit from this functionality you need multiple devices. 
Unless you have multiple GPUs you should restart the notebook an run the following snippets - to use your CPU multiple cores.
Suggestion: don't put num_core>your_max_hardware_num_core

In [1]:
import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
num_core= os.cpu_count()
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count='+str(num_core)
import jax
jax.devices()

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7),
 CpuDevice(id=8),
 CpuDevice(id=9),
 CpuDevice(id=10),
 CpuDevice(id=11)]

In [2]:
from jax import pmap
import numpy as np
import jax.numpy as jnp
import time

### A simple operation

In [4]:
x = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2))
y = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2

In [4]:
start_time=time.time()
out = pmap(jnp.dot)(x, y)  
end_time=time.time()

tot_time=end_time-start_time

print(f"Time taken using pmap: {tot_time}")

Time taken using pmap: 0.030210494995117188


In [5]:
start_time=time.time()
out = jnp.dot(x, y)  
end_time=time.time()

tot_time=end_time-start_time

print(f"Time taken without using pmap: {tot_time}")

Time taken without using pmap: 0.26958250999450684


### Matrix-matrix multiplication

In [3]:
#Define two matrices in this way

#Matrix used: change dimension to see difference in time
a = jnp.arange(9, dtype=jnp.float32).reshape(3, 3)
b = jnp.arange(9, dtype=jnp.float32).reshape(3, 3)

#a= jnp.arange(10000 * 10000, dtype=jnp.float32).reshape(10000, 10000)
#b= jnp.arange(10000 * 10000, dtype=jnp.float32).reshape(10000, 10000)
a,b

(Array([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]], dtype=float32),
 Array([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]], dtype=float32))

In [4]:
# Define a function for matrix multiplication
def matmul(a, b):
    return jnp.dot(a, b)

@jax.jit
def parallel_matmul(a, b):
    return matmul(a, b)

In [5]:
#matmul(a,b)

Array([[7.8120318e+16, 7.8120318e+16, 7.8120318e+16, ..., 7.8125017e+16,
        7.8125017e+16, 7.8125017e+16],
       [1.9530312e+17, 1.9530312e+17, 1.9530312e+17, ..., 1.9531719e+17,
        1.9531719e+17, 1.9531719e+17],
       [3.1248591e+17, 3.1248595e+17, 3.1248595e+17, ..., 3.1250938e+17,
        3.1250938e+17, 3.1250938e+17],
       ...,
       [1.7574687e+21, 1.7574687e+21, 1.7574687e+21, ..., 1.7576093e+21,
        1.7576093e+21, 1.7576093e+21],
       [1.7575859e+21, 1.7575859e+21, 1.7575859e+21, ..., 1.7577266e+21,
        1.7577266e+21, 1.7577266e+21],
       [1.7577031e+21, 1.7577031e+21, 1.7577031e+21, ..., 1.7578436e+21,
        1.7578436e+21, 1.7578436e+21]], dtype=float32)

In [5]:
# Prepare the matrices for pmap by adding a leading axis. 
# This op. is needed because pmap maps the function over leading axis on the first axes, 
# and you have to mantain an extra axis to make it work
a_parallel = jnp.array([a])  # Shape (1, 3, 3)
b_parallel = jnp.array([b])  # Shape (1, 3, 3)

# Apply pmap to the parallel matrix multiplication function
matmul_parallel = pmap(parallel_matmul)

# Warm-up to ensure JAX is initialized
dummy_result = matmul_parallel(a_parallel, b_parallel)

# Measure time for parallel computation using pmap
start_time = time.time()
result_parallel = matmul_parallel(a_parallel, b_parallel)

# Remove the extra axis
result_parallel = result_parallel[0]  

end_time = time.time()
tot_time_parallel = end_time - start_time

print(f"Time taken using pmap: {tot_time_parallel} seconds")

Time taken using pmap: 0.000310 seconds


In [6]:
# Measure time for non-parallel computation
start_time = time.time()
result_no_pmap = matmul(a, b)
end_time = time.time()
tot_time_no_pmap = end_time - start_time

print(f"Time taken without using pmap: {tot_time_no_pmap} seconds")

Time taken without using pmap: 0.014169692993164062 seconds


In [8]:
#Check results
print("Matrix a:\n", a)
print("Matrix b:\n", b)
print("Result of matrix mult. using pmap:\n", result_parallel)
print("Result of matrix mult. without using pmap:\n", result_no_pmap)

Matrix a:
 [[0. 1. 2.]
 [3. 4. 5.]
 [6. 7. 8.]]
Matrix b:
 [[0. 1. 2.]
 [3. 4. 5.]
 [6. 7. 8.]]
Result of matrix mult. using pmap:
 [[ 15.  18.  21.]
 [ 42.  54.  66.]
 [ 69.  90. 111.]]
Result of matrix mult. without using pmap:
 [[ 15.  18.  21.]
 [ 42.  54.  66.]
 [ 69.  90. 111.]]
