In [1]:
import jax
import jax.numpy as jnp

from einops import rearrange, repeat, reduce, einsum

In [39]:
from flax import nnx

In [3]:
key = jax.random.PRNGKey(0)

2024-12-20 15:45:55.829035: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.6.85). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


### Patchify

In [9]:
B, C, H, W = 10, 1, 1920, 1920
x1 = jax.random.normal(key, (B, C, H, W))
x1.shape

(10, 1, 1920, 1920)

In [10]:
patch_size = 128
rearrange(x1, 'b c (h p1) (w p2) -> b c h w p1 p2', p1=patch_size, p2=patch_size).shape

(10, 1, 15, 15, 128, 128)

### Multi head attention

In [21]:
B, T, E = 10, 1024, 128
N = 4  # 4 heads
qkv = jax.random.normal(key, (B, T, 3*E)) # Each of q, k, v has a T, respectively.

In [23]:
# before the scaled dot product, we need
# [B, 3T, E] -> [B, N, 3T, D] where N is number of heads, D is the head size
q, k, v = rearrange(qkv, 'B T (three N D) -> three B N T D', N=N, three=3)

In [24]:
q.shape

(10, 4, 1024, 32)

In [38]:
# Scaled dot product to get attention score
scale = 1/jnp.sqrt(E/N)
atten_score = einsum(q, k, 'B N T1 D, B N T2 D -> B N T1 T2') * scale