In [1]:
import os
import json
import time
from copy import deepcopy

import torch
import torch.nn.functional as F
from torch.testing import assert_close

from ref.modeling import (
    MLPActivationType,
    AttnQKVPackFormat,
    AttnQKVLayout,
    TransformerConfig,
    TransformerDecoderKVCache,
    TransformerDecoderLayer,
    TransformerDecoderBlock,
)

#### generate toy test cases

##### task1 - case1

In [2]:
def construct_transformer_decoder_kvcache_args(
    b: int,
    nh: int,
    hd: int,
    qkv_layout: AttnQKVLayout,
    ops: list,
    dtype: torch.dtype = torch.float32,
    device: str = "cpu",
    seed: int = 42,
):
    input_tensors = []
    
    for i, op in enumerate(ops):
        if op['op'] in ("set", "append"):
            s, seqlens = op['s'], op['seqlens']
            
            torch.manual_seed(seed + i)
            k = torch.randn(b, s, nh, hd, dtype=dtype, device=device)
            v = torch.randn_like(k)
            cu_seqlens = None
            
            match qkv_layout:
                case AttnQKVLayout.SBHD:
                    k, v = [x.transpose(0, 1) for x in (k, v)]
                case AttnQKVLayout.THD:
                    assert b == 1, "b should be equal to 1 when qkv_layout is THD"
                    assert seqlens is not None, "seqlens must be given when qkv_layout is THD"
                    k, v = [x.squeeze(0) for x in (k, v)]
                    cu_seqlens = torch.concat([
                            torch.zeros(1, dtype=torch.int32, device=device),
                            torch.tensor(seqlens, dtype=torch.int32, device=device).cumsum(dim=0)
                    ], dim=0)
                    assert cu_seqlens[-1] == (t:=b*s), f"The sum of seqlens ({cu_seqlens[-1]}) != length ({t})"
            input_tensors.append((k, v, cu_seqlens))
        else:
            input_tensors.append(None)
        
    return input_tensors
           
b, nh, hd = 1, 1, 4
layout = AttnQKVLayout.BSHD
num_layers = 2
ops = [
    {
        "op": "has",
        "layer_idx": 0,
    },
    {
        "op": "set",
        "layer_idx": 1,
        "s": 3,
        "seqlens": None,
    },
    {
        "op": "set",
        "layer_idx": 0,
        "s": 2,
        "seqlens": None,
    },
    {
        "op": "has",
        "layer_idx": 1,
    },
    {
        "op": "get",
        "layer_idx": 1,
    },
    {
        "op": "reset",
    },
    {
        "op": "has",
        "layer_idx": 1,
    },
    {
        "op": "append",
        "layer_idx": 0,
        "s": 1,
        "seqlens": None,
    },
    {
        "op": "append",
        "layer_idx": 0,
        "s": 2,
        "seqlens": None,
    },
    {
        "op": "has",
        "layer_idx": 0,
    },
    {
        "op": "get",
        "layer_idx": 0,
    },
]

input_tensors = construct_transformer_decoder_kvcache_args(
    b, nh, hd, 
    layout,
    ops,
)

kv_cache = TransformerDecoderKVCache(
    qkv_layout=layout,
    num_layers=num_layers,
)

outputs_ref = []

for i, (op, input_tensor) in enumerate(zip(ops, input_tensors)):
    match op['op']:
        case "reset":
            kv_cache.reset()
            outputs_ref.append(None)
        case "has":
            layer_idx = op['layer_idx']
            outputs_ref.append(kv_cache.has(layer_idx))
        case "get":
            layer_idx = op['layer_idx']
            outputs_ref.append(kv_cache.get(layer_idx))
        case "set":
            layer_idx = op['layer_idx']
            k, v, cu_seqlens = input_tensor
            kv_cache.set(layer_idx, k, v, cu_seqlens=cu_seqlens)
            outputs_ref.append(None)
        case "append":
            layer_idx = op['layer_idx']
            k, v, cu_seqlens = input_tensor
            kv_cache.append(layer_idx, k, v, cu_seqlens=cu_seqlens)
            outputs_ref.append(None)
        case _:
            raise ValueError(f"Unknown operation: {op['op']}")

outputs_ref

[False,
 None,
 None,
 True,
 (tensor([[[[-0.6484, -0.7058,  0.6432,  1.4788]],
  
           [[ 1.1918, -0.1446,  0.4847,  0.6921]],
  
           [[-1.3929,  0.7623,  0.8387, -1.0450]]]]),
  tensor([[[[ 1.1097,  0.3953,  1.1804, -0.8989]],
  
           [[-0.8313,  0.4680,  2.2700,  0.0743]],
  
           [[-0.8931, -0.9201, -0.0213,  1.7711]]]]),
  None),
 None,
 False,
 None,
 None,
 True,
 (tensor([[[[-2.0157,  2.0106,  0.0583,  0.0656]],
  
           [[ 0.4625, -0.1692,  0.3719,  1.4709]],
  
           [[-0.1568, -2.8720,  1.9054, -0.1457]]]]),
  tensor([[[[-1.6534,  2.2517,  0.9501,  2.2385]],
  
           [[-1.8826, -1.0217, -0.2169, -1.0115]],
  
           [[ 0.1614, -0.0939,  1.7723, -0.0284]]]]),
  None)]

##### task1 - case2

