In [2]:
import jax.numpy as jnp
from jax import grad,jit,vmap
from jax import random
import numpy as np
import jax

In [3]:
devices = jax.devices()
print(devices)

[cuda(id=0)]


In [4]:
# test 1
def fn(x):
    return x*x*x*x + x*x*x + x*x + x 
    # x^4+x^3+x^2+x
x_np = np.random.randn(1,1).astype(dtype='float32')
x_jnp = jnp.array(x_np) 

In [5]:
%timeit fn(x_np) #numpy on cpu
%timeit jit(fn)(x_np).block_until_ready() #jit on gpu

6.4 µs ± 67.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
694 µs ± 70.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
# test 2
def matmul(a, b):
    return jnp.dot(a, b)

# Generate random matrices
key = random.PRNGKey(0)
a = random.normal(key, (5000, 5000), dtype=jnp.float32)
b = random.normal(key, (5000, 5000), dtype=jnp.float32)

# JIT compile the matrix multiplication function
matmul_jit = jit(matmul)

# Timing with JAX on GPU
print("Matrix Multiplication (JAX, JIT):")
%timeit matmul_jit(a, b).block_until_ready()

# Timing with NumPy on CPU
a_np = a
b_np = b
print("Matrix Multiplication (numpy):")
%timeit np.dot(a_np, b_np)


Matrix Multiplication (JAX, JIT):
145 ms ± 407 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Matrix Multiplication (numpy):
1.67 s ± 49.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
# test 3
def elementwise_fn(x):
    return x**4 + x**3 + x**2 + x

# Generate a large random array
x_np = np.random.randn(50000000).astype(dtype='float32')
x_jnp = jnp.array(x_np)

# JIT compile the element-wise function
elementwise_fn_jit = jit(elementwise_fn)

# Timing with JAX on GPU
print("Element-wise Function (JAX, JIT):")
%timeit elementwise_fn_jit(x_jnp).block_until_ready()

# Timing with NumPy on CPU
print("Element-wise Function (numpy):")
%timeit elementwise_fn(x_np)


Element-wise Function (JAX, JIT):
6.75 ms ± 46.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Element-wise Function (numpy):
1.78 s ± 123 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [17]:
# test 4
def sort_fn(x):
    return jnp.sort(x)

# Generate a large random array
x_np = np.random.randn(10000000).astype(dtype='float32')
x_jnp = jnp.array(x_np)

# Timing with JAX on CPU
print("Sorting (JAX, CPU):")
%timeit sort_fn(x_jnp).block_until_ready()

# Timing with NumPy on CPU
print("Sorting (numpy):")
%timeit np.sort(x_np)


Sorting (JAX, CPU):
211 ms ± 3.46 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Sorting (numpy):
1.03 s ± 8.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
