# Test the time mix code block

Using the reference model, we load the first time mix block, test/validate the forward pass for compilation issues

In [1]:
# Configure the parent path to be the proj folder
import sys, os, torch, time
sys.path.append('../../')

# # Cuda debugging
# os.environ["TORCH_USE_CUDA_DSA"] = "1"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Import the block classes
from rwkv_block.v7_qwerky.block.qwerky7_time_mix import Qwerky7TimeMix

# File to load
MODEL_FILENAME="qwerky7-7B.pth"

# Run device, and run dtype to use
RUN_DEVICE="cpu"
RUN_DTYPE=torch.bfloat16
RUN_TMIX_BACKEND="auto"

# Check for cuda device
if torch.cuda.is_available():
    RUN_DEVICE="cuda:0"

# Check if the reference weights exists
assert os.path.exists(f"./.model/{MODEL_FILENAME}"), "The reference weights does not exist. Please download it first (00-model-download.ipynb)"

# Loads the model weights
model_weight = torch.load(f"./.model/{MODEL_FILENAME}", map_location='cpu', weights_only=True, mmap=True)

# Model filename
print(f"### Model filename: {MODEL_FILENAME}")

# Lets get the hidden_size, and setup the test module
head_size = model_weight['model.layers.0.self_attn.r_k'].shape[1]
hidden_size = model_weight['model.embed_tokens.weight'].shape[1]
hidden_size_att = model_weight['model.layers.0.self_attn.v_proj.weight'].shape[0]
print(f"### Model hidden_size: {hidden_size}")

# List the model weights keys, and their shapes
print(f"### model weights keys:")
for key in model_weight:
    print(f"{key}: {model_weight[key].shape} - {model_weight[key].dtype}")

  from .autonotebook import tqdm as notebook_tqdm


### Model filename: qwerky7-7B.pth
### Model hidden_size: 3584
### model weights keys:
model.embed_tokens.weight: torch.Size([152064, 3584]) - torch.bfloat16
model.layers.0.self_attn.w0: torch.Size([1, 1, 3584]) - torch.bfloat16
model.layers.0.self_attn.w1: torch.Size([3584, 96]) - torch.bfloat16
model.layers.0.self_attn.w2: torch.Size([96, 3584]) - torch.bfloat16
model.layers.0.self_attn.a0: torch.Size([1, 1, 3584]) - torch.bfloat16
model.layers.0.self_attn.a1: torch.Size([3584, 96]) - torch.bfloat16
model.layers.0.self_attn.a2: torch.Size([96, 3584]) - torch.bfloat16
model.layers.0.self_attn.v0: torch.Size([1, 1, 3584]) - torch.bfloat16
model.layers.0.self_attn.v1: torch.Size([3584, 64]) - torch.bfloat16
model.layers.0.self_attn.v2: torch.Size([64, 3584]) - torch.bfloat16
model.layers.0.self_attn.g1: torch.Size([3584, 416]) - torch.bfloat16
model.layers.0.self_attn.g2: torch.Size([416, 3584]) - torch.bfloat16
model.layers.0.self_attn.k_k: torch.Size([1, 1, 3584]) - torch.bfloat16
mod

