In [1]:
import torch

# Shared Memory computation

## forward pass

In [2]:
# forward pass


def forward_shared_mem_in_bytes(
    q_tile_dim: int,
    kv_tile_dim: int,
    head_dim: int,
    shared_mem_padding_tile: int,
    shared_mem_padding_chunk: int,
    sof_bf16_bytes: int = 2,
    sof_fp32_bytes: int = 4,
) -> int:

    # 1. q_tile
    q_tile_size = sof_bf16_bytes * q_tile_dim * (head_dim + shared_mem_padding_tile)

    # 2. kv_tile
    kv_tile_size = sof_bf16_bytes * kv_tile_dim * (head_dim + shared_mem_padding_tile)

    # 3. cd_tile
    cd_tile_size = sof_bf16_bytes * q_tile_dim * (kv_tile_dim + shared_mem_padding_tile)

    # 4. i_chunk
    i_chunk_size = sof_bf16_bytes * kv_tile_dim * (1 + shared_mem_padding_chunk)

    # 5. f_chunk
    f_chunk_size = sof_bf16_bytes * q_tile_dim * (1 + shared_mem_padding_chunk)
    f_tilecol_chunk_size = sof_fp32_bytes * q_tile_dim * (1 + shared_mem_padding_chunk)

    # 6. mnl_chunk
    mnl_chunk_size = sof_bf16_bytes * q_tile_dim * (1 + shared_mem_padding_chunk)

    total_bytes = 2 * q_tile_size + 1 * kv_tile_size + 2 * cd_tile_size + 1 * i_chunk_size + 1 * f_chunk_size + 1 * f_tilecol_chunk_size + 6 * mnl_chunk_size

    return total_bytes

In [3]:
Q_TILE_DIMS = [64, 128]
KV_TILE_DIMS = [64, 128]

Q_KV_TILE_DIM_TUPLES = [(64, 64), (128, 64), (128, 128)]

HEAD_DIMS = [64, 96, 128, 160, 192, 224, 256]

SHMEM_PADDING_TILE = 16
SHMEM_PADDING_CHUNK = 4

for q_tile_dim, kv_tile_dim in Q_KV_TILE_DIM_TUPLES:
    for head_dim in HEAD_DIMS:
        shmem_bytes = forward_shared_mem_in_bytes(
            q_tile_dim=q_tile_dim,
            kv_tile_dim=kv_tile_dim,
            head_dim=head_dim,
            shared_mem_padding_tile=SHMEM_PADDING_TILE,
            shared_mem_padding_chunk=SHMEM_PADDING_CHUNK,
        )
        print(f"q_tile_dim={q_tile_dim}, kv_tile_dim={kv_tile_dim}, head_dim={head_dim}, shmem_bytes={shmem_bytes}")

q_tile_dim=64, kv_tile_dim=64, head_dim=64, shmem_bytes=57600
q_tile_dim=64, kv_tile_dim=64, head_dim=96, shmem_bytes=69888
q_tile_dim=64, kv_tile_dim=64, head_dim=128, shmem_bytes=82176
q_tile_dim=64, kv_tile_dim=64, head_dim=160, shmem_bytes=94464
q_tile_dim=64, kv_tile_dim=64, head_dim=192, shmem_bytes=106752
q_tile_dim=64, kv_tile_dim=64, head_dim=224, shmem_bytes=119040
q_tile_dim=64, kv_tile_dim=64, head_dim=256, shmem_bytes=131328
q_tile_dim=128, kv_tile_dim=64, head_dim=64, shmem_bytes=104320
q_tile_dim=128, kv_tile_dim=64, head_dim=96, shmem_bytes=124800
q_tile_dim=128, kv_tile_dim=64, head_dim=128, shmem_bytes=145280
q_tile_dim=128, kv_tile_dim=64, head_dim=160, shmem_bytes=165760
q_tile_dim=128, kv_tile_dim=64, head_dim=192, shmem_bytes=186240
q_tile_dim=128, kv_tile_dim=64, head_dim=224, shmem_bytes=206720
q_tile_dim=128, kv_tile_dim=64, head_dim=256, shmem_bytes=227200
q_tile_dim=128, kv_tile_dim=128, head_dim=64, shmem_bytes=147968
q_tile_dim=128, kv_tile_dim=128, head_di