In [3]:
def construct_transformer_decoder_kvcache_args(
    b: int,
    nh: int,
    hd: int,
    qkv_layout: AttnQKVLayout,
    ops: list,
    dtype: torch.dtype = torch.float32,
    device: str = "cpu",
    seed: int = 42,
):
    input_tensors = []
    
    for i, op in enumerate(ops):
        if op['op'] in ("set", "append"):
            s, seqlens = op['s'], op['seqlens']
            
            torch.manual_seed(seed + i)
            k = torch.randn(b, s, nh, hd, dtype=dtype, device=device)
            v = torch.randn_like(k)
            cu_seqlens = None
            
            match qkv_layout:
                case AttnQKVLayout.SBHD:
                    k, v = [x.transpose(0, 1) for x in (k, v)]
                case AttnQKVLayout.THD:
                    assert b == 1, "b should be equal to 1 when qkv_layout is THD"
                    assert seqlens is not None, "seqlens must be given when qkv_layout is THD"
                    k, v = [x.squeeze(0) for x in (k, v)]
                    cu_seqlens = torch.concat([
                            torch.zeros(1, dtype=torch.int32, device=device),
                            torch.tensor(seqlens, dtype=torch.int32, device=device).cumsum(dim=0)
                    ], dim=0)
                    assert cu_seqlens[-1] == (t:=b*s), f"The sum of seqlens ({cu_seqlens[-1]}) != length ({t})"
            input_tensors.append((k, v, cu_seqlens))
        else:
            input_tensors.append(None)
        
    return input_tensors
           
b, nh, hd = 1, 1, 4
layout = AttnQKVLayout.THD
num_layers = 2
ops = [
    {
        "op": "has",
        "layer_idx": 0,
    },
    {
        "op": "set",
        "layer_idx": 1,
        "s": 5,
        "seqlens": [2,2,1],
    },
    {
        "op": "append",
        "layer_idx": 0,
        "s": 4,
        "seqlens": [1,1,2],
    },
    {
        "op": "has",
        "layer_idx": 0,
    },
    {
        "op": "get",
        "layer_idx": 0,
    },
    {
        "op": "append",
        "layer_idx": 1,
        "s": 3,
        "seqlens": [1,1,1],
    },
    {
        "op": "has",
        "layer_idx": 1,
    },
    {
        "op": "get",
        "layer_idx": 1,
    },
]

input_tensors = construct_transformer_decoder_kvcache_args(
    b, nh, hd, 
    layout,
    ops,
)

kv_cache = TransformerDecoderKVCache(
    qkv_layout=layout,
    num_layers=num_layers,
)

outputs_ref = []

for i, (op, input_tensor) in enumerate(zip(ops, input_tensors)):
    match op['op']:
        case "reset":
            kv_cache.reset()
            outputs_ref.append(None)
        case "has":
            layer_idx = op['layer_idx']
            outputs_ref.append(kv_cache.has(layer_idx))
        case "get":
            layer_idx = op['layer_idx']
            outputs_ref.append(kv_cache.get(layer_idx))
        case "set":
            layer_idx = op['layer_idx']
            k, v, cu_seqlens = input_tensor
            kv_cache.set(layer_idx, k, v, cu_seqlens=cu_seqlens)
            outputs_ref.append(None)
        case "append":
            layer_idx = op['layer_idx']
            k, v, cu_seqlens = input_tensor
            kv_cache.append(layer_idx, k, v, cu_seqlens=cu_seqlens)
            outputs_ref.append(None)
        case _:
            raise ValueError(f"Unknown operation: {op['op']}")

outputs_ref

