In [1]:
import torch

# Shared Memory computation

## forward pass

In [2]:
# forward pass


def forward_shared_mem_in_bytes(
    q_tile_dim: int,
    kv_tile_dim: int,
    head_dim: int,
    shared_mem_padding_tile: int,
    shared_mem_padding_chunk: int,
    sof_bf16_bytes: int = 2,
    sof_fp32_bytes: int = 4,
    c_tile_dtype: str = "bf16",
    d_tile_dtype: str = "bf16",
) -> int:

    # 1. q_tile
    q_tile_size = sof_bf16_bytes * q_tile_dim * (head_dim + shared_mem_padding_tile)

    # 2. kv_tile
    kv_tile_size = sof_bf16_bytes * kv_tile_dim * (head_dim + shared_mem_padding_tile)

    # 3. cd_tile
    cd_tile_size_bf16 = sof_bf16_bytes * q_tile_dim * (kv_tile_dim + shared_mem_padding_tile)
    cd_tile_size_fp32 = sof_fp32_bytes * q_tile_dim * (kv_tile_dim + shared_mem_padding_tile)
    
    # 4. i_chunk
    i_chunk_size = sof_bf16_bytes * kv_tile_dim * (1 + shared_mem_padding_chunk)

    # 5. f_chunk
    f_chunk_size = sof_bf16_bytes * q_tile_dim * (1 + shared_mem_padding_chunk)
    f_tilecol_chunk_size = sof_fp32_bytes * q_tile_dim * (1 + shared_mem_padding_chunk)

    # 6. mnl_chunk
    mnl_chunk_size = sof_bf16_bytes * q_tile_dim * (1 + shared_mem_padding_chunk)

    total_bytes = 2 * q_tile_size + 1 * kv_tile_size + 1 * i_chunk_size + 1 * f_chunk_size + 1 * f_tilecol_chunk_size + 6 * mnl_chunk_size
    if c_tile_dtype == "fp32":
        total_bytes += cd_tile_size_fp32
    else:
        total_bytes += cd_tile_size_bf16
    if d_tile_dtype == "fp32":
        total_bytes += cd_tile_size_fp32
    else:
        total_bytes += cd_tile_size_bf16
    return total_bytes

In [3]:
Q_TILE_DIMS = [64, 128]
KV_TILE_DIMS = [64, 128]

Q_KV_TILE_DIM_TUPLES = [(64, 64), (128, 64), (128, 128)]

HEAD_DIMS = [64, 96, 128, 160, 192, 224, 256]

SHMEM_PADDING_TILE = 16
SHMEM_PADDING_CHUNK = 4

In [4]:
for q_tile_dim, kv_tile_dim in Q_KV_TILE_DIM_TUPLES:
    for head_dim in HEAD_DIMS:
        shmem_bytes = forward_shared_mem_in_bytes(
            q_tile_dim=q_tile_dim,
            kv_tile_dim=kv_tile_dim,
            head_dim=head_dim,
            shared_mem_padding_tile=SHMEM_PADDING_TILE,
            shared_mem_padding_chunk=SHMEM_PADDING_CHUNK,
            c_tile_dtype="bf16",
            d_tile_dtype="bf16",
        )
        print(f"cdTile(bf16, bf16): q_tile_dim={q_tile_dim}, kv_tile_dim={kv_tile_dim}, head_dim={head_dim}, shmem_bytes={shmem_bytes}")

cdTile(bf16, bf16): q_tile_dim=64, kv_tile_dim=64, head_dim=64, shmem_bytes=57600
cdTile(bf16, bf16): q_tile_dim=64, kv_tile_dim=64, head_dim=96, shmem_bytes=69888
cdTile(bf16, bf16): q_tile_dim=64, kv_tile_dim=64, head_dim=128, shmem_bytes=82176
cdTile(bf16, bf16): q_tile_dim=64, kv_tile_dim=64, head_dim=160, shmem_bytes=94464
cdTile(bf16, bf16): q_tile_dim=64, kv_tile_dim=64, head_dim=192, shmem_bytes=106752
cdTile(bf16, bf16): q_tile_dim=64, kv_tile_dim=64, head_dim=224, shmem_bytes=119040
cdTile(bf16, bf16): q_tile_dim=64, kv_tile_dim=64, head_dim=256, shmem_bytes=131328
cdTile(bf16, bf16): q_tile_dim=128, kv_tile_dim=64, head_dim=64, shmem_bytes=104320
cdTile(bf16, bf16): q_tile_dim=128, kv_tile_dim=64, head_dim=96, shmem_bytes=124800
cdTile(bf16, bf16): q_tile_dim=128, kv_tile_dim=64, head_dim=128, shmem_bytes=145280
cdTile(bf16, bf16): q_tile_dim=128, kv_tile_dim=64, head_dim=160, shmem_bytes=165760
cdTile(bf16, bf16): q_tile_dim=128, kv_tile_dim=64, head_dim=192, shmem_bytes=18

