# Test the full model logits

Using the reference model, test/validate the full model forward pass for compilation issues

In [1]:
# Configure the parent path to be the proj folder
import sys, os, torch, time
sys.path.append('../../')

# Cuda debugging
os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Import the block classes
from rwkv_block.v7_qwerky.model.qwerky7_causal_lm import Qwerky7CausalLM
from rwkv_block.v7_qwerky.model.qwerky7_config_map import Qwerky7ConfigMap

# File to load
MODEL_FILENAME="qwerky7-7B.pth"

# Run device, and run dtype to use
RUN_DEVICE="cpu"
RUN_DTYPE=torch.bfloat16
RUN_TMIX_BACKEND="fla"

# Check for cuda device
if torch.cuda.is_available():
    RUN_DEVICE="cuda"
    # # Distributing my notebook GPUs
    # if torch.cuda.device_count() >= 8:
    #     RUN_DEVICE="cuda:1"

# Check if the reference weights exists
assert os.path.exists(f"./.model/{MODEL_FILENAME}"), "The reference weights does not exist. Please download it first (00-model-download.ipynb)"

# Loads the model weights
model_weight = torch.load(f"./.model/{MODEL_FILENAME}", map_location='cpu', weights_only=True, mmap=True)

# Drop model weights that were accidentally saved in the model file
del model_weight["model.layers.0.self_attn.v0"]
del model_weight["model.layers.0.self_attn.v1"]
del model_weight["model.layers.0.self_attn.v2"]

# Model filename
print(f"### Model filename: {MODEL_FILENAME}")

# Lets get the hidden_size, and setup the test module
head_size = model_weight['model.layers.0.self_attn.r_k'].shape[1]
hidden_size = model_weight['model.embed_tokens.weight'].shape[1]
hidden_size_att = model_weight['model.layers.0.self_attn.v_proj.weight'].shape[0]
hidden_size_ffn = model_weight['model.layers.0.mlp.gate_proj.weight'].shape[0]
print(f"### Model hidden_size: {hidden_size}")

# List the model weights keys, and their shapes
print(f"### model weights keys:")
for key in model_weight:
    print(f"{key}: {model_weight[key].shape} - {model_weight[key].dtype}")

  from .autonotebook import tqdm as notebook_tqdm


### Model filename: qwerky7-7B.pth
### Model hidden_size: 3584
### model weights keys:
model.embed_tokens.weight: torch.Size([152064, 3584]) - torch.bfloat16
model.layers.0.self_attn.w0: torch.Size([1, 1, 3584]) - torch.bfloat16
model.layers.0.self_attn.w1: torch.Size([3584, 96]) - torch.bfloat16
model.layers.0.self_attn.w2: torch.Size([96, 3584]) - torch.bfloat16
model.layers.0.self_attn.a0: torch.Size([1, 1, 3584]) - torch.bfloat16
model.layers.0.self_attn.a1: torch.Size([3584, 96]) - torch.bfloat16
model.layers.0.self_attn.a2: torch.Size([96, 3584]) - torch.bfloat16
model.layers.0.self_attn.g1: torch.Size([3584, 416]) - torch.bfloat16
model.layers.0.self_attn.g2: torch.Size([416, 3584]) - torch.bfloat16
model.layers.0.self_attn.k_k: torch.Size([1, 1, 3584]) - torch.bfloat16
model.layers.0.self_attn.k_a: torch.Size([1, 1, 3584]) - torch.bfloat16
model.layers.0.self_attn.r_k: torch.Size([28, 128]) - torch.bfloat16
model.layers.0.self_attn.q_proj.weight: torch.Size([3584, 3584]) - torc

In [2]:
BATCH_SIZE=1
TEST_LOOP=1
IN_TOKENS_LEN=8192
# GPU_COUNT=1

# Iteration to test
TEST_COUNT=1
if RUN_DEVICE != "cpu":
    TEST_COUNT=10


@torch.inference_mode()
def testForwardPass(smodel, compile_type=False):
    # Lets prepare the states accordingly
    in_state = smodel.get_init_state(BATCH_SIZE)
    out_state = smodel.get_init_state(BATCH_SIZE)
    x_tokens = torch.ones(BATCH_SIZE, IN_TOKENS_LEN, device=smodel.model.embed_tokens.weight.device, dtype=torch.long)
    # out_emb = torch.zeros(BATCH_SIZE, IN_TOKENS_LEN, hidden_size, device=smodel.emb.weight.device, dtype=smodel.emb.weight.dtype)

    # Lets test more aggressively
    time0 = time.time()
    if compile_type == "default":
        for i in range(TEST_COUNT):
            smodel.forward_with_default_compile(x_tokens, in_state, out_state)
    elif compile_type == "reduce":
        for i in range(TEST_COUNT):
            smodel.forward_with_reduce_compile(x_tokens, in_state)
    else:
        for i in range(TEST_COUNT):
            smodel.forward(x_tokens, in_state, out_state)
    time1 = time.time()

    print("--")
    print(f"### Compile Type: {compile_type}")
    print("--")
    print(f"### (warmup) Avg time per token batch ({BATCH_SIZE}):", (time1-time0)*1000/TEST_COUNT, "ms")
    print(f"### (warmup) Avg tok/s batch ({BATCH_SIZE}) :", 1000/((time1-time0)*1000/TEST_COUNT/IN_TOKENS_LEN), "tok/s")
    print(f"### (warmup) Avg time per token unbatched :", (time1-time0)*1000/TEST_COUNT/BATCH_SIZE, "ms")
    print(f"### (warmup) Avg tok/s unbatched :", 1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE/IN_TOKENS_LEN), "tok/s")
    # print(f"### (warmup) Avg tok/s unbatched / gpu :", (1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE))/GPU_COUNT, "tok/s")

    for i in range(TEST_LOOP):
        time0 = time.time()
        if compile_type == "default":
            for i in range(TEST_COUNT):
                smodel.forward_with_default_compile(x_tokens, in_state, out_state)
        elif compile_type == "reduce":
            for i in range(TEST_COUNT):
                smodel.forward_with_reduce_compile(x_tokens, in_state)
        else:
            for i in range(TEST_COUNT):
                smodel.forward(x_tokens, in_state, out_state)
        time1 = time.time()
        print("--")
        print(f"### (actual) Avg time per token batch ({BATCH_SIZE}):", (time1-time0)*1000/TEST_COUNT, "ms")
        print(f"### (actual) Avg tok/s batch ({BATCH_SIZE}) :", 1000/((time1-time0)*1000/TEST_COUNT/IN_TOKENS_LEN), "tok/s")
        print(f"### (actual) Avg time per token unbatched :", (time1-time0)*1000/TEST_COUNT/BATCH_SIZE, "ms")
        print(f"### (actual) Avg tok/s unbatched :", 1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE/IN_TOKENS_LEN), "tok/s")
        # print(f"### (actual) Avg tok/s unbatched / gpu :", (1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE))/GPU_COUNT, "tok/s")

