In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5"  # Use only the 5th GPU 
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform' # makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed

`XLA_PYTHON_CLIENT_ALLOCATOR=platform`
This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed (note that this is the only configuration that will deallocate GPU memory, instead of reusing it). This is very slow, so is not recommended for general use, but may be useful for running with the minimal possible GPU memory footprint or debugging OOM failures.

more on jax preallocation: https://kolonist26-jax-kr.readthedocs.io/en/latest/gpu_memory_allocation.html

In [2]:
import numpy as np

x = np.random.rand(2000, 2000)
print(x)

[[0.99012811 0.24143892 0.4357765  ... 0.33194329 0.85512127 0.5159837 ]
 [0.31044338 0.61214593 0.51972409 ... 0.09486678 0.13285493 0.79452527]
 [0.22408187 0.72155976 0.54728003 ... 0.24928772 0.20035903 0.46038213]
 ...
 [0.74549526 0.98701231 0.28679011 ... 0.0438974  0.22062398 0.95927464]
 [0.76143866 0.13367578 0.38313419 ... 0.34832691 0.67910519 0.043902  ]
 [0.29180837 0.15810088 0.15138866 ... 0.26024079 0.47435797 0.07594437]]


In [3]:
import jax.numpy as jnp

y = jnp.array(x) # device array
print(y)
print('*'*30)
print(type(y))

2026-01-03 06:04:54.701339: W external/xla/xla/service/gpu/nvptx_compiler.cc:930] The NVIDIA driver's CUDA version is 12.8 which is older than the PTX compiler version 12.9.86. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


[[0.9901281  0.24143893 0.4357765  ... 0.33194327 0.85512125 0.5159837 ]
 [0.31044337 0.6121459  0.5197241  ... 0.09486678 0.13285494 0.79452527]
 [0.22408187 0.72155976 0.54728    ... 0.24928772 0.20035903 0.46038213]
 ...
 [0.74549526 0.9870123  0.2867901  ... 0.0438974  0.22062398 0.95927465]
 [0.76143867 0.13367578 0.3831342  ... 0.34832692 0.67910516 0.043902  ]
 [0.29180837 0.15810087 0.15138866 ... 0.2602408  0.47435796 0.07594437]]
******************************
<class 'jaxlib.xla_extension.ArrayImpl'>


In [4]:
2*y

Array([[1.9802562 , 0.48287785, 0.871553  , ..., 0.66388655, 1.7102425 ,
        1.0319674 ],
       [0.62088674, 1.2242918 , 1.0394481 , ..., 0.18973356, 0.26570988,
        1.5890505 ],
       [0.44816375, 1.4431195 , 1.09456   , ..., 0.49857545, 0.40071806,
        0.92076427],
       ...,
       [1.4909905 , 1.9740247 , 0.5735802 , ..., 0.0877948 , 0.44124797,
        1.9185493 ],
       [1.5228773 , 0.26735157, 0.7662684 , ..., 0.69665384, 1.3582103 ,
        0.087804  ],
       [0.58361673, 0.31620175, 0.30277732, ..., 0.5204816 , 0.9487159 ,
        0.15188874]], dtype=float32)

In [6]:
y[0, 0]

Array(0.9901281, dtype=float32)

In [7]:
y - y.mean(0)

Array([[ 0.48538   , -0.26910245, -0.05847731, ..., -0.16983068,
         0.3532893 ,  0.01842991],
       [-0.19430473,  0.10160452,  0.02547026, ..., -0.40690717,
        -0.368977  ,  0.29697147],
       [-0.28066623,  0.21101838,  0.0530262 , ..., -0.25248623,
        -0.3014729 , -0.03717166],
       ...,
       [ 0.24074715,  0.47647095, -0.20746371, ..., -0.45787656,
        -0.28120798,  0.46172085],
       [ 0.25669056, -0.3768656 , -0.11111963, ..., -0.15344703,
         0.17727321, -0.4536518 ],
       [-0.21293974, -0.3524405 , -0.34286517, ..., -0.24153316,
        -0.02747399, -0.42160943]], dtype=float32)

In [8]:
jnp.dot(y, y)