In [5]:
for q_tile_dim, kv_tile_dim in Q_KV_TILE_DIM_TUPLES:
    for head_dim in HEAD_DIMS:
        shmem_bytes = forward_shared_mem_in_bytes(
            q_tile_dim=q_tile_dim,
            kv_tile_dim=kv_tile_dim,
            head_dim=head_dim,
            shared_mem_padding_tile=SHMEM_PADDING_TILE,
            shared_mem_padding_chunk=SHMEM_PADDING_CHUNK,
            c_tile_dtype="fp32",
            d_tile_dtype="bf16",
        )
        print(f"cdTile(fp32, bf16): q_tile_dim={q_tile_dim}, kv_tile_dim={kv_tile_dim}, head_dim={head_dim}, shmem_bytes={shmem_bytes}")

cdTile(fp32, bf16): q_tile_dim=64, kv_tile_dim=64, head_dim=64, shmem_bytes=67840
cdTile(fp32, bf16): q_tile_dim=64, kv_tile_dim=64, head_dim=96, shmem_bytes=80128
cdTile(fp32, bf16): q_tile_dim=64, kv_tile_dim=64, head_dim=128, shmem_bytes=92416
cdTile(fp32, bf16): q_tile_dim=64, kv_tile_dim=64, head_dim=160, shmem_bytes=104704
cdTile(fp32, bf16): q_tile_dim=64, kv_tile_dim=64, head_dim=192, shmem_bytes=116992
cdTile(fp32, bf16): q_tile_dim=64, kv_tile_dim=64, head_dim=224, shmem_bytes=129280
cdTile(fp32, bf16): q_tile_dim=64, kv_tile_dim=64, head_dim=256, shmem_bytes=141568
cdTile(fp32, bf16): q_tile_dim=128, kv_tile_dim=64, head_dim=64, shmem_bytes=124800
cdTile(fp32, bf16): q_tile_dim=128, kv_tile_dim=64, head_dim=96, shmem_bytes=145280
cdTile(fp32, bf16): q_tile_dim=128, kv_tile_dim=64, head_dim=128, shmem_bytes=165760
cdTile(fp32, bf16): q_tile_dim=128, kv_tile_dim=64, head_dim=160, shmem_bytes=186240
cdTile(fp32, bf16): q_tile_dim=128, kv_tile_dim=64, head_dim=192, shmem_bytes=2

In [6]:
for q_tile_dim, kv_tile_dim in Q_KV_TILE_DIM_TUPLES:
    for head_dim in HEAD_DIMS:
        shmem_bytes = forward_shared_mem_in_bytes(
            q_tile_dim=q_tile_dim,
            kv_tile_dim=kv_tile_dim,
            head_dim=head_dim,
            shared_mem_padding_tile=SHMEM_PADDING_TILE,
            shared_mem_padding_chunk=SHMEM_PADDING_CHUNK,
            c_tile_dtype="fp32",
            d_tile_dtype="fp32",
        )
        print(f"cdTile(fp32, fp32): q_tile_dim={q_tile_dim}, kv_tile_dim={kv_tile_dim}, head_dim={head_dim}, shmem_bytes={shmem_bytes}")

cdTile(fp32, fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=64, shmem_bytes=78080
cdTile(fp32, fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=96, shmem_bytes=90368
cdTile(fp32, fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=128, shmem_bytes=102656
cdTile(fp32, fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=160, shmem_bytes=114944
cdTile(fp32, fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=192, shmem_bytes=127232
cdTile(fp32, fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=224, shmem_bytes=139520
cdTile(fp32, fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=256, shmem_bytes=151808
cdTile(fp32, fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=64, shmem_bytes=145280
cdTile(fp32, fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=96, shmem_bytes=165760
cdTile(fp32, fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=128, shmem_bytes=186240
cdTile(fp32, fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=160, shmem_bytes=206720
cdTile(fp32, fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=192, shmem_bytes=