In [2]:
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda12_pip]
  Using cached jax-0.4.26-py3-none-any.whl (1.9 MB)
Collecting ml-dtypes>=0.2.0
  Downloading ml_dtypes-0.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m79.3 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cudnn-cu12<9.0,>=8.9.2.26
  Downloading nvidia_cudnn_cu12-8.9.7.29-py3-none-manylinux1_x86_64.whl (704.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m704.7/704.7 MB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting nvidia-nvjitlink-cu12>=12.1.105
  Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.1/21.1 MB[0m [31m66.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting nvidia-cublas-cu12>=12.1.3.1
  Do

In [3]:
!nvidia-smi

Wed May  1 06:30:17 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A4000    On   | 00000000:00:05.0 Off |                  Off |
| 41%   35C    P8    17W / 140W |      1MiB / 16376MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
import jax
import jax.numpy as jnp

import numpy as np

In [2]:
L = [0, 1, 2, 3]
x_np = np.array(L, dtype=np.int32)
x_jnp = jnp.array(L, dtype=jnp.int32)

x_np, x_jnp

2024-04-30 10:55:07.121811: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


(array([0, 1, 2, 3], dtype=int32), Array([0, 1, 2, 3], dtype=int32))

In [3]:
x1 = x_jnp*2
x2 = x_jnp+1
x3 = x1 + x2

x1, x2, x3

(Array([0, 2, 4, 6], dtype=int32),
 Array([1, 2, 3, 4], dtype=int32),
 Array([ 1,  4,  7, 10], dtype=int32))

In [4]:
import jax
import jax.numpy as jnp
from jax import random

key = random.PRNGKey(0)
size = 5000

x = random.normal(key, (size, size)).astype(jnp.float32)
%time x_jax = jax.device_put(x)
%time jnp.dot(x_jax, x_jax.T).block_until_ready()
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready()

CPU times: user 41 µs, sys: 28 µs, total: 69 µs
Wall time: 74.6 µs
CPU times: user 139 ms, sys: 42.2 ms, total: 181 ms
Wall time: 216 ms
9.3 ms ± 43.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
from jax import grad, jit

key = random.PRNGKey(0)

def selu_np(x, alpha=1.67, lmbda=1.05):
  return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

def selu_jax(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))

selu_jax_jit = jit(selu_jax)
%time x_jax = jax.device_put(x) 
%time selu_jax_jit(x_jax).block_until_ready() 
%timeit selu_jax_jit(x_jax).block_until_ready()

CPU times: user 45 µs, sys: 27 µs, total: 72 µs
Wall time: 78 µs
CPU times: user 68.1 ms, sys: 91 µs, total: 68.2 ms
Wall time: 83.9 ms
88.1 µs ± 3.71 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [7]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np

def fn(x):
    return x + x*x + x*x*x + x*x*x*x

x_np = np.random.randn(5000,5000).astype(dtype='float32')
x_jnp = jnp.array(x_np)  #numpy- >jax DeviceArray

%timeit fn(x_np)
%timeit jit(fn)(x_jnp).block_until_ready()

#Microsecond is one millionth of a second. Millisecond is one thousandth of a second.


CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated cuDNN installation found.
Version JAX was built against: 8906
Minimum supported: 8900
Installed version: 8302
The local installation version must be no lower than 8900. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


240 ms ± 1.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
29.5 ms ± 475 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
import torch

In [12]:
import time
import jax.numpy as jnp
from jax import jit
import torch

# Define JAX matrix multiplication function
def jax_matmul(A, B):
    return jnp.dot(A, B)

# Add JIT compilation for performance
jax_matmul_jit = jit(jax_matmul)

# Define PyTorch matrix multiplication function
def torch_matmul(A, B):
    return torch.matmul(A, B)

# Generate large matrices
matrix_size = 1000
#A_jax = jnp.random.randn(matrix_size, matrix_size)
A_np = np.random.randn(matrix_size,matrix_size).astype(dtype='float32')
A_jax = jnp.array(x_np)

#B_jax = jnp.random.randn(matrix_size, matrix_size)
B_np = np.random.randn(matrix_size,matrix_size).astype(dtype='float32')
B_jax = jnp.array(x_np)

A_torch = torch.randn(matrix_size, matrix_size)
B_torch = torch.randn(matrix_size, matrix_size)

# Measure execution time for JAX
start_time = time.time()
result_jax = jax_matmul_jit(A_jax, B_jax)
jax_execution_time = time.time() - start_time

# Measure execution time for PyTorch
start_time = time.time()
result_torch = torch_matmul(A_torch, B_torch)
torch_execution_time = time.time() - start_time

# Measure execution time for numpy
start_time = time.time()
result_numpy = np.dot(A_np, B_np)
numpy_execution_time = time.time() - start_time

print("JAX execution time:", jax_execution_time, "seconds")
print("PyTorch execution time:", torch_execution_time, "seconds")
print("numpy execution time:", numpy_execution_time, "seconds")




JAX execution time: 0.7215843200683594 seconds
PyTorch execution time: 0.01124262809753418 seconds
numpy execution time: 0.015349864959716797 seconds


In [16]:
import time
import jax.numpy as jnp
from jax import jit, random
import torch


In [15]:

# Define JAX matrix multiplication function
def jax_matmul(A, B):
    return jnp.dot(A, B)

# Add JIT compilation for performance
jax_matmul_jit = jit(jax_matmul)

# Define PyTorch matrix multiplication function
def torch_matmul(A, B):
    return torch.matmul(A, B)

# Generate large matrices
matrix_size = 1000
key = random.PRNGKey(0)
A_jax = random.normal(key, (matrix_size, matrix_size))
B_jax = random.normal(key, (matrix_size, matrix_size))
A_torch = torch.randn(matrix_size, matrix_size)
B_torch = torch.randn(matrix_size, matrix_size)

# Warm-up runs
for _ in range(10):
    jax_matmul_jit(A_jax, B_jax)
    torch_matmul(A_torch, B_torch)

# Measure execution time for JAX
start_time = time.time()
result_jax = jax_matmul_jit(A_jax, B_jax).block_until_ready()
jax_execution_time = time.time() - start_time

# Measure execution time for PyTorch
start_time = time.time()
result_torch = torch_matmul(A_torch, B_torch)
torch_execution_time = time.time() - start_time

print("JAX execution time:", jax_execution_time, "seconds")
print("PyTorch execution time:", torch_execution_time, "seconds")


JAX execution time: 0.00592041015625 seconds
PyTorch execution time: 0.017140865325927734 seconds


In [19]:
import time
import jax.numpy as jnp
from jax import grad, jit

# Define the function to minimize
def f(x):
    return x**2 + 3*x + 5

# Define the gradient of the function
grad_f = grad(f)

# Initial guess for the minimum
x = 0.0

# Learning rate
learning_rate = 0.1

# Warm-up run
x -= learning_rate * grad_f(x)

# Perform gradient descent and measure time
start_time = time.time()
for i in range(100):
    x -= learning_rate * grad_f(x)
jax_execution_time = time.time() - start_time

print("Minimum (JAX):", x)
print("Execution time (JAX):", jax_execution_time, "seconds")



Minimum (JAX): -1.4999998
Execution time (JAX): 0.22281455993652344 seconds


In [20]:
import time
import torch

# Define the function to minimize
def f(x):
    return x**2 + 3*x + 5

# Convert the function to a PyTorch tensor
x = torch.tensor([0.0], requires_grad=True)

# Learning rate
learning_rate = 0.1

# Warm-up run
y = f(x)
y.backward()
with torch.no_grad():
    x -= learning_rate * x.grad
x.grad.zero_()

# Perform gradient descent and measure time
start_time = time.time()
for i in range(100):
    y = f(x)
    y.backward()
    with torch.no_grad():
        x -= learning_rate * x.grad
    x.grad.zero_()
torch_execution_time = time.time() - start_time

print("Minimum (PyTorch):", x.item())
print("Execution time (PyTorch):", torch_execution_time, "seconds")



Minimum (PyTorch): -1.499999761581421
Execution time (PyTorch): 0.007425785064697266 seconds


In [21]:
f = lambda x:x**2
f(3)
jax.grad(f)(3.0)

Array(6., dtype=float32, weak_type=True)

In [24]:
x = torch.tensor(
    3.0,
    requires_grad = True
)

y = x**2
y.backward()
x.grad

tensor(6.)

In [25]:
import jax.numpy as jnp
from jax import grad

# Define the function to differentiate
def f(x):
    return x**2 + 3*x + 5

# Define the derivative of the function using JAX's grad function
df_dx = grad(f)

# Test the derivative at a specific point
x_value = 2.0
derivative_value = df_dx(x_value)
print("Derivative (JAX) at x =", x_value, ":", derivative_value)


Derivative (JAX) at x = 2.0 : 7.0


In [26]:
import torch

# Define the function to differentiate
def f(x):
    return x**2 + 3*x + 5

# Convert the function to a PyTorch tensor
x = torch.tensor([2.0], requires_grad=True)

# Calculate the derivative using PyTorch's autograd mechanism
y = f(x)
y.backward()
derivative_value = x.grad.item()
print("Derivative (PyTorch) at x =", x.item(), ":", derivative_value)


Derivative (PyTorch) at x = 2.0 : 7.0