# Get the config
model_config = Qwerky7ConfigMap.from_model_state_dict(model_weight, device=RUN_DEVICE, dtype=RUN_DTYPE, tmix_backend=RUN_TMIX_BACKEND)

# Log the config
print("### Model Config:")
print(model_config)

# Initialize the model instance
model_inst = Qwerky7CausalLM(model_config).to(RUN_DEVICE)
model_inst.load_state_dict(model_weight)
# model_inst.load_from_model_state_dict(model_weight)
model_state = model_inst.state_dict()

# List the model weights keys, and their shapes
print(f"### model weights keys:")
for key in model_state:
    print(f"{key}: {model_state[key].shape} - {model_state[key].dtype}")


### Model Config:
Qwerky7ConfigMap(num_hidden_layers=28, hidden_size=3584, head_size=128, dropout_rate=0.0, tmix_backend='fla', layer_id=None, device='cuda', dtype=torch.bfloat16, hidden_size_ffn=18944, hidden_size_att=512, rms_norm_eps=1e-06, attention_bias=True, attention_output_bias=False, vocab_size=152064, init_state_wkv=False, forward_chunk_size=4096, padding_idx=151643)
### model weights keys:
model.embed_tokens.weight: torch.Size([152064, 3584]) - torch.bfloat16
model.layers.0.input_layernorm.weight: torch.Size([3584]) - torch.bfloat16
model.layers.0.self_attn.w0: torch.Size([1, 1, 3584]) - torch.bfloat16
model.layers.0.self_attn.w1: torch.Size([3584, 96]) - torch.bfloat16
model.layers.0.self_attn.w2: torch.Size([96, 3584]) - torch.bfloat16
model.layers.0.self_attn.a0: torch.Size([1, 1, 3584]) - torch.bfloat16
model.layers.0.self_attn.a1: torch.Size([3584, 96]) - torch.bfloat16
model.layers.0.self_attn.a2: torch.Size([96, 3584]) - torch.bfloat16
model.layers.0.self_attn.g1: tor

In [3]:
# Test the single token forward pass
testForwardPass(model_inst)

--
### Compile Type: False
--
### (warmup) Avg time per token batch (1): 9097.833919525146 ms
### (warmup) Avg tok/s batch (1) : 900.4341112909186 tok/s
### (warmup) Avg time per token unbatched : 9097.833919525146 ms
### (warmup) Avg tok/s unbatched : 900.4341112909186 tok/s
--
### (actual) Avg time per token batch (1): 1242.3580169677734 ms
### (actual) Avg tok/s batch (1) : 6593.912453669544 tok/s
### (actual) Avg time per token unbatched : 1242.3580169677734 ms
### (actual) Avg tok/s unbatched : 6593.912453669544 tok/s


In [4]:
# Test the single token forward pass
testForwardPass(model_inst, "default")

--
### Compile Type: default
--
### (warmup) Avg time per token batch (1): 4855.414056777954 ms
### (warmup) Avg tok/s batch (1) : 1687.1887555221604 tok/s
### (warmup) Avg time per token unbatched : 4855.414056777954 ms
### (warmup) Avg tok/s unbatched : 1687.1887555221604 tok/s
--
### (actual) Avg time per token batch (1): 1050.9202480316162 ms
### (actual) Avg tok/s batch (1) : 7795.072951866419 tok/s
### (actual) Avg time per token unbatched : 1050.9202480316162 ms
### (actual) Avg tok/s unbatched : 7795.072951866419 tok/s


In [5]:
# Test the single token forward pass
testForwardPass(model_inst, "reduce")

--
### Compile Type: reduce
--
### (warmup) Avg time per token batch (1): 11262.545323371887 ms
### (warmup) Avg tok/s batch (1) : 727.3666622233313 tok/s
### (warmup) Avg time per token unbatched : 11262.545323371887 ms
### (warmup) Avg tok/s unbatched : 727.3666622233313 tok/s
--
### (actual) Avg time per token batch (1): 971.4428663253784 ms
### (actual) Avg tok/s batch (1) : 8432.817084742628 tok/s
### (actual) Avg time per token unbatched : 971.4428663253784 ms
### (actual) Avg tok/s unbatched : 8432.817084742628 tok/s
