In [1]:
from timeit import timeit

In [2]:
import jax
import jax.numpy as jnp

In [3]:
import einops

### Devices

In [4]:
devices = jax.devices()
devices

[cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3)]

### Array

In [5]:
x = jnp.zeros((1, 1920, 1920))

2024-12-19 21:22:44.149787: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.6.85). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [6]:
x.devices()

{cuda(id=0)}

In [7]:
x_device1 = jax.device_put(x, devices[1])

In [8]:
x_device1.devices()

{cuda(id=1)}

### Random

In [9]:
seed = 0
key = jax.random.PRNGKey(seed)

In [10]:
xrand = jax.random.normal(key, (1, 1920, 1920))

In [11]:
xrand.devices()

{cuda(id=0)}

### JIT

In [12]:
def squared_sum(x):
    return jnp.sum(x**2)

@jax.jit
def squared_sum_jit(x):
    return jnp.sum(x**2)

In [13]:
timeit('squared_sum(xrand)', globals=globals(), number=1)

0.06594208301976323

In [14]:
timeit('squared_sum_jit(xrand)', globals=globals(), number=1)

0.04803331685252488

### Grad

In [15]:
xrand = jax.random.normal(key, (10, 1, 1920, 1920))

In [16]:
squared_sum_grad_jit = jax.grad(squared_sum_jit)

In [17]:
xrand_grad = squared_sum_grad_jit(xrand)
xrand_grad.shape

(10, 1, 1920, 1920)

In [18]:
batch_squared_sum_grad_jit = jax.vmap(squared_sum_grad_jit, in_axes=0)

In [19]:
xrand_grad_vmap = batch_squared_sum_grad_jit(xrand)

In [20]:
xrand_grad_vmap.devices()

{cuda(id=0)}

In [21]:
xrand_grad_vmap.shape

(10, 1, 1920, 1920)

In [22]:
(xrand_grad_vmap == xrand_grad).all()

Array(True, dtype=bool)

### vmap

In [58]:
# Basic vectorization along first axis (default)
f = lambda x: x + 1
vectorized_f = jax.vmap(f)  # equivalent to in_axes=0
vectorized_f(jnp.array([1, 2, 3]))  # -> [2, 3, 4]

# Add arguments with different axes
g = lambda x, y: x + y
vectorized_g = jax.vmap(g, in_axes=(0, 1))  # vectorize x along 0, y along 1

x1 = jax.random.normal(key, (1,5))
x2 = jax.random.normal(key, (5,1))

print(x1.shape)
print(x2.shape)

r = vectorized_g(x1, x2)
print(r)

(1, 5)
(5, 1)
[[ 0.3756877  -2.5666852  -0.5421834   2.4981186   0.48894006]]


In [61]:
r2 = x1+x2
r2.shape

(5, 5)

In [64]:
x1 = jax.random.normal(key, (1,5))
x2 = jax.random.normal(key, (5,))
h = lambda x, y: x + y
vectorized_h = jax.vmap(h, in_axes=(0, None))
vectorized_h(x1,x2).shape

(1, 5)

In [66]:
(x1+x2).shape

(1, 5)

In [92]:
xrand.shape

(10, 1, 1920, 1920)

In [93]:
einops.rearrange(xrand, 'b c h w -> (b c) h w').shape

(10, 1920, 1920)