[False,
 None,
 None,
 True,
 (tensor([[[ 1.5862,  1.1253,  1.8306,  0.1129]],
  
          [[ 0.4976,  1.5010, -0.1413, -0.3522]],
  
          [[-0.1643, -1.1651, -0.4089, -0.5252]],
  
          [[-1.3153,  0.6031, -0.8124,  0.5920]]]),
  tensor([[[-1.2266, -0.9598,  1.7118, -0.0146]],
  
          [[ 0.4252, -1.3446,  1.6114,  0.5914]],
  
          [[ 0.1644,  1.2514,  0.5173, -0.8078]],
  
          [[-2.0788,  0.6370,  1.3824, -0.9156]]]),
  tensor([0, 1, 2, 4])),
 None,
 True,
 (tensor([[[-0.0166, -0.4668,  2.0909,  0.6149]],
  
          [[ 0.3083, -0.2947, -0.7662, -0.9962]],
  
          [[-1.4624,  0.7523, -1.7173,  0.5757]],
  
          [[-0.2345, -0.5367,  1.1296,  0.1054]],
  
          [[-0.3630,  1.5822, -0.4430,  1.8462]],
  
          [[ 0.6040,  1.1914,  0.3525,  0.2941]],
  
          [[-0.4772, -1.8291, -0.6145,  1.0282]],
  
          [[ 0.5197, -0.1634, -0.0875,  0.6146]]]),
  tensor([[[-0.7771, -0.4484, -1.1668,  0.5006]],
  
          [[ 0.0139,  0.6564,  0.4

##### task23 - case1

In [4]:
SEED = 42
PARAM_DEVICE = "cpu"
PARAM_DTYPE = torch.float32

b, s, h, v = 1, 8, 8, 10
layout = AttnQKVLayout.BSHD
pack_format = AttnQKVPackFormat.Q_K_V

past_seqlen_kv = 0

config = TransformerConfig(
    num_layers=1,
    hidden_size=8,
    ffh_size=16,
    max_seq_len=8,
    param_dtype=PARAM_DTYPE,
    param_device=PARAM_DEVICE,
    init_base_seed=SEED,
    
    vocab_size=10,
    vocab_init_mean=0.1,
    vocab_init_std=1.1,
    
    rope_base=10000,
    rope_ratio=1,
    rope_dynamic=False,
    
    group_size=None,
    eps=1e-5,
    norm_init_range=(-1.1, 1.1),
    
    proj_init_seed=SEED,
    proj_init_mean=0.1,
    proj_init_std=1.1,
    lm_head_tied=False,
    
    online_attn_block_size=4,
    head_dim=4,
    num_q_head=2,
    num_kv_head=1,
    qkv_pack_format=AttnQKVPackFormat.Q_K_V,
    qkv_layout=AttnQKVLayout.BSHD,
    window_size=None,
    causal=True,
    softmax_dropout_rate=0.,
    softmax_dropout_seed=SEED,
    softmax_scale=None,
    softmax_cap=None,
    softmax_temp=1.,
    softmax_clip_range=(0., 1.),
    apply_qk_norm=False,
    qk_norm_group_size=None,
    
    activation_type=MLPActivationType.SILU,
    lora_rank=0,
    lora_alpha=None,
    lora_dropout_rate=0.,
    lora_dropout_seed=SEED,
    lora_init_base_seed=SEED,
    
    num_experts=None,
    moe_topk=1,
    gate_init_mean=0.,
    gate_init_std=1.,
)

torch.manual_seed(42)
input = torch.randn(b, s, h, dtype=torch.bfloat16, device="cpu")
input_ids = torch.randint(0, v, (b, s), dtype=torch.int32, device="cpu")

if past_seqlen_kv > 0:
    kv_cache = TransformerDecoderKVCache(
        qkv_layout=layout,
        num_layers=config.num_layers,
    )
    torch.manual_seed(42)
    past_k = torch.randn(b, past_seqlen_kv, config.num_kv_head, config.head_dim, dtype=config.param_dtype, device=config.param_device)
    past_v = torch.randn_like(past_k)
    past_cu_seqlens = None
    
    if layout is AttnQKVLayout.SBHD:
        past_k, past_v = [x.transpose(0, 1) for x in (past_k, past_v)]
    elif layout is AttnQKVLayout.THD:
        past_k, past_v = [x.squeeze(0) for x in (past_k, past_v)]
        past_cu_seqlens = torch.tensor([0, 5, 9, s]).long()
    
    for layer_idx in range(config.num_layers):
        kv_cache.set(layer_idx, past_k, past_v, cu_seqlens=past_cu_seqlens)

else:
    kv_cache = None

if layout is AttnQKVLayout.THD:
    if past_seqlen_kv > 0:
        cu_seqlens = torch.tensor([0, 1, 2, 3]).long()
    else:
        cu_seqlens = torch.tensor([0, 5, 9, s]).long()
else:
    cu_seqlens = None

print(input, cu_seqlens, kv_cache.get(0) if kv_cache else None)

tensor([[[-0.8086, -1.5312,  0.4062,  0.1719, -0.2471,  0.2041, -0.8789,
          -0.3867],
         [ 0.6016,  0.2676, -0.8516, -0.2891,  1.0000, -0.7812,  1.3750,
           0.1187],
         [ 0.8633, -0.7656,  0.8242, -1.1875, -0.0330, -0.0801,  0.0781,
          -0.6484],
         [-0.4746, -0.6680,  1.0547, -0.0359, -1.3203, -0.6719,  0.0415,
          -0.6445],
         [ 0.2637,  0.7070, -0.2188,  2.8906,  1.3672, -0.1084,  0.2402,
           1.8359],
         [ 1.7500, -0.0459, -0.3516, -0.0962, -1.7656, -0.9102, -0.5977,
           0.6602],
         [ 0.0393,  1.6094, -1.5078,  1.2656,  0.2227, -1.0312,  0.6055,
          -0.4707],
         [ 1.6406, -0.3086, -0.0693, -1.0469,  0.6211, -0.1455,  0.3242,
           2.3906]]], dtype=torch.bfloat16) None None


In [5]:
layer = TransformerDecoderLayer(config)
print(layer)

output_layer = layer(input, cu_seqlens=cu_seqlens, kv_cache=kv_cache)
output_layer.shape, output_layer

TransformerDecoderLayer(
  (attn_pre_norm): GroupRMSNorm()
  (rope): NTKAwareRoPE()
  (attn): OnlineSlidingWindowAttn(
    (softmax_dropout): Dropout(p=0.0, inplace=False)
  )
  (mlp_pre_norm): GroupRMSNorm()
  (mlp): DenseMLPWithLoRA()
)


(torch.Size([1, 8, 8]),
 tensor([[[ -1.6328, -14.3125,  -4.7188,  -5.0938,   0.3770,  -1.6016,  -5.3750,
             0.8789],
          [  6.3438,  11.1250,  -1.8672,   6.3750,  -0.5703,   5.6250,  -0.2734,
            -7.4688],
          [  0.5391,  -7.1875,   1.2656,  -3.7812,   8.4375,  -4.0625,  -3.5312,
             1.9922],
          [  2.7031,  -0.1162,   3.5625,  -5.4062,   2.9531,  -1.2500,  -3.1094,
             4.8125],
          [ -3.8125,  -6.7188,   2.9688,   0.1011,   4.1562,  -9.5625,   4.6875,
             5.2188],
          [  5.3438,  -1.9375,   1.1562,  -7.5312,   2.4062,  -0.8555,  -6.0938,
             8.6250],
          [ -3.8906, -11.8125,   4.1875,  -6.4688,   6.9062, -16.2500,   4.6250,
             6.8750],
          [ -4.1562,   3.1719,  -1.3906,  -0.3086, -19.8750,   3.0156,   2.8125,
             4.3125]]], dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>))

In [6]:
block = TransformerDecoderBlock(config)
block = block.train() if config.online_attn_block_size is not None else block.eval()
if kv_cache is not None:
    block.set_kv_cache(kv_cache)
print(block)

output_block = block(input_ids, cu_seqlens=cu_seqlens)
output_block.shape, output_block

TransformerDecoderBlock(
  (vocab_emb): ParallelVocabEmbedding()
  (layers): ModuleList(
    (0): TransformerDecoderLayer(
      (attn_pre_norm): GroupRMSNorm()
      (rope): NTKAwareRoPE()
      (attn): OnlineSlidingWindowAttn(
        (softmax_dropout): Dropout(p=0.0, inplace=False)
      )
      (mlp_pre_norm): GroupRMSNorm()
      (mlp): DenseMLPWithLoRA()
    )
  )
  (kv_cache): TransformerDecoderKVCache()
  (final_norm): GroupRMSNorm()
  (lm_head): Linear(in_features=8, out_features=10, bias=False)
)


(torch.Size([1, 8, 10]),
 tensor([[[-1.2925,  2.5357, -0.7369, -0.7364, -0.5004,  0.9618, -0.5549,
            1.1369, -0.5655,  0.4905],
          [-3.4057,  1.3570, -1.5379, -0.7549,  1.6519,  1.2669,  0.6092,
            1.7056, -1.2644,  0.6840],
          [-1.7036,  3.1648, -0.5522,  0.9891, -0.1064,  1.8245, -0.3872,
            2.3809,  0.6620,  1.5979],
          [ 3.0575,  2.4082, -1.1623, -0.3622, -1.9437,  1.4024, -0.7102,
           -1.4185, -0.5156,  1.2056],
          [-3.4312,  1.0153, -1.4652, -1.0645,  1.6408,  1.0096,  0.6544,
            1.4626, -1.4687,  0.4301],
          [-3.1556,  0.6621, -1.4333, -1.3881,  1.5312,  0.7911,  0.6917,
            1.0304, -1.7093,  0.1825],
          [-5.0665,  0.8959, -2.1492,  0.0602,  3.3647,  2.4823,  1.7351,
            2.9604, -1.4029,  1.7156],
          [ 1.0452,  3.8170, -1.2443,  1.9853, -0.2522,  2.6731, -0.7634,
            1.0764,  0.5081,  1.9799]]], grad_fn=<UnsafeViewBackward0>))

##### task23 - case2

In [7]:
SEED = 42
PARAM_DEVICE = "cpu"
PARAM_DTYPE = torch.float32

b, s, h, v = 1, 1, 8, 10
layout = AttnQKVLayout.SBHD
pack_format = AttnQKVPackFormat.Q_KV

past_seqlen_kv = 5

config = TransformerConfig(
    num_layers=2,
    hidden_size=8,
    ffh_size=16,
    max_seq_len=16,
    param_dtype=PARAM_DTYPE,
    param_device=PARAM_DEVICE,
    init_base_seed=SEED,
    
    vocab_size=10,
    vocab_init_mean=0.1,
    vocab_init_std=1.1,
    
    rope_base=10000,
    rope_ratio=1,
    rope_dynamic=False,
    
    group_size=None,
    eps=1e-5,
    norm_init_range=(-1.1, 1.1),
    
    proj_init_seed=SEED,
    proj_init_mean=0.1,
    proj_init_std=1.1,
    lm_head_tied=False,
    
    online_attn_block_size=None,
    head_dim=4,
    num_q_head=2,
    num_kv_head=1,
    qkv_pack_format=AttnQKVPackFormat.Q_KV,
    qkv_layout=AttnQKVLayout.SBHD,
    window_size=None,
    causal=True,
    softmax_dropout_rate=0.,
    softmax_dropout_seed=SEED,
    softmax_scale=None,
    softmax_cap=None,
    softmax_temp=1.,
    softmax_clip_range=(0., 1.),
    apply_qk_norm=False,
    qk_norm_group_size=None,
    
    activation_type=MLPActivationType.SILU,
    lora_rank=0,
    lora_alpha=None,
    lora_dropout_rate=0.,
    lora_dropout_seed=SEED,
    lora_init_base_seed=SEED,
    
    num_experts=None,
    moe_topk=1,
    gate_init_mean=0.,
    gate_init_std=1.,
)

torch.manual_seed(42)
input = torch.randn(b, s, h, dtype=torch.bfloat16, device="cpu")
input_ids = torch.randint(0, v, (b, s), dtype=torch.int32, device="cpu")

if past_seqlen_kv > 0:
    kv_cache = TransformerDecoderKVCache(
        qkv_layout=layout,
        num_layers=config.num_layers,
    )
    for layer_idx in range(config.num_layers):
        torch.manual_seed(42 + layer_idx)
        past_k = torch.randn(b, past_seqlen_kv, config.num_kv_head, config.head_dim, dtype=config.param_dtype, device=config.param_device)
        past_v = torch.randn_like(past_k)
        past_cu_seqlens = None
        
        if layout is AttnQKVLayout.SBHD:
            past_k, past_v = [x.transpose(0, 1) for x in (past_k, past_v)]
        elif layout is AttnQKVLayout.THD:
            past_k, past_v = [x.squeeze(0) for x in (past_k, past_v)]
            past_cu_seqlens = torch.tensor([0, 5, 9, past_seqlen_kv]).long()
        
        kv_cache.set(layer_idx, past_k, past_v, cu_seqlens=past_cu_seqlens)

else:
    kv_cache = None

if layout is AttnQKVLayout.THD:
    if past_seqlen_kv > 0:
        cu_seqlens = torch.tensor([0, 1, 2, 3]).long()
    else:
        cu_seqlens = torch.tensor([0, 5, 9, s]).long()
else:
    cu_seqlens = None

print(input, input_ids, cu_seqlens, kv_cache.get(0) if kv_cache else None)

tensor([[[ 0.3359,  0.1289,  0.2344,  0.2305, -1.1250, -0.1865,  2.2031,
          -0.6367]]], dtype=torch.bfloat16) tensor([[5]], dtype=torch.int32) None (tensor([[[[ 1.9269,  1.4873,  0.9007, -2.1055]]],


        [[[-0.7581,  1.0783,  0.8008,  1.6806]]],


        [[[ 0.3559, -0.6866, -0.4934,  0.2415]]],


        [[[-0.2316,  0.0418, -0.2516,  0.8599]]],


        [[[-0.3097, -0.3957,  0.8034, -0.6216]]]]), tensor([[[[ 0.3189, -0.4245,  0.3057, -0.7746]]],


        [[[-0.8371, -0.9224,  1.8113,  0.1606]]],


        [[[ 0.3672,  0.1754,  1.3852, -0.4459]]],


        [[[-1.2024,  0.7078, -1.0759,  0.5357]]],


        [[[ 1.1754,  0.5612, -0.4527, -0.7718]]]]), None)


In [8]:
layer = TransformerDecoderLayer(config)
print(layer)

output_layer = layer(input.clone(), cu_seqlens=cu_seqlens.clone() if cu_seqlens is not None else None, kv_cache=deepcopy(kv_cache) if kv_cache else None)
output_layer.shape, output_layer

TransformerDecoderLayer(
  (attn_pre_norm): GroupRMSNorm()
  (rope): NTKAwareRoPE()
  (attn): OfflineSlidingWindowAttn(
    (softmax_dropout): Dropout(p=0.0, inplace=False)
  )
  (mlp_pre_norm): GroupRMSNorm()
  (mlp): DenseMLPWithLoRA()
)


(torch.Size([1, 1, 8]),
 tensor([[[ 2.1250, -6.6875,  3.5625, -2.6719,  3.8750, -7.8438,  2.1875,
           -0.7852]]], dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>))

In [9]:
block = TransformerDecoderBlock(config)
block = block.train() if config.online_attn_block_size is not None else block.eval()
print(block.training)
if kv_cache is not None:
    block.set_kv_cache(kv_cache=deepcopy(kv_cache))
print(block)

output_block = block(input_ids.clone(), cu_seqlens=cu_seqlens.clone() if cu_seqlens is not None else None)
output_block.shape, output_block

False
TransformerDecoderBlock(
  (vocab_emb): ParallelVocabEmbedding()
  (layers): ModuleList(
    (0-1): 2 x TransformerDecoderLayer(
      (attn_pre_norm): GroupRMSNorm()
      (rope): NTKAwareRoPE()
      (attn): OfflineSlidingWindowAttn(
        (softmax_dropout): Dropout(p=0.0, inplace=False)
      )
      (mlp_pre_norm): GroupRMSNorm()
      (mlp): DenseMLPWithLoRA()
    )
  )
  (kv_cache): TransformerDecoderKVCache()
  (final_norm): GroupRMSNorm()
  (lm_head): Linear(in_features=8, out_features=10, bias=False)
)


(torch.Size([1, 1, 10]),
 tensor([[[-0.0713, -0.7954,  1.0491, -4.1084, -1.7625, -1.1286, -0.6247,
           -1.1881, -2.9059, -3.9017]]], grad_fn=<UnsafeViewBackward0>))

##### task23 - case3

In [10]:
SEED = 42
PARAM_DEVICE = "cpu"
PARAM_DTYPE = torch.float32

b, s, h, v = 1, 3, 8, 10
layout = AttnQKVLayout.THD
pack_format = AttnQKVPackFormat.Q_K_V

past_seqlen_kv = 12

config = TransformerConfig(
    num_layers=3,
    hidden_size=8,
    ffh_size=16,
    max_seq_len=16,
    param_dtype=PARAM_DTYPE,
    param_device=PARAM_DEVICE,
    init_base_seed=SEED,
    
    vocab_size=10,
    vocab_init_mean=0.1,
    vocab_init_std=1.1,
    
    rope_base=10000,
    rope_ratio=1,
    rope_dynamic=False,
    
    group_size=None,
    eps=1e-5,
    norm_init_range=(-1.1, 1.1),
    
    proj_init_seed=SEED,
    proj_init_mean=0.1,
    proj_init_std=1.1,
    lm_head_tied=False,
    
    online_attn_block_size=None,
    head_dim=4,
    num_q_head=2,
    num_kv_head=1,
    qkv_pack_format=AttnQKVPackFormat.Q_K_V,
    qkv_layout=AttnQKVLayout.THD,
    window_size=None,
    causal=True,
    softmax_dropout_rate=0.,
    softmax_dropout_seed=SEED,
    softmax_scale=None,
    softmax_cap=None,
    softmax_temp=1.,
    softmax_clip_range=(0., 1.),
    apply_qk_norm=False,
    qk_norm_group_size=None,
    
    activation_type=MLPActivationType.SILU,
    lora_rank=0,
    lora_alpha=None,
    lora_dropout_rate=0.,
    lora_dropout_seed=SEED,
    lora_init_base_seed=SEED,
    
    num_experts=None,
    moe_topk=1,
    gate_init_mean=0.,
    gate_init_std=1.,
)

torch.manual_seed(42)
input = torch.randn(b, s, h, dtype=torch.bfloat16, device="cpu")
input_ids = torch.randint(0, v, (b, s), dtype=torch.int32, device="cpu")

if past_seqlen_kv > 0:
    kv_cache = TransformerDecoderKVCache(
        qkv_layout=layout,
        num_layers=config.num_layers,
    )
    for layer_idx in range(config.num_layers):
        torch.manual_seed(42 + layer_idx)
        past_k = torch.randn(b, past_seqlen_kv, config.num_kv_head, config.head_dim, dtype=config.param_dtype, device=config.param_device)
        past_v = torch.randn_like(past_k)
        past_cu_seqlens = None
        
        if layout is AttnQKVLayout.SBHD:
            past_k, past_v = [x.transpose(0, 1) for x in (past_k, past_v)]
        elif layout is AttnQKVLayout.THD:
            past_k, past_v = [x.squeeze(0) for x in (past_k, past_v)]
            past_cu_seqlens = torch.tensor([0, 5, 9, past_seqlen_kv]).long()
        
        kv_cache.set(layer_idx, past_k, past_v, cu_seqlens=past_cu_seqlens)

else:
    kv_cache = None

if layout is AttnQKVLayout.THD:
    if past_seqlen_kv > 0:
        cu_seqlens = torch.tensor([0, 1, 2, 3]).long()
    else:
        cu_seqlens = torch.tensor([0, 5, 9, s]).long()
else:
    cu_seqlens = None

print(input, cu_seqlens, kv_cache.get(0) if kv_cache else None)

tensor([[[-0.8086, -1.5312,  0.4062,  0.1719, -0.2471,  0.2041, -0.8789,
          -0.3867],
         [ 0.5664,  0.2363,  0.4863,  1.1719,  1.4531, -0.8906,  0.1543,
           0.8242],
         [-2.1719,  1.3516,  0.2754, -0.1128, -0.7969,  1.3438,  0.3750,
          -1.1328]]], dtype=torch.bfloat16) tensor([0, 1, 2, 3]) (tensor([[[ 1.9269,  1.4873,  0.9007, -2.1055]],

        [[ 0.6784, -1.2345, -0.0431, -1.6047]],

        [[-0.7521,  1.6487, -0.3925, -1.4036]],

        [[-0.7279, -0.5594, -0.7688,  0.7624]],

        [[ 1.6423, -0.1596, -0.4974,  0.4396]],

        [[-0.7581,  1.0783,  0.8008,  1.6806]],

        [[ 1.2791,  1.2964,  0.6105,  1.3347]],

        [[-0.2316,  0.0418, -0.2516,  0.8599]],

        [[-1.3847, -0.8712, -0.2234,  1.7174]],

        [[ 0.3189, -0.4245,  0.3057, -0.7746]],

        [[-1.5576,  0.9956, -0.8798, -0.6011]],

        [[-1.2742,  2.1228, -1.2347, -0.4879]]]), tensor([[[-0.9138, -0.6581,  0.0780,  0.5258]],

        [[-0.4880,  1.1914, -0.8140, 

In [11]:
layer = TransformerDecoderLayer(config)
print(layer)

output_layer = layer(input.clone(), cu_seqlens=cu_seqlens.clone() if cu_seqlens is not None else None, kv_cache=deepcopy(kv_cache) if kv_cache else None)
output_layer.shape, output_layer

TransformerDecoderLayer(
  (attn_pre_norm): GroupRMSNorm()
  (rope): NTKAwareRoPE()
  (attn): OfflineSlidingWindowAttn(
    (softmax_dropout): Dropout(p=0.0, inplace=False)
  )
  (mlp_pre_norm): GroupRMSNorm()
  (mlp): DenseMLPWithLoRA()
)


(torch.Size([1, 3, 8]),
 tensor([[[-1.8906, -9.5625, -4.8125, -1.9688,  4.5312, -2.0938, -4.7812,
            0.2119],
          [ 0.5859,  5.4062, -0.5000,  3.5938, -1.8828,  1.2734,  1.0469,
           -3.1250],
          [-6.5000, -8.9375,  4.3750, -0.1699, 11.8125, -7.8750,  0.5078,
            2.1250]]], dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>))

In [12]:
block = TransformerDecoderBlock(config)
block = block.train() if config.online_attn_block_size is not None else block.eval()
print(block.training)
if kv_cache is not None:
    block.set_kv_cache(kv_cache=deepcopy(kv_cache))
print(block)

output_block = block(input_ids.clone(), cu_seqlens=cu_seqlens.clone() if cu_seqlens is not None else None)
output_block.shape, output_block

False
TransformerDecoderBlock(
  (vocab_emb): ParallelVocabEmbedding()
  (layers): ModuleList(
    (0-2): 3 x TransformerDecoderLayer(
      (attn_pre_norm): GroupRMSNorm()
      (rope): NTKAwareRoPE()
      (attn): OfflineSlidingWindowAttn(
        (softmax_dropout): Dropout(p=0.0, inplace=False)
      )
      (mlp_pre_norm): GroupRMSNorm()
      (mlp): DenseMLPWithLoRA()
    )
  )
  (kv_cache): TransformerDecoderKVCache()
  (final_norm): GroupRMSNorm()
  (lm_head): Linear(in_features=8, out_features=10, bias=False)
)


(torch.Size([1, 3, 10]),
 tensor([[[-5.7833e+00, -4.1559e+00, -7.2531e-01, -1.2458e+00,  5.3105e+00,
           -1.1368e+00,  2.0874e+00,  3.9418e-01, -3.1351e+00, -1.5997e+00],
          [ 5.1554e-01, -2.8597e+00,  7.3401e-01,  7.7610e-01,  5.3917e-01,
           -1.0834e+00,  1.0781e+00, -7.7173e-01,  1.2294e+00,  1.4850e-01],
          [ 8.8537e-01,  1.7028e+00, -5.3950e-03,  4.7097e-01, -1.2100e+00,
            1.1996e+00, -4.4219e-01,  2.1394e-01, -2.2099e-02,  1.5710e+00]]],
        grad_fn=<UnsafeViewBackward0>))

#### transformer decoder, kv cache

In [13]:
l, v = 2, 10
b, s, h, ffh = 1, 6, 8, 16
hd = 4
online_block_size = None # None to use offline attn

layout = AttnQKVLayout.SBHD
pack_format = AttnQKVPackFormat.QKV

config = TransformerConfig(
    num_layers=l,
    hidden_size=h,
    ffh_size=ffh,
    max_seq_len=s,
    vocab_size=v,
    head_dim=hd,
    num_q_head=h//hd,
    num_kv_head=1,
    qkv_layout=layout,
    qkv_pack_format=pack_format,
    online_attn_block_size=online_block_size,
)

if layout is AttnQKVLayout.THD:
    cu_seqlens = torch.tensor([0, 5, 9, s])
else:
    cu_seqlens = None

torch.manual_seed(42)
input = torch.randn(b, s, h)
input_ids = torch.randint(0, v, (b, s))
labels = torch.randint(0, v, (b, s))
input, input_ids, labels

(tensor([[[ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784, -1.2345, -0.0431,
           -1.6047],
          [-0.7521,  1.6487, -0.3925, -1.4036, -0.7279, -0.5594, -0.7688,
            0.7624],
          [ 1.6423, -0.1596, -0.4974,  0.4396, -0.7581,  1.0783,  0.8008,
            1.6806],
          [ 1.2791,  1.2964,  0.6105,  1.3347, -0.2316,  0.0418, -0.2516,
            0.8599],
          [-1.3847, -0.8712, -0.2234,  1.7174,  0.3189, -0.4245,  0.3057,
           -0.7746],
          [-1.5576,  0.9956, -0.8798, -0.6011, -1.2742,  2.1228, -1.2347,
           -0.4879]]]),
 tensor([[3, 7, 0, 9, 0, 9]]),
 tensor([[6, 9, 5, 4, 8, 8]]))

In [14]:
layer = TransformerDecoderLayer(config)
print(layer)

output = layer(input, cu_seqlens=cu_seqlens)
output.shape, output

TransformerDecoderLayer(
  (attn_pre_norm): GroupRMSNorm()
  (rope): NTKAwareRoPE()
  (attn): OfflineSlidingWindowAttn(
    (softmax_dropout): Dropout(p=0.0, inplace=False)
  )
  (mlp_pre_norm): GroupRMSNorm()
  (mlp): DenseMLPWithLoRA()
)


(torch.Size([1, 6, 8]),
 tensor([[[ 4.1074,  0.3974, -2.1693, -2.3000,  3.7595,  2.1975, -0.6886,
           -5.9105],
          [-0.4448, -1.6368, -4.2078, -0.8758,  1.5075,  1.8875, -3.5614,
           -0.5663],
          [ 1.1051, -0.8874,  0.0672,  0.7068, -2.6837,  2.0918,  0.9942,
            3.3176],
          [ 1.6730,  0.2681,  0.7090,  1.6265, -1.8539, -0.6725, -3.3719,
            2.6219],
          [-2.3194,  1.2133,  1.3484,  2.6591,  0.3236, -0.9400,  1.0793,
           -0.3700],
          [-1.9524, -3.0155, -3.0551, -2.8790, -3.2100,  2.8289, -4.8605,
            0.1226]]], grad_fn=<AddBackward0>))

In [15]:
block = TransformerDecoderBlock(config)
print(block)

output = block(input_ids, cu_seqlens=cu_seqlens)
output.shape, output

TransformerDecoderBlock(
  (vocab_emb): ParallelVocabEmbedding()
  (layers): ModuleList(
    (0-1): 2 x TransformerDecoderLayer(
      (attn_pre_norm): GroupRMSNorm()
      (rope): NTKAwareRoPE()
      (attn): OfflineSlidingWindowAttn(
        (softmax_dropout): Dropout(p=0.0, inplace=False)
      )
      (mlp_pre_norm): GroupRMSNorm()
      (mlp): DenseMLPWithLoRA()
    )
  )
  (kv_cache): TransformerDecoderKVCache()
  (final_norm): GroupRMSNorm()
  (lm_head): Linear(in_features=8, out_features=10, bias=False)
)


(torch.Size([1, 6, 10]),
 tensor([[[-3.9713,  0.0218,  0.5872,  2.9859,  2.2806,  0.0770,  0.2449,
            2.5042,  1.5278,  1.5588],
          [-7.4429, -1.5518,  0.5519,  0.6168,  4.0759, -0.4247,  1.1969,
            3.2529, -0.6715, -0.7215],
          [ 5.8300,  2.5819,  0.0981, -1.9704, -4.6511,  0.2711, -1.7529,
           -2.3114,  0.1161, -1.5156],
          [ 5.1277,  3.0218,  0.2595, -1.5033, -4.5942,  0.2958, -2.0625,
           -1.4238,  0.5330, -1.2860],
          [ 5.4769,  1.9154, -0.4239, -1.9101, -3.6879,  0.3206, -1.1926,
           -2.5624, -0.2108, -1.1517],
          [ 5.1345,  2.0056,  0.1201, -1.4434, -4.0380, -0.0994, -1.7106,
           -1.7382,  0.4382, -0.8064]]], grad_fn=<UnsafeViewBackward0>))

In [16]:
kv_cache = TransformerDecoderKVCache(qkv_layout=AttnQKVLayout.BSHD, num_layers=3)
kv_cache.set(1, torch.randn(1, 5, 4), torch.randn(1, 5, 4))

print(kv_cache.has(0), kv_cache.has(1), kv_cache.has(2))
try: print(kv_cache.get(0))
except Exception as e: print(e)
print(kv_cache.get(1))
try: print(kv_cache.get(3))
except Exception as e: print(e)

kv_cache.append(1, torch.randn(1, 1, 4), torch.randn(1, 1, 4))
kv_cache.get(1)

False True False
'Cache for layer 0 does not exist.'
(tensor([[[-0.0166, -0.4668,  2.0909,  0.6149],
         [ 0.3083, -0.2947, -0.7662, -0.9962],
         [-0.2345, -0.5367,  1.1296,  0.1054],
         [-0.3630,  1.5822, -0.4430,  1.8462],
         [-0.4772, -1.8291, -0.6145,  1.0282]]]), tensor([[[-0.7771, -0.4484, -1.1668,  0.5006],
         [ 0.0139,  0.6564,  0.4846, -0.2549],
         [-0.7603, -1.6943, -0.2596,  0.8847],
         [-0.8256,  0.7988, -0.3005, -0.3062],
         [ 0.4163, -0.5947, -0.2367, -1.8343]]]), None)
Layer index must be less than 3 and greater than or equal to 0, but got 3.


(tensor([[[-0.0166, -0.4668,  2.0909,  0.6149],
          [ 0.3083, -0.2947, -0.7662, -0.9962],
          [-0.2345, -0.5367,  1.1296,  0.1054],
          [-0.3630,  1.5822, -0.4430,  1.8462],
          [-0.4772, -1.8291, -0.6145,  1.0282],
          [ 1.6024, -0.0655,  1.1773,  0.2308]]]),
 tensor([[[-0.7771, -0.4484, -1.1668,  0.5006],
          [ 0.0139,  0.6564,  0.4846, -0.2549],
          [-0.7603, -1.6943, -0.2596,  0.8847],
          [-0.8256,  0.7988, -0.3005, -0.3062],
          [ 0.4163, -0.5947, -0.2367, -1.8343],
          [ 0.6967, -1.3385, -1.3070, -0.4712]]]),
 None)

In [17]:
kv_cache = TransformerDecoderKVCache(qkv_layout=AttnQKVLayout.THD, num_layers=3)
kv_cache.set(1, torch.randn(7, 4), torch.randn(7, 4), cu_seqlens=torch.tensor([0, 2, 3, 7]))

print(kv_cache.has(0), kv_cache.has(1), kv_cache.has(2))
try: print(kv_cache.get(0))
except Exception as e: print(e)
print(kv_cache.get(1))
try: print(kv_cache.get(3))
except Exception as e: print(e)

kv_cache.append(1, torch.randn(3, 4), torch.randn(3, 4), cu_seqlens=torch.tensor([0, 1, 2, 3]))
kv_cache.get(1)

False True False
'Cache for layer 0 does not exist.'
(tensor([[-2.5979, -0.0450, -0.7643, -0.4835],
        [-0.7352,  0.6240,  2.3729,  1.3071],
        [-0.4575, -0.2238,  0.9910,  1.3958],
        [-0.0791, -0.2089, -0.3442,  1.8142],
        [ 1.5057,  0.9148, -0.8651, -1.2858],
        [-0.3012,  0.3881,  0.6736, -2.2534],
        [-0.7496,  1.4628, -2.7743,  1.0224]]), tensor([[-2.7623,  0.8921,  0.3671,  0.5021],
        [-0.9558, -2.3803, -0.5387,  1.0196],
        [-0.9873, -0.6070, -0.9646,  0.9868],
        [ 0.1422,  0.2733, -0.4433, -0.5858],
        [-0.5670,  1.2877,  0.2881, -0.0710],
        [ 0.6117,  0.5042, -0.3404, -1.3312],
        [ 1.3698,  0.3584, -0.5757, -0.2979]]), tensor([0, 2, 3, 7]))
Layer index must be less than 3 and greater than or equal to 0, but got 3.


(tensor([[-2.5979, -0.0450, -0.7643, -0.4835],
         [-0.7352,  0.6240,  2.3729,  1.3071],
         [-1.8024,  1.1348,  0.9362,  1.0182],
         [-0.4575, -0.2238,  0.9910,  1.3958],
         [-0.7009,  0.2459, -0.1850,  0.6825],
         [-0.0791, -0.2089, -0.3442,  1.8142],
         [ 1.5057,  0.9148, -0.8651, -1.2858],
         [-0.3012,  0.3881,  0.6736, -2.2534],
         [-0.7496,  1.4628, -2.7743,  1.0224],
         [-0.3955,  0.8117,  0.9530, -0.6530]]),
 tensor([[-2.7623,  0.8921,  0.3671,  0.5021],
         [-0.9558, -2.3803, -0.5387,  1.0196],
         [-0.8325, -1.4271, -0.6519, -0.2329],
         [-0.9873, -0.6070, -0.9646,  0.9868],
         [ 0.4037,  0.8213, -0.1352,  0.3902],
         [ 0.1422,  0.2733, -0.4433, -0.5858],
         [-0.5670,  1.2877,  0.2881, -0.0710],
         [ 0.6117,  0.5042, -0.3404, -1.3312],
         [ 1.3698,  0.3584, -0.5757, -0.2979],
         [ 0.2317, -0.9516,  0.9047,  0.2141]]),
 tensor([ 0,  3,  5, 10]))