## RWKV TMix block python / triton benchmark

Benchmarking the various kernels

In [1]:
# Configure the parent path to be the proj folder
import sys, os, torch, time
sys.path.append('../../')

# # Cuda debugging
# os.environ["TORCH_USE_CUDA_DSA"] = "1"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Import the block classes
from rwkv_block.v7_qwerky.block.qwerky7_time_mix import Qwerky7TimeMix

# File to load
MODEL_FILENAME="qwerky7-7B.pth"

# Run device, and run dtype to use
RUN_DEVICE="cpu"
RUN_DTYPE=torch.bfloat16
RUN_TMIX_BACKEND="auto"

# Check for cuda device
if torch.cuda.is_available():
    RUN_DEVICE="cuda:0"

# Check if the reference weights exists
assert os.path.exists(f"./.model/{MODEL_FILENAME}"), "The reference weights does not exist. Please download it first (00-model-download.ipynb)"

# Loads the model weights
model_weight = torch.load(f"./.model/{MODEL_FILENAME}", map_location='cpu', weights_only=True, mmap=True)

# Model filename
print(f"### Model filename: {MODEL_FILENAME}")

# Lets get the hidden_size, and setup the test module
head_size = model_weight['model.layers.0.self_attn.r_k'].shape[1]
hidden_size = model_weight['model.embed_tokens.weight'].shape[1]
hidden_size_att = model_weight['model.layers.0.self_attn.v_proj.weight'].shape[0]
print(f"### Model hidden_size: {hidden_size}")

# List the model weights keys, and their shapes
print(f"### model weights keys:")
for key in model_weight:
    print(f"{key}: {model_weight[key].shape} - {model_weight[key].dtype}")

  from .autonotebook import tqdm as notebook_tqdm


### Model filename: qwerky7-7B.pth
### Model hidden_size: 3584
### model weights keys:
model.embed_tokens.weight: torch.Size([152064, 3584]) - torch.bfloat16
model.layers.0.self_attn.w0: torch.Size([1, 1, 3584]) - torch.bfloat16
model.layers.0.self_attn.w1: torch.Size([3584, 96]) - torch.bfloat16
model.layers.0.self_attn.w2: torch.Size([96, 3584]) - torch.bfloat16
model.layers.0.self_attn.a0: torch.Size([1, 1, 3584]) - torch.bfloat16
model.layers.0.self_attn.a1: torch.Size([3584, 96]) - torch.bfloat16
model.layers.0.self_attn.a2: torch.Size([96, 3584]) - torch.bfloat16
model.layers.0.self_attn.v0: torch.Size([1, 1, 3584]) - torch.bfloat16
model.layers.0.self_attn.v1: torch.Size([3584, 64]) - torch.bfloat16
model.layers.0.self_attn.v2: torch.Size([64, 3584]) - torch.bfloat16
model.layers.0.self_attn.g1: torch.Size([3584, 416]) - torch.bfloat16
model.layers.0.self_attn.g2: torch.Size([416, 3584]) - torch.bfloat16
model.layers.0.self_attn.k_k: torch.Size([1, 1, 3584]) - torch.bfloat16
mod