In [2]:
# Initialize the channelmix state, and x state to test
#
# NOTE: The triton kernel minimum chunk size is 16, it fallsback to pytorch mode otherwise
# we intentionally DO not use a unit of 16, so the remainder pytorch code kicks in for triton
IN_TOKENS_LEN=9000
x_state_0 = torch.ones(1, IN_TOKENS_LEN, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
x_state_1 = torch.ones(1, IN_TOKENS_LEN, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
x_state_2 = torch.ones(1, IN_TOKENS_LEN, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_shift_0 = torch.ones(1, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_shift_1 = torch.ones(1, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_wkv_0 = torch.ones(1, hidden_size // head_size, head_size, head_size, device=RUN_DEVICE, dtype=torch.float)
tmix_wkv_1 = torch.ones(1, hidden_size // head_size, head_size, head_size, device=RUN_DEVICE, dtype=torch.float)

# Iteration to test
TEST_STEPS = 5
if RUN_DEVICE != "cpu":
    TEST_STEPS=50

# Build the cmix block
tmix = Qwerky7TimeMix({ 
    "num_hidden_layers":27,
    "head_size":head_size,
    "hidden_size":hidden_size, 
    "hidden_size_att":hidden_size_att, 
    "layer_id":0, 
    "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":RUN_TMIX_BACKEND 
})
tmix.load_from_model_state_dict(model_weight, 0)

# Get the named parameters
tmix_params = tmix.named_parameters()
print(f"### tmix named parameters:")
for name, param in tmix_params:
    print(f"{name}: {param.shape} - {param.dtype} - {param.device.type}")

# Log each item shape
tmix_state = tmix.state_dict()
print(f"### tmix state keys:")
for key in tmix_state:
    print(f"tmix.{key}: {tmix_state[key].shape} - {tmix_state[key].dtype} - {param.device.type}")
print("----")

### tmix named parameters:
w0: torch.Size([1, 1, 3584]) - torch.bfloat16 - cuda
w1: torch.Size([3584, 96]) - torch.bfloat16 - cuda
w2: torch.Size([96, 3584]) - torch.bfloat16 - cuda
a0: torch.Size([1, 1, 3584]) - torch.bfloat16 - cuda
a1: torch.Size([3584, 96]) - torch.bfloat16 - cuda
a2: torch.Size([96, 3584]) - torch.bfloat16 - cuda
g1: torch.Size([3584, 416]) - torch.bfloat16 - cuda
g2: torch.Size([416, 3584]) - torch.bfloat16 - cuda
k_k: torch.Size([1, 1, 3584]) - torch.bfloat16 - cuda
k_a: torch.Size([1, 1, 3584]) - torch.bfloat16 - cuda
r_k: torch.Size([28, 128]) - torch.bfloat16 - cuda
q_proj.weight: torch.Size([3584, 3584]) - torch.bfloat16 - cuda
k_proj.weight: torch.Size([512, 3584]) - torch.bfloat16 - cuda
v_proj.weight: torch.Size([512, 3584]) - torch.bfloat16 - cuda
o_proj.weight: torch.Size([3584, 3584]) - torch.bfloat16 - cuda
ln_x.weight: torch.Size([3584]) - torch.bfloat16 - cuda
ln_x.bias: torch.Size([3584]) - torch.bfloat16 - cuda
### tmix state keys:
tmix.w0: torch.

In [3]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')


Using /home/recursal/.cache/torch_extensions/py312_cu121 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/recursal/.cache/torch_extensions/py312_cu121/state_wind_backstepping/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module state_wind_backstepping...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/2] /home/recursal/miniconda3/envs/py-3-12/bin/nvcc --generate-dependencies-with-compile --dependency-output state_wkv7_cuda.cuda.o.d -ccbin /home/recursal/miniconda3/envs/py-3-12/bin/x86_64-conda-linux-gnu-cc -DTORCH_EXTENSION_NAME=state_wind_backstepping -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/recursal/miniconda3/envs/py-3-12/lib/python3.12/site-packages/torch/include -isystem /home/recursal/miniconda3/envs/py-3-12/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /home/recursal/miniconda3/envs/py-3-12/lib/python3.12/site-packages/torch/include/TH -isystem /home/recursal/miniconda3/envs/py-3-12/lib/python3.12/site-packages/torch/include/THC -isystem /home/recursal/miniconda3/envs/py-3-12/include -isystem /home/recursal/miniconda3/envs/py-3-12/include/python3.12 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -

Loading extension module state_wind_backstepping...


1 tmix forward passes (warmup): 81.73021793365479 ms (cuda:0, torch.bfloat16)
1 tmix forward passes (normal): 20.773215293884277 ms (cuda:0, torch.bfloat16)


In [4]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix.forward_with_default_compile(x_state_1, tmix_wkv_1, v_first, out_x, t_wkv, v_first)
    t2 = time.time()
    print(f'1 tmix forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix.forward_with_default_compile(x_state_1, tmix_wkv_1, v_first, out_x, t_wkv, v_first)
    t2 = time.time()
    print(f'1 tmix forward passes (compiled): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')


1 tmix forward passes (warmup): 61.31931781768799 ms (cuda:0, torch.bfloat16)
1 tmix forward passes (compiled): 18.603901863098145 ms (cuda:0, torch.bfloat16)


In [5]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')


1 tmix forward passes (warmup): 52.87830352783203 ms (cuda:0, torch.bfloat16)
1 tmix forward passes (normal): 18.80061626434326 ms (cuda:0, torch.bfloat16)


In [6]:
# # Export tmix1 state dict
# tmix_state = tmix.state_dict()

# # Log each item shape
# print(f"### tmix state keys:")
# for key in tmix_state:
#     print(f"tmix.{key}: {tmix_state[key].shape} - {tmix_state[key].dtype}")
# print("----")

# # Build the tmix block
# tmix2 = Qwerky7TimeMix({ "num_hidden_layers":24, "hidden_size":hidden_size, "layer_id":1, "tmix_backend":"torch", "device":RUN_DEVICE, "dtype":RUN_DTYPE })

# # Load the state dict
# tmix2.load_from_model_state_dict(tmix_state, 1)

# # Log each item shape
# print(f"### tmix2 state keys:")
# for key in tmix_state:
#     print(f"tmix.{key}: {tmix_state[key].shape} - {tmix_state[key].dtype}")
# print("----")