# Test the full model logits

Using the reference model, test/validate the full model forward pass for compilation issues

In [1]:
# Configure the parent path to be the proj folder
import sys, os, torch, time
sys.path.append('../../')

# # Cuda debugging
# os.environ["TORCH_USE_CUDA_DSA"] = "1"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Import the block classes
from rwkv_block.v7_qwrky.model.qwrky7_causal_lm import Qwrky7CausalLM
from rwkv_block.v7_qwrky.model.qwrky7_config_map import Qwrky7ConfigMap

# File to load
MODEL_FILENAME="qwrky7-7B.pth"

# Run device, and run dtype to use
RUN_DEVICE="cpu"
RUN_DTYPE=torch.bfloat16
RUN_TMIX_BACKEND="fla"

# Check for cuda device
if torch.cuda.is_available():
    RUN_DEVICE="cuda:0"

# Check if the reference weights exists
assert os.path.exists(f"./.model/{MODEL_FILENAME}"), "The reference weights does not exist. Please download it first (00-model-download.ipynb)"

# Loads the model weights
model_weight = torch.load(f"./.model/{MODEL_FILENAME}", map_location='cpu', weights_only=True, mmap=True)

# Model filename
print(f"### Model filename: {MODEL_FILENAME}")

# Lets get the hidden_size, and setup the test module
head_size = model_weight['model.layers.0.self_attn.r_k'].shape[1]
hidden_size = model_weight['model.embed_tokens.weight'].shape[1]
hidden_size_att = model_weight['model.layers.0.self_attn.v_proj.weight'].shape[0]
hidden_size_ffn = model_weight['model.layers.0.mlp.gate_proj.weight'].shape[0]
print(f"### Model hidden_size: {hidden_size}")

# List the model weights keys, and their shapes
print(f"### model weights keys:")
for key in model_weight:
    print(f"{key}: {model_weight[key].shape} - {model_weight[key].dtype}")

  from .autonotebook import tqdm as notebook_tqdm


### Model filename: qwrky7-7B.pth
### Model hidden_size: 3584
### model weights keys:
model.embed_tokens.weight: torch.Size([152064, 3584]) - torch.bfloat16
model.layers.0.self_attn.w0: torch.Size([1, 1, 3584]) - torch.bfloat16
model.layers.0.self_attn.w1: torch.Size([3584, 96]) - torch.bfloat16
model.layers.0.self_attn.w2: torch.Size([96, 3584]) - torch.bfloat16
model.layers.0.self_attn.a0: torch.Size([1, 1, 3584]) - torch.bfloat16
model.layers.0.self_attn.a1: torch.Size([3584, 96]) - torch.bfloat16
model.layers.0.self_attn.a2: torch.Size([96, 3584]) - torch.bfloat16
model.layers.0.self_attn.v0: torch.Size([1, 1, 3584]) - torch.bfloat16
model.layers.0.self_attn.v1: torch.Size([3584, 64]) - torch.bfloat16
model.layers.0.self_attn.v2: torch.Size([64, 3584]) - torch.bfloat16
model.layers.0.self_attn.g1: torch.Size([3584, 416]) - torch.bfloat16
model.layers.0.self_attn.g2: torch.Size([416, 3584]) - torch.bfloat16
model.layers.0.self_attn.k_k: torch.Size([1, 1, 3584]) - torch.bfloat16
mode

In [2]:
BATCH_SIZE=1
TEST_LOOP=1
IN_TOKENS_LEN=8192
# GPU_COUNT=1

# Iteration to test
TEST_COUNT=1
if RUN_DEVICE != "cpu":
    TEST_COUNT=10


