To check jax version, write in the terminal `nvcc --version`

## Jax Example - Multiplying Matrixes

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)



[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [3]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

271 ms ± 9.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

252 ms ± 2.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

232 ms ± 5.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