Array([[513.2312 , 518.92615, 505.31046, ..., 512.65497, 513.17426,
        508.769  ],
       [515.13104, 517.30444, 498.23804, ..., 509.46625, 506.96686,
        501.28946],
       [496.98447, 504.31183, 483.8814 , ..., 497.3886 , 498.64545,
        492.8675 ],
       ...,
       [500.3337 , 505.61398, 488.80203, ..., 490.40857, 502.7337 ,
        498.05475],
       [501.70673, 509.26157, 493.6596 , ..., 499.99893, 498.32538,
        491.46646],
       [499.55164, 503.30927, 492.6638 , ..., 509.51663, 504.45865,
        494.64453]], dtype=float32)

In [9]:
%timeit np.dot(x, x)
%timeit jnp.dot(y, y)


24.8 ms ± 250 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
56.8 μs ± 13.4 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


# JIT Compilation

In [13]:
def f(x):
    for i in range(10):
        x -= 0.1*x
    return x

f(y)

Array([[0.34523636, 0.08418455, 0.15194586, ..., 0.11574148, 0.29816237,
        0.17991239],
       [0.1082449 , 0.21344209, 0.18121658, ..., 0.033078  , 0.04632366,
        0.27703387],
       [0.07813253, 0.25159234, 0.19082475, ..., 0.08692125, 0.06986087,
        0.16052532],
       ...,
       [0.25993812, 0.34414992, 0.09999753, ..., 0.01530608, 0.07692683,
        0.33447838],
       [0.26549724, 0.04660987, 0.13359062, ..., 0.12145408, 0.23678933,
        0.01530768],
       [0.1017473 , 0.05512637, 0.05278597, ..., 0.09074036, 0.16539839,
        0.02648016]], dtype=float32)

In [16]:
%timeit f(x)
%timeit f(y)

28.1 ms ± 285 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
1.27 ms ± 5.29 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [17]:
from jax import jit # just in time compilation
g = jit(f)
g(y)

Array([[0.34523636, 0.08418455, 0.15194586, ..., 0.11574148, 0.29816237,
        0.17991239],
       [0.1082449 , 0.21344209, 0.18121658, ..., 0.033078  , 0.04632366,
        0.27703387],
       [0.07813253, 0.25159234, 0.19082475, ..., 0.08692125, 0.06986087,
        0.16052532],
       ...,
       [0.25993812, 0.34414992, 0.09999753, ..., 0.01530608, 0.07692683,
        0.33447838],
       [0.26549724, 0.04660987, 0.13359062, ..., 0.12145408, 0.23678933,
        0.01530768],
       [0.1017473 , 0.05512637, 0.05278597, ..., 0.09074036, 0.16539839,
        0.02648016]], dtype=float32)

In [18]:
%timeit g(y)

48.6 μs ± 295 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## Automatic differentiation

In [22]:
def f(x):
    return x * jnp.sin(x)

f(4)

Array(-3.02721, dtype=float32, weak_type=True)

In [None]:
def grad_f(x):
    return jnp.sin(x) + x * jnp.cos(x)

grad_f(4)

Array(-3.3713772, dtype=float32, weak_type=True)

In [29]:
from jax import grad

grad_f_jax = grad(f)
grad_f_jax(4.0) # cant take jax grad of an integer, must pass as float

Array(-3.3713772, dtype=float32, weak_type=True)

## Vectorization

In [31]:
def square(x):
    return jnp.sum(x ** 2)

square(jnp.arange(10))

Array(285, dtype=int32)

In [32]:
x = jnp.arange(100).reshape(10, 10)
[square(row) for row in x]

[Array(285, dtype=int32),
 Array(2185, dtype=int32),
 Array(6085, dtype=int32),
 Array(11985, dtype=int32),
 Array(19885, dtype=int32),
 Array(29785, dtype=int32),
 Array(41685, dtype=int32),
 Array(55585, dtype=int32),
 Array(71485, dtype=int32),
 Array(89385, dtype=int32)]

In [33]:
from jax import vmap
vmap(square)(x)

Array([  285,  2185,  6085, 11985, 19885, 29785, 41685, 55585, 71485,
       89385], dtype=int32)

In [2]:
i : int = 5
print(i)

5