@torch.inference_mode()
def testForwardPass(smodel, compile_type=False):
    # Lets prepare the states accordingly
    in_state = smodel.get_init_state(BATCH_SIZE)
    out_state = smodel.get_init_state(BATCH_SIZE)
    x_tokens = torch.ones(BATCH_SIZE, IN_TOKENS_LEN, device=smodel.emb.weight.device, dtype=torch.long)
    # out_emb = torch.zeros(BATCH_SIZE, IN_TOKENS_LEN, hidden_size, device=smodel.emb.weight.device, dtype=smodel.emb.weight.dtype)

    # Lets test more aggressively
    time0 = time.time()
    if compile_type == "default":
        for i in range(TEST_COUNT):
            smodel.forward_with_default_compile(x_tokens, in_state, out_state)
    elif compile_type == "reduce":
        for i in range(TEST_COUNT):
            smodel.forward_with_reduce_compile(x_tokens, in_state)
    else:
        for i in range(TEST_COUNT):
            smodel.forward(x_tokens, in_state, out_state)
    time1 = time.time()

    print("--")
    print(f"### Compile Type: {compile_type}")
    print("--")
    print(f"### (warmup) Avg time per token batch ({BATCH_SIZE}):", (time1-time0)*1000/TEST_COUNT, "ms")
    print(f"### (warmup) Avg tok/s batch ({BATCH_SIZE}) :", 1000/((time1-time0)*1000/TEST_COUNT/IN_TOKENS_LEN), "tok/s")
    print(f"### (warmup) Avg time per token unbatched :", (time1-time0)*1000/TEST_COUNT/BATCH_SIZE, "ms")
    print(f"### (warmup) Avg tok/s unbatched :", 1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE/IN_TOKENS_LEN), "tok/s")
    # print(f"### (warmup) Avg tok/s unbatched / gpu :", (1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE))/GPU_COUNT, "tok/s")

    for i in range(TEST_LOOP):
        time0 = time.time()
        if compile_type == "default":
            for i in range(TEST_COUNT):
                smodel.forward_with_default_compile(x_tokens, in_state, out_state)
        elif compile_type == "reduce":
            for i in range(TEST_COUNT):
                smodel.forward_with_reduce_compile(x_tokens, in_state)
        else:
            for i in range(TEST_COUNT):
                smodel.forward(x_tokens, in_state, out_state)
        time1 = time.time()
        print("--")
        print(f"### (actual) Avg time per token batch ({BATCH_SIZE}):", (time1-time0)*1000/TEST_COUNT, "ms")
        print(f"### (actual) Avg tok/s batch ({BATCH_SIZE}) :", 1000/((time1-time0)*1000/TEST_COUNT/IN_TOKENS_LEN), "tok/s")
        print(f"### (actual) Avg time per token unbatched :", (time1-time0)*1000/TEST_COUNT/BATCH_SIZE, "ms")
        print(f"### (actual) Avg tok/s unbatched :", 1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE/IN_TOKENS_LEN), "tok/s")
        # print(f"### (actual) Avg tok/s unbatched / gpu :", (1000/((time1-time0)*1000/TEST_COUNT/BATCH_SIZE))/GPU_COUNT, "tok/s")

# Get the config
model_config = Qwrky7ConfigMap.from_model_state_dict(model_weight, device=RUN_DEVICE, dtype=RUN_DTYPE)

# Log the config
print("### Model Config:")
print(model_config)

# Initialize the model instance
model_inst = Qwrky7ConfigMap(model_config)
model_inst.load_state_dict(model_weight)
# model_inst.load_from_model_state_dict(model_weight)
model_state = model_inst.state_dict()

# List the model weights keys, and their shapes
print(f"### model weights keys:")
for key in model_state:
    print(f"{key}: {model_state[key].shape} - {model_state[key].dtype}")


### Model Config:
Qwrky7ConfigMap(num_hidden_layers=0, hidden_size=3584, head_size=128, dropout_rate=0.0, tmix_backend='auto', layer_id=None, device='cuda:0', dtype=torch.bfloat16, hidden_size_ffn=18944, hidden_size_att=512, rms_norm_eps=1e-06, attention_bias=True, attention_output_bias=False, vocab_size=152064, init_state_wkv=False, forward_chunk_size=4096, padding_idx=151643)


TypeError: RWKV7BlockConfigMap.__init__() missing 2 required positional arguments: 'num_hidden_layers' and 'hidden_size'

In [3]:
# Test the single token forward pass
testForwardPass(model_inst)

--
### Compile Type: False
--
### (warmup) Avg time per token batch (1): 462.12620735168457 ms
### (warmup) Avg tok/s batch (1) : 17726.75920490649 tok/s
### (warmup) Avg time per token unbatched : 462.12620735168457 ms
### (warmup) Avg tok/s unbatched : 17726.75920490649 tok/s
--
### (actual) Avg time per token batch (1): 382.0483446121216 ms
### (actual) Avg tok/s batch (1) : 21442.312512352357 tok/s
### (actual) Avg time per token unbatched : 382.0483446121216 ms
### (actual) Avg tok/s unbatched : 21442.312512352357 tok/s


In [4]:
# Test the single token forward pass
testForwardPass(model_inst, "default")

--
### Compile Type: default
--
### (warmup) Avg time per token batch (1): 4773.983526229858 ms
### (warmup) Avg tok/s batch (1) : 1715.9673792317085 tok/s
### (warmup) Avg time per token unbatched : 4773.983526229858 ms
### (warmup) Avg tok/s unbatched : 1715.9673792317085 tok/s
--
### (actual) Avg time per token batch (1): 328.5087585449219 ms
### (actual) Avg tok/s batch (1) : 24936.930255026324 tok/s
### (actual) Avg time per token unbatched : 328.5087585449219 ms
### (actual) Avg tok/s unbatched : 24936.930255026324 tok/s


In [5]:
# Test the single token forward pass
testForwardPass(model_inst, "reduce")

--
### Compile Type: reduce
--
### (warmup) Avg time per token batch (1): 4587.719798088074 ms
### (warmup) Avg tok/s batch (1) : 1785.636516731908 tok/s
### (warmup) Avg time per token unbatched : 4587.719798088074 ms
### (warmup) Avg tok/s unbatched : 1785.636516731908 tok/s
--
### (actual) Avg time per token batch (1): 294.07570362091064 ms
### (actual) Avg tok/s batch (1) : 27856.77258995937 tok/s
### (actual) Avg time per token unbatched : 294.07570362091064 ms
### (actual) Avg tok/s unbatched : 27856.77258995937 tok/s
