# Test the model block (aka time+channel mix) code block

Using the reference model, we load the first model layer block, test/validate the forward pass for compilation issues

In [None]:
# Configure the parent path to be the proj folder
import sys, os, torch, time
sys.path.append('../../')

# # Cuda debugging
# os.environ["TORCH_USE_CUDA_DSA"] = "1"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Import the block classes
from rwkv_block.v7_qwrky.block.qwrky7_layer_block import Qwrky7LayerBlock

# File to load
MODEL_FILENAME="qwrky7-7B.pth"

# Run device, and run dtype to use
RUN_DEVICE="cpu"
RUN_DTYPE=torch.bfloat16
RUN_TMIX_BACKEND="fla"

# Check for cuda device
if torch.cuda.is_available():
    RUN_DEVICE="cuda:0"

# Check if the reference weights exists
assert os.path.exists(f"./.model/{MODEL_FILENAME}"), "The reference weights does not exist. Please download it first (00-model-download.ipynb)"

# Loads the model weights
model_weight = torch.load(f"./.model/{MODEL_FILENAME}", map_location='cpu', weights_only=True, mmap=True)

# Model filename
print(f"### Model filename: {MODEL_FILENAME}")

# Lets get the hidden_size, and setup the test module
head_size = model_weight['model.layers.0.self_attn.r_k'].shape[1]
hidden_size = model_weight['model.embed_tokens.weight'].shape[1]
hidden_size_att = model_weight['model.layers.0.self_attn.v_proj.weight'].shape[0]
print(f"### Model hidden_size: {hidden_size}")

# List the model weights keys, and their shapes
print(f"### model weights keys:")
for key in model_weight:
    print(f"{key}: {model_weight[key].shape} - {model_weight[key].dtype}")

In [None]:
# Initialize the channelmix state, and x state to test
#
# NOTE: The triton kernel minimum chunk size is 16, it fallsback to pytorch mode otherwise
# we intentionally DO not use a unit of 16, so the remainder pytorch code kicks in for triton
IN_TOKENS_LEN=9000
x_state_0 = torch.ones(1, IN_TOKENS_LEN, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
x_state_1 = torch.ones(1, IN_TOKENS_LEN, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
x_state_2 = torch.ones(1, IN_TOKENS_LEN, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_shift_0 = torch.ones(1, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_shift_1 = torch.ones(1, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_wkv_0 = torch.ones(1, hidden_size // head_size, head_size, head_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_wkv_1 = torch.ones(1, hidden_size // head_size, head_size, head_size, device=RUN_DEVICE, dtype=RUN_DTYPE)

# Iteration to test
TEST_STEPS = 5
if RUN_DEVICE != "cpu":
    TEST_STEPS=1000

# Build the cmix block
block = Qwrky7LayerBlock({ 
    "num_hidden_layers":27,
    "head_size":head_size,
    "hidden_size":hidden_size, 
    "hidden_size_att":hidden_size_att, 
    "layer_id":0, 
    "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":RUN_TMIX_BACKEND 
})
block.load_from_model_state_dict(model_weight, 0)

# Log each item shape
block_state = block.state_dict()
print(f"### block state keys:")
for key in block_state:
    print(f"block.{key}: {block_state[key].shape} - {block_state[key].dtype}")
print("----")

In [None]:
### Block
with torch.inference_mode():
    # Input
    block_state_1 = tmix_wkv_1

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    block_state_0 = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, block_state_0, v_first = block(x_state_1, block_state_1, v_first)
    t2 = time.time()
    print(f'1 block forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    block_state_0 = tmix_wkv_0
    for i in range(TEST_STEPS):
        out_x, block_state_0, v_first = block(x_state_1, block_state_1, v_first)
    t2 = time.time()
    print(f'1 block forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms')


In [None]:
### Block
with torch.inference_mode():
    # Input
    block_state_1 = tmix_wkv_1

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    block_state_0 = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, block_state_0, v_first = block.forward_with_default_compile(x_state_1, block_state_1, v_first, out_x, block_state_0, v_first)
    t2 = time.time()
    print(f'1 block forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    block_state_0 = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, block_state_0, v_first = block.forward_with_default_compile(x_state_1, block_state_1, v_first, out_x, block_state_0, v_first)
    t2 = time.time()
    print(f'1 block forward passes (compiled): {(t2-t1)*1000/TEST_STEPS} ms')


In [None]:
# Iteration to test
TEST_STEPS = 1000

### Block
with torch.inference_mode():
    # Input
    block_state_1 = tmix_wkv_1

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    block_state_0 = tmix_wkv_0
    for i in range(TEST_STEPS):
        out_x, block_state_0, v_first = block.forward_with_reduce_compile(x_state_1, block_state_1, v_first)
    t2 = time.time()
    print(f'1 block forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    block_state_0 = tmix_wkv_0
    for i in range(TEST_STEPS):
        out_x, block_state_0, v_first = block.forward_with_reduce_compile(x_state_1, block_state_1, v_first)
    t2 = time.time()
    print(f'1 block forward passes (compiled): {(t2-t1)*1000/TEST_STEPS} ms')