In [2]:
# Initialize the channelmix state, and x state to test
# NOTE: The triton kernel minimum chunk size is 16, it fallsback to pytorch mode otherwise
# we intentionally DO not use a unit of 16, so the remainder pytorch code kicks in for triton
IN_TOKENS_LEN=8192
x_state_0 = torch.ones(1, IN_TOKENS_LEN, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
x_state_1 = torch.ones(1, IN_TOKENS_LEN, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
x_state_2 = torch.ones(1, IN_TOKENS_LEN, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_shift_0 = torch.ones(1, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_shift_1 = torch.ones(1, hidden_size, device=RUN_DEVICE, dtype=RUN_DTYPE)
tmix_wkv_0 = torch.ones(1, hidden_size // head_size, head_size, head_size, device=RUN_DEVICE, dtype=torch.float)
tmix_wkv_1 = torch.ones(1, hidden_size // head_size, head_size, head_size, device=RUN_DEVICE, dtype=torch.float)

# Iteration to test
TEST_STEPS = 10

# Slower reference implementation
tmix_pytorch = Qwerky7TimeMix({ "num_hidden_layers":24, "head_size":head_size, "hidden_size":hidden_size, "hidden_size_att":hidden_size_att, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":"pytorch_ref" })
tmix_pytorch.load_from_model_state_dict(model_weight, 0)

# Improved pytorch implement
tmix_pytorch_chunk = Qwerky7TimeMix({ "num_hidden_layers":24, "head_size":head_size, "hidden_size":hidden_size, "hidden_size_att":hidden_size_att, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":"pytorch" })
tmix_pytorch_chunk.load_from_model_state_dict(model_weight, 0)

tmix_triton = Qwerky7TimeMix({ "num_hidden_layers":24, "head_size":head_size, "hidden_size":hidden_size, "hidden_size_att":hidden_size_att, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":"triton" })
tmix_triton.load_from_model_state_dict(model_weight, 0)

tmix_triton_bighead = Qwerky7TimeMix({ "num_hidden_layers":24, "head_size":head_size, "hidden_size":hidden_size, "hidden_size_att":hidden_size_att, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":"triton_bighead" })
tmix_triton_bighead.load_from_model_state_dict(model_weight, 0)

tmix_cuda = Qwerky7TimeMix({ "num_hidden_layers":24, "head_size":head_size, "hidden_size":hidden_size, "hidden_size_att":hidden_size_att, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":"cuda" })
tmix_cuda.load_from_model_state_dict(model_weight, 0)

tmix_fla = Qwerky7TimeMix({ "num_hidden_layers":24, "head_size":head_size, "hidden_size":hidden_size, "hidden_size_att":hidden_size_att, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":"fla" })
tmix_fla.load_from_model_state_dict(model_weight, 0)

tmix_fla_fused = Qwerky7TimeMix({ "num_hidden_layers":24, "head_size":head_size, "hidden_size":hidden_size, "hidden_size_att":hidden_size_att, "layer_id":0, "device":RUN_DEVICE, "dtype":RUN_DTYPE, "tmix_backend":"fla_fused" })
tmix_fla_fused.load_from_model_state_dict(model_weight, 0)

print(f"### Testing the tmix blocks for {TEST_STEPS} steps")

### Testing the tmix blocks for 10 steps


In [3]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix_pytorch.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_pytorch reduce-compile forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix_pytorch.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_pytorch reduce-compile forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

1 tmix_pytorch reduce-compile forward passes (warmup): 2672.472071647644 ms (cuda:0, torch.bfloat16)
1 tmix_pytorch reduce-compile forward passes (normal): 2330.1079273223877 ms (cuda:0, torch.bfloat16)


In [4]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix_pytorch_chunk.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_pytorch_chunk reduce-compile forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix_pytorch_chunk.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_pytorch_chunk reduce-compile forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

skipping cudagraphs due to mutated inputs (1 instances). Found from : 
   File "/home/recursal/rwkv-prj/layerwise-trainer/block/RWKV_block/test/v7_qwirky/../../rwkv_block/v7_goose/block/kernel/rwkv7_attn_pytorch.py", line 200, in torch_dynamo_resume_in_rwkv7_attn_pytorch_v2_chunk_w_compile_break_at_186
    xx[:] = (wkv_xx.to(dtype=xx.dtype) @ r.view(BATCH_SIZE,SEQ_LEN,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,SEQ_LEN,N_HEAD*HEAD_SIZE)



1 tmix_pytorch_chunk reduce-compile forward passes (warmup): 1648.7801551818848 ms (cuda:0, torch.bfloat16)
1 tmix_pytorch_chunk reduce-compile forward passes (normal): 646.4013576507568 ms (cuda:0, torch.bfloat16)


In [5]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix_triton.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_triton reduce-compile forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix_triton.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_triton reduce-compile forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

1 tmix_triton reduce-compile forward passes (warmup): 210.40878295898438 ms (cuda:0, torch.bfloat16)
1 tmix_triton reduce-compile forward passes (normal): 0.46617984771728516 ms (cuda:0, torch.bfloat16)


In [6]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix_triton_bighead.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_triton_bighead reduce-compile forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix_triton_bighead.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_triton_bighead reduce-compile forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

1 tmix_triton_bighead reduce-compile forward passes (warmup): 326.89146995544434 ms (cuda:0, torch.bfloat16)
1 tmix_triton_bighead reduce-compile forward passes (normal): 0.5033969879150391 ms (cuda:0, torch.bfloat16)


In [7]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix_cuda.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_cuda reduce-compile forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix_cuda.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_cuda reduce-compile forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

Using /home/recursal/.cache/torch_extensions/py312_cu121 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/recursal/.cache/torch_extensions/py312_cu121/state_wind_backstepping/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module state_wind_backstepping...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module state_wind_backstepping...


ninja: no work to do.
1 tmix_cuda reduce-compile forward passes (warmup): 145.1594114303589 ms (cuda:0, torch.bfloat16)
1 tmix_cuda reduce-compile forward passes (normal): 0.8263349533081055 ms (cuda:0, torch.bfloat16)


In [8]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix_fla.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_fla reduce-compile forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix_fla.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_fla reduce-compile forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

1 tmix_fla reduce-compile forward passes (warmup): 462.7376079559326 ms (cuda:0, torch.bfloat16)
1 tmix_fla reduce-compile forward passes (normal): 0.5631685256958008 ms (cuda:0, torch.bfloat16)


In [9]:
### TMix
with torch.inference_mode():

    # This is a warmup
    t0 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix_fla_fused.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_fla_fused reduce-compile forward passes (warmup): {(t2-t0)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

    # The actual run
    t1 = time.time()
    out_x = x_state_0
    t_shift = tmix_shift_0
    t_wkv = tmix_wkv_0
    v_first = x_state_2
    for i in range(TEST_STEPS):
        out_x, t_wkv, v_first = tmix_fla_fused.forward_with_reduce_compile(x_state_1, tmix_wkv_1, v_first)
    t2 = time.time()
    print(f'1 tmix_fla_fused reduce-compile forward passes (normal): {(t2-t1)*1000/TEST_STEPS} ms ({RUN_DEVICE}, {RUN_DTYPE})')

1 tmix_fla_fused reduce-compile forward passes (warmup): 181.4242124557495 ms (cuda:0, torch.bfloat16)
1 tmix_fla_fused reduce-compile forward passes (normal): 0.47256946563720703 ms (cuda:0, torch.bfloat16)
