In [2]:
import functools
from typing import Callable

import jax
from jax.experimental import pallas as pl
# from jax.experimental.pallas import tpu as pltpu
from jax import random
import jax.numpy as jnp
import numpy as np
import timeit

I'm doing these on a GTX 3090. 99KB of shared memory is available for each thread block

### Constant Add Block

In [9]:
def constant_vector_add_kernel(x_ref, y_ref, out_ref):
    out_ref[...] = x_ref[...] + y_ref[...]


@functools.partial(jax.jit, static_argnames=("b"))
def constant_vector_add(
    x: jax.Array,
    y: jax.Array,
    *,
    b: int = 128,
):
    
    assert(x.shape == y.shape)
    assert(len(x.shape) == 1)
    print(b)
    m = x.shape[0]
    return pl.pallas_call(
        constant_vector_add_kernel,
        out_shape=jax.ShapeDtypeStruct((m,), x.dtype),
        in_specs=[
            pl.BlockSpec(lambda i: (i,), (b,)),
            pl.BlockSpec(lambda i: (i,), (b,)),
        ],
        out_specs= pl.BlockSpec(lambda i: (i,), (b,)),
        grid=(m // b,),
    )(x, y)


In [10]:
N = 2**20 # 1M elements
x = jnp.ones(N, dtype=jnp.float16)
y = jnp.ones(N, dtype=jnp.float16)

out1 = constant_vector_add(x, y)
out2 = x + y
jnp.allclose(out1, out2, atol=1e-2, rtol=0)

128


Array(True, dtype=bool)

In [None]:
# 2**15 = 32768
%timeit constant_vector_add(x, y, b=2**12).block_until_ready()

In [48]:
%timeit (x + y).block_until_ready()

44.3 µs ± 3.06 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


No matter the block size I don't get this error

λ ~/code/learnjax: python3 pallas.py 4096 --block_size 512
4096 4096 4096
(8, 8, 8)
Traceback (most recent call last):
  File "/home/dom/code/learnjax/pallas.py", line 72, in <module>
    fire.Fire(main)
  File "/home/dom/.local/lib/python3.11/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dom/.local/lib/python3.11/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/dom/.local/lib/python3.11/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dom/code/learnjax/pallas.py", line 58, in main
    result = matmul(
             ^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Shared memory size limit exceeded: requested 1115136, available: 101376

    Shared memory size limit exceeded
    


In [8]:
constant_vector_add(x, y, b=2**18)

262144


Array([2., 2., 2., ..., 2., 2., 2.], dtype=float16)

I'm not sure why for the addition kernel this doesn't seem to be a problem.

Ok, after a bit of digging it seems that there is static and dynamic shared memory types. In a matrix  multiplication kernel only static memory is used and so if it exceeds the size it throws an error. In the addition kernel dynamic memory is used and so this will just stall if the block is too large.

### Puzzle 3: Outer Vector Add

Add two vectors.

Uses one program block axis. Block size B0 is always the same as vector x length N0. Block size B1 is always the same as vector y length N1.

z_ij = x_i + y_j

i is rows j is columns

In [15]:
x = jnp.ones(10).reshape(10, 1)
y = jnp.ones(10).reshape(1, 10)

In [16]:
x + y

Array([[2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
       [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.]], dtype=float32)

In [111]:
def outer_vector_add_kernel(x_ref, y_ref, out_ref):
    out_ref[...] = jnp.add(x_ref[...][:, None], y_ref[...][None, :])


@functools.partial(jax.jit, static_argnames=("bm", "bn"))
def outer_vector_add(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bn: int = 128,
):
    m = x.shape[0]
    n = y.shape[0]
    grid=(m // bm, n // bn)
    return pl.pallas_call(
        outer_vector_add_kernel,
        out_shape=jax.ShapeDtypeStruct((m,n), x.dtype),
        in_specs=[
            pl.BlockSpec(lambda i, j: (i,), (bm,)),
            pl.BlockSpec(lambda i, j: (j,), (bn,)),
        ],
        out_specs= pl.BlockSpec(lambda i, j: (i,j), (bm,bn)),
        grid=grid,
#         interpret=True,
    )(x, y)

x = jnp.ones(4096) + 2
y = jnp.ones(4096) + 4
outer_vector_add(x, y)

Array([[8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       ...,
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.],
       [8., 8., 8., ..., 8., 8., 8.]], dtype=float32)

In [124]:
jnp.multiply(x[:, None], y[None, :]).block_until_ready()

Array([[15., 15., 15., ..., 15., 15., 15.],
       [15., 15., 15., ..., 15., 15., 15.],
       [15., 15., 15., ..., 15., 15., 15.],
       ...,
       [15., 15., 15., ..., 15., 15., 15.],
       [15., 15., 15., ..., 15., 15., 15.],
       [15., 15., 15., ..., 15., 15., 15.]], dtype=float32)

In [114]:
%timeit jnp.add(x[:, None], y[None, :]).block_until_ready()

370 µs ± 7.98 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [121]:
%timeit outer_vector_add(x, y).block_until_ready()

142 µs ± 2.68 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [122]:
outer_vector_add(x, jnp.ones(2048)).shape

(4096, 2048)

In [145]:
### Puzzle 5: Outer Vector Multiplication with Activation

def p5_kernel(x_ref, y_ref, z_ref, *, activation):
    z_ref[...] = activation(jnp.multiply(x_ref[...][:, None], y_ref[...][None, :]))


@functools.partial(jax.jit, static_argnames=("bm", "bn", "activation"))
def p5(
    x: jax.Array,
    y: jax.Array,
    *,
    bm: int = 128,
    bn: int = 128,
    activation = jax.nn.relu,
):
    m = x.shape[0]
    n = y.shape[0]
    grid=(m // bm, n // bn)
    return pl.pallas_call(
        functools.partial(p5_kernel, activation=activation),
        out_shape=jax.ShapeDtypeStruct((m,n), x.dtype),
        in_specs=[
            pl.BlockSpec(lambda i, j: (i,), (bm,)),
            pl.BlockSpec(lambda i, j: (j,), (bn,)),
        ],
        out_specs= pl.BlockSpec(lambda i, j: (i,j), (bm,bn)),
        grid=grid,
#         interpret=True,
    )(x, y)

x = jnp.ones(4096) + 2
y = jnp.ones(4096) - 4
res = p5(x, y)
# all 0 since 2 * -4 is -8 and relu(-8) = 0
print(res)

x = jnp.ones(4096) + 2
y = jnp.ones(4096) + 4
res = p5(x, y)
print(res)

[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[15. 15. 15. ... 15. 15. 15.]
 [15. 15. 15. ... 15. 15. 15.]
 [15. 15. 15. ... 15. 15. 15.]
 ...
 [15. 15. 15. ... 15. 15. 15.]
 [15. 15. 15. ... 15. 15. 15.]
 [15. 15. 15. ... 15. 15. 15.]]


### Puzzle 7: Long Sum

Sum of a batch of numbers.

Uses one program blocks. Block size B0 represents a range of batches of x of length N0. Each element is of length T. Process it B1 < T elements at a time.

In [327]:
def p7_kernel(x_ref, z_ref):  
    @pl.when(pl.program_id(axis=1) == 0)
    def _():
        z_ref[...] = jnp.zeros_like(z_ref)    
    z_ref[...] += jnp.sum(x_ref[...])


@functools.partial(jax.jit, static_argnames=("block_size"))
def p7(
    x: jax.Array,
    *,
    block_size: int = 128,
):
    b, m = x.shape
    grid=(b, m // block_size)
    print(grid)
    return pl.pallas_call(
        p7_kernel,
        in_specs=[
            pl.BlockSpec(lambda i, j: (i,j), (None, block_size,)),
        ],
        out_specs=pl.BlockSpec(lambda i, j: (i,), (1,)),
        grid=grid,
        out_shape=jax.ShapeDtypeStruct((b,), x.dtype),
        interpret=True,
    )(x)

x = jnp.ones((4, 512))
res = p7(x)
print(res)
print(jnp.sum(x, axis=1))

(4, 4)
[512. 512. 512. 512.]
[512. 512. 512. 512.]


In [330]:
jax.lax.fori_loop(0, 10, lambda i, val: i + val, 0)

Array(45, dtype=int32, weak_type=True)