In [10]:
import torch

# Shared Memory computation

## forward pass

In [11]:
# forward pass

# shmem with common shared memory for cTile + hTile ! 
# this is wrong we are not allowed to override the shared memory for cTile and hTile
def forward_shared_mem_in_bytes_common_c_h_tile(
    q_tile_dim: int,
    kv_tile_dim: int,
    head_dim: int,
    shared_mem_padding_tile: int,
    shared_mem_padding_chunk: int,
    sof_bf16_bytes: int = 2,
    sof_fp32_bytes: int = 4,
    c_tile_dtype: str = "bf16",
    d_tile_dtype: str = "bf16",
) -> int:

    # 1. q_tile
    q_tile_size = sof_bf16_bytes * q_tile_dim * (head_dim + shared_mem_padding_tile)

    # 2. kv_tile
    kv_tile_size = sof_bf16_bytes * kv_tile_dim * (head_dim + shared_mem_padding_tile)

    # 3. d_tile
    d_tile_size_bf16 = sof_bf16_bytes * q_tile_dim * (kv_tile_dim + shared_mem_padding_tile)
    d_tile_size_fp32 = sof_fp32_bytes * q_tile_dim * (kv_tile_dim + shared_mem_padding_tile)

    # 4. c_tile
    n_col_c_tile = max(kv_tile_dim, head_dim)
    c_tile_size_bf16 = sof_bf16_bytes * q_tile_dim * (n_col_c_tile + shared_mem_padding_tile)
    c_tile_size_fp32 = sof_fp32_bytes * q_tile_dim * (n_col_c_tile + shared_mem_padding_tile)
    
    # 4. i_chunk
    i_chunk_size = sof_bf16_bytes * kv_tile_dim * (1 + shared_mem_padding_chunk)

    # 5. f_chunk
    f_chunk_size = sof_bf16_bytes * q_tile_dim * (1 + shared_mem_padding_chunk)
    f_tilecol_chunk_size = sof_fp32_bytes * q_tile_dim * (1 + shared_mem_padding_chunk)

    # 6. mnl_chunk
    mnl_chunk_size = sof_bf16_bytes * q_tile_dim * (1 + shared_mem_padding_chunk)

    total_bytes = 2 * q_tile_size + 1 * kv_tile_size + 1 * i_chunk_size + 1 * f_chunk_size + 1 * f_tilecol_chunk_size + 6 * mnl_chunk_size
    if c_tile_dtype == "fp32":
        total_bytes += c_tile_size_fp32
    else:
        total_bytes += c_tile_size_bf16
    if d_tile_dtype == "fp32":
        total_bytes += d_tile_size_fp32
    else:
        total_bytes += d_tile_size_bf16
    return total_bytes

In [12]:
def forward_shared_mem_in_bytes(
    q_tile_dim: int,
    kv_tile_dim: int,
    head_dim: int,
    shared_mem_padding_tile: int,
    shared_mem_padding_chunk: int,
    sof_bf16_bytes: int = 2,
    sof_fp32_bytes: int = 4,
    c_tile_dtype: str = "bf16",
    d_tile_dtype: str = "bf16",
    h_tile_dtype: str = "fp32",
) -> int:

    # q_tile
    q_tile_size = sof_bf16_bytes * q_tile_dim * (head_dim + shared_mem_padding_tile)

    # h_tile
    h_tile_size_bf16 = sof_bf16_bytes * kv_tile_dim * (head_dim + shared_mem_padding_tile)
    h_tile_size_fp32 = sof_fp32_bytes * kv_tile_dim * (head_dim + shared_mem_padding_tile)

    # kv_tile
    kv_tile_size = sof_bf16_bytes * kv_tile_dim * (head_dim + shared_mem_padding_tile)

    # d_tile
    d_tile_size_bf16 = sof_bf16_bytes * q_tile_dim * (kv_tile_dim + shared_mem_padding_tile)
    d_tile_size_fp32 = sof_fp32_bytes * q_tile_dim * (kv_tile_dim + shared_mem_padding_tile)

    # c_tile
    c_tile_size_bf16 = sof_bf16_bytes * q_tile_dim * (kv_tile_dim + shared_mem_padding_tile)
    c_tile_size_fp32 = sof_fp32_bytes * q_tile_dim * (kv_tile_dim + shared_mem_padding_tile)
    
    # i_chunk
    i_chunk_size = sof_bf16_bytes * kv_tile_dim * (1 + shared_mem_padding_chunk)

    # f_chunk
    f_chunk_size = sof_bf16_bytes * q_tile_dim * (1 + shared_mem_padding_chunk)
    f_tilecol_chunk_size = sof_fp32_bytes * q_tile_dim * (1 + shared_mem_padding_chunk)

    # mnl_chunk
    mnl_chunk_size = sof_bf16_bytes * q_tile_dim * (1 + shared_mem_padding_chunk)

    total_bytes = 1 * q_tile_size + 1 * kv_tile_size + 1 * i_chunk_size + 1 * f_chunk_size + 1 * f_tilecol_chunk_size + 6 * mnl_chunk_size
    if c_tile_dtype == "fp32":
        total_bytes += c_tile_size_fp32
    else:
        total_bytes += c_tile_size_bf16
    if d_tile_dtype == "fp32":
        total_bytes += d_tile_size_fp32
    else:
        total_bytes += d_tile_size_bf16

    if h_tile_dtype == "fp32":
        total_bytes += h_tile_size_fp32
    else:
        total_bytes += h_tile_size_bf16

    return total_bytes

In [19]:
Q_TILE_DIMS = [64, 128]
KV_TILE_DIMS = [64, 128]

Q_KV_TILE_DIM_TUPLES = [(64, 64), (128, 64), (128, 128)]

HEAD_DIMS = [64, 96, 128, 160, 192, 224, 256]

SHMEM_PADDING_TILE = 16
SHMEM_PADDING_CHUNK = 4

In [20]:
for q_tile_dim, kv_tile_dim in Q_KV_TILE_DIM_TUPLES:
    for head_dim in HEAD_DIMS:
        shmem_bytes = forward_shared_mem_in_bytes(
            q_tile_dim=q_tile_dim,
            kv_tile_dim=kv_tile_dim,
            head_dim=head_dim,
            shared_mem_padding_tile=SHMEM_PADDING_TILE,
            shared_mem_padding_chunk=SHMEM_PADDING_CHUNK,
            c_tile_dtype="bf16",
            d_tile_dtype="bf16",
            h_tile_dtype="fp32"
        )
        print(f"cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim={q_tile_dim}, kv_tile_dim={kv_tile_dim}, head_dim={head_dim}, shmem_bytes={shmem_bytes}")

cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=64, shmem_bytes=67840
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=96, shmem_bytes=84224
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=128, shmem_bytes=100608
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=160, shmem_bytes=116992
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=192, shmem_bytes=133376
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=224, shmem_bytes=149760
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=256, shmem_bytes=166144
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=64, shmem_bytes=104320
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=96, shmem_bytes=124800
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=128, shmem_bytes=145280
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=

In [21]:
for q_tile_dim, kv_tile_dim in Q_KV_TILE_DIM_TUPLES:
    for head_dim in HEAD_DIMS:
        shmem_bytes = forward_shared_mem_in_bytes(
            q_tile_dim=q_tile_dim,
            kv_tile_dim=kv_tile_dim,
            head_dim=head_dim,
            shared_mem_padding_tile=SHMEM_PADDING_TILE,
            shared_mem_padding_chunk=SHMEM_PADDING_CHUNK,
            c_tile_dtype="fp32",
            d_tile_dtype="bf16",
            h_tile_dtype="fp32"
        )
        print(f"cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim={q_tile_dim}, kv_tile_dim={kv_tile_dim}, head_dim={head_dim}, shmem_bytes={shmem_bytes}")

cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=64, shmem_bytes=78080
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=96, shmem_bytes=94464
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=128, shmem_bytes=110848
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=160, shmem_bytes=127232
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=192, shmem_bytes=143616
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=224, shmem_bytes=160000
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=256, shmem_bytes=176384
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=64, shmem_bytes=124800
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=96, shmem_bytes=145280
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=128, shmem_bytes=165760
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=

In [22]:
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=64, shmem_bytes=67840
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=96, shmem_bytes=80128
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=128, shmem_bytes=92416
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=160, shmem_bytes=104704
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=192, shmem_bytes=116992
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=224, shmem_bytes=129280
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=256, shmem_bytes=141568
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=64, shmem_bytes=114560
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=96, shmem_bytes=130944
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=128, shmem_bytes=147328
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=160, shmem_bytes=163712
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=192, shmem_bytes=180096
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=224, shmem_bytes=196480
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=256, shmem_bytes=212864
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=128, head_dim=64, shmem_bytes=184832
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=128, head_dim=96, shmem_bytes=209408
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=128, head_dim=128, shmem_bytes=233984
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=128, head_dim=160, shmem_bytes=258560
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=128, head_dim=192, shmem_bytes=283136
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=128, head_dim=224, shmem_bytes=307712
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=128, head_dim=256, shmem_bytes=332288

In [23]:
for q_tile_dim, kv_tile_dim in Q_KV_TILE_DIM_TUPLES:
    for head_dim in HEAD_DIMS:
        shmem_bytes = forward_shared_mem_in_bytes(
            q_tile_dim=q_tile_dim,
            kv_tile_dim=kv_tile_dim,
            head_dim=head_dim,
            shared_mem_padding_tile=SHMEM_PADDING_TILE,
            shared_mem_padding_chunk=SHMEM_PADDING_CHUNK,
            c_tile_dtype="fp32",
            d_tile_dtype="fp32",
            h_tile_dtype="fp32"
        )
        print(f"cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim={q_tile_dim}, kv_tile_dim={kv_tile_dim}, head_dim={head_dim}, shmem_bytes={shmem_bytes}")

cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=64, shmem_bytes=88320
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=96, shmem_bytes=104704
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=128, shmem_bytes=121088
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=160, shmem_bytes=137472
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=192, shmem_bytes=153856
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=224, shmem_bytes=170240
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=256, shmem_bytes=186624
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=64, shmem_bytes=145280
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=96, shmem_bytes=165760
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=128, shmem_bytes=186240
cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim

In [24]:
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=64, shmem_bytes=78080
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=96, shmem_bytes=90368
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=128, shmem_bytes=102656
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=160, shmem_bytes=114944
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=192, shmem_bytes=127232
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=224, shmem_bytes=139520
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=64, kv_tile_dim=64, head_dim=256, shmem_bytes=151808
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=64, shmem_bytes=135040
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=96, shmem_bytes=151424
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=128, shmem_bytes=167808
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=160, shmem_bytes=184192
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=192, shmem_bytes=200576
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=224, shmem_bytes=216960
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=64, head_dim=256, shmem_bytes=233344
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=128, head_dim=64, shmem_bytes=221696
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=128, head_dim=96, shmem_bytes=246272
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=128, head_dim=128, shmem_bytes=270848
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=128, head_dim=160, shmem_bytes=295424
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=128, head_dim=192, shmem_bytes=320000
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=128, head_dim=224, shmem_bytes=344576
# cdTile(c=bf16, d=bf16, h=fp32): q_tile_dim=128, kv_tile_dim=128, head_dim=256, shmem_bytes=369152