In [1]:
import os
import math
import fla
from transformers import GenerationConfig
import torch
import argparse
import random
import re
import numpy as np
from numpy import random
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns

model_path = "/workspace/RWKV-block/test/v7_goose/.hf_build/v7-1B5-world/"

  warn(


In [2]:
class Args:
    def __init__(self):
        self.base_model = "fla-hub/rwkv7-1.5B-world"
        self.cache_dir = "./cache"
        self.min_tokens = 16384
        self.max_tokens = 20480
        self.interval = 1024
        self.num_tests = 5
        self.max_depth = 1.0

In [3]:
def get_gpu_memory():
    """Returns the current GPU memory usage in MB."""
    torch.cuda.synchronize()
    return torch.cuda.memory_allocated() / 1024 / 1024


def generate_prompt_landmark(tokenizer, pass_key, context_length, depth, final_context_length_buffer=250):
    needle = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key. "
    task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there. "
    garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. "
    question = "What is the pass key? The pass key is"
    
    tokens_in_garbage = len(tokenizer.encode(garbage))
    multiplier = math.ceil((context_length - len(tokenizer.encode(task_description)) - 25) / tokens_in_garbage)
    context = garbage * multiplier
    
    tokens_task = tokenizer.encode(task_description)
    tokens_needle = tokenizer.encode(needle)
    tokens_context = tokenizer.encode(context)
    tokens_question = tokenizer.encode(question)
    
    # Reduce context length by buffer
    context_length = context_length - final_context_length_buffer - len(tokens_task) - len(tokens_question)
    
    # Truncate context if needed
    if len(tokens_context) + len(tokens_task) + len(tokens_needle) + len(question) > context_length:
        tokens_context = tokens_context[:context_length - len(tokens_needle)]
    
    if depth >= 1:
        tokens_new_context = tokens_task + tokens_context + tokenizer.encode("\n") + tokens_needle + tokenizer.encode("\n") + tokens_question

    elif depth == 0: 
        tokens_new_context = tokens_task + tokens_needle + tokenizer.encode("\n") + tokens_context + tokenizer.encode("\n") + tokens_question

    else:
        insertion_point = int(len(tokens_context) * depth)
        tokens_new_context = tokens_context[:insertion_point]
        
        # Find sentence break
        period_tokens = tokenizer.encode('.')
        while tokens_new_context and tokens_new_context[-1] not in period_tokens:
            insertion_point -= 1
            tokens_new_context = tokens_context[:insertion_point]
        
        tokens_new_context = tokens_task + tokens_new_context + tokenizer.encode("\n") + tokens_needle + tokenizer.encode("\n") + tokens_context[insertion_point:] + tokens_question
    
    # print("Total Tokens in Context: ", len(tokens_new_context))
    new_context = tokenizer.decode(tokens_new_context)
    return new_context

def passkey_retrieval_test(model, tokenizer, device, context_length, depth, seed=666):
    # Generate random pass key
    rnd_state = random.get_state()
    random.seed(seed)
    pass_key = random.randint(1, 50000)
    random.set_state(rnd_state)
    
    prompt = generate_prompt_landmark(tokenizer, pass_key, context_length=context_length, depth=depth)
    answer = str(pass_key)
    
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    input_ids = input_ids.to(device)
    len_token = input_ids.shape[-1]

    # print(f"VRAM usage before generation: {get_gpu_memory():.2f} MB")

    answer_ids = tokenizer(answer, return_tensors="pt").input_ids
    past_key_values = None
    chunk_input_ids = input_ids[:, :-1]
    with torch.no_grad():
        outputs = model(input_ids[:, :-1])
        current_mem = torch.cuda.memory_allocated(device) / 1024**2
        max_mem = torch.cuda.max_memory_allocated(device) / 1024**2
        # print(f"Memory usage after context processing: {current_mem:.2f}MB / {max_mem:.2f}MB")

        # Generate the answer
        generation_output = model.generate(
            input_ids=input_ids,
            max_new_tokens=answer_ids.shape[-1] + 16,
            use_cache=True,
            generation_config=GenerationConfig(do_sample=False, use_cache=True),
        )
        current_mem = torch.cuda.memory_allocated(device) / 1024**2
        max_mem = torch.cuda.max_memory_allocated(device) / 1024**2
        # print(f"Memory usage after generate: {current_mem:.2f}MB / {max_mem:.2f}MB")
    
    model_output = tokenizer.decode(generation_output[0].cpu())
    
    # Find the number after "The pass key is"
    matches = re.findall(r"What is the pass key\? The pass key is (\d+)", model_output)
    model_answer = matches[-1] if matches else None
    is_correct = (model_answer == answer)
    print(f"Model's output: {model_output}")
    print(f"Found answer: {model_answer}")
    print(f"Correct answer: {answer}")
    print(f"Is correct: {is_correct}\n")
    
    return is_correct, len_token

In [4]:
def run_experiment(args, tmix_backend):
    device = "cuda:0"
    torch.cuda.set_device(device)
    torch.set_float32_matmul_precision('high')
    
    print(f"\nInitializing model with {tmix_backend} backend...")
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        trust_remote_code=True, 
        tmix_backend=tmix_backend,
        cache_dir=args.cache_dir  # Using the cache dir from args
    )
    model = model.to('cuda')
    tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
    model.eval()
    
    total_test_points = (args.max_tokens - args.min_tokens) // args.interval + 1
    all_accuracies = []
    
    print(f"Running tests with {tmix_backend} backend:")
    print(f"Token range: {args.min_tokens} to {args.max_tokens} with interval {args.interval}")
    print(f"Number of tests per configuration: {args.num_tests}")
    
    for i in range(total_test_points):
        current_tokens = args.min_tokens + (i * args.interval)
        depth_steps = np.linspace(0, args.max_depth, 10)
        
        for depth in depth_steps:
            passed_tests = 0
            total_tokens = 0
            
            # print(f"\nTesting token length {current_tokens}, depth {depth:.2f}")
            current_mem = torch.cuda.memory_allocated(device) / 1024**2
            # print(f"Current GPU memory usage: {current_mem:.2f} MB")
            
            for k in range(args.num_tests):
                is_correct, len_tokens = passkey_retrieval_test(
                    model, tokenizer, device, 
                    context_length=current_tokens,
                    depth=depth,
                    seed=k
                )
                passed_tests += is_correct
                total_tokens += len_tokens
                
            avg_tokens = total_tokens // args.num_tests
            accuracy = float(passed_tests) / args.num_tests
            
            result = {
                "Backend": tmix_backend,
                "Context Length": avg_tokens,
                "Document Depth": round(depth * 100, -1),
                "Score": passed_tests,
                "Accuracy": accuracy
            }
            all_accuracies.append(result)
            
            print(f"{tmix_backend} - Length: {avg_tokens}, Depth: {depth:.2f}, "
                  f"Accuracy: {accuracy:.2f}, Passed: {passed_tests}/{args.num_tests}")
    
    # Create summary DataFrame and visualization
    df_summary = pd.DataFrame(all_accuracies)
    
    # Create heatmap
    plt.figure(figsize=(17.5, 8))
    pivot_table = pd.pivot_table(
        df_summary,
        values='Score',
        index='Document Depth',
        columns='Context Length',
        aggfunc='mean'
    )
    
    sns.heatmap(
        pivot_table,
        fmt="g",
        cmap=LinearSegmentedColormap.from_list("custom_cmap", 
                                              ["#F0496E", "#EBB839", "#0CD79F"]),
        cbar_kws={'label': 'Score'}
    )
    
    plt.xlabel('Token Limit')
    plt.ylabel('Depth Percent')
    plt.title(f'NIAH Accuracy - {tmix_backend} Kernel')
    plt.tight_layout()
    
    plt.savefig(f"data/heatmap_tokenized_{args.max_tokens}_rwkv7_1b5_{tmix_backend}.png")
    plt.show()
    plt.close()
    
    # Clean up
    del model
    torch.cuda.empty_cache()
    gc.collect()
    
    return df_summary

In [None]:
args = Args()

# Run experiments with different backends
backends = ["cuda", "triton", "triton_bighead", "fused_fla", "fla"]
results = {}

for backend in backends:
    try:
        results[backend] = run_experiment(args, backend)
        time.sleep(5)  # Cool-down period
        
        # Clear output for cleaner notebook
        from IPython.display import clear_output
        clear_output(wait=True)
        
    except Exception as e:
        print(f"Error running {backend} backend: {str(e)}")
        continue

# Compare results if we have multiple backends
if len(results) > 1:
    plt.figure(figsize=(15, 10))
    for backend, df in results.items():
        sns.lineplot(
            data=df,
            x='Context Length',
            y='Accuracy',
            label=f'{backend} backend'
        )
    plt.title('Backend Comparison')
    plt.grid(True)
    plt.savefig(f"data/backend_comparison_{args.max_tokens}_rwkv7_1b5.png")
    plt.show()


Initializing model with cuda backend...
Running tests with cuda backend:
Token range: 16384 to 20480 with interval 1024
Number of tests per configuration: 5


Using /root/.cache/torch_extensions/py310_cu124 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py310_cu124/state_wind_backstepping/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module state_wind_backstepping...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module state_wind_backstepping...


ninja: no work to do.


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Model's output: There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there. The pass key is 2733. Remember it. 2733 is the pass key. 
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. H

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Model's output: There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there. The pass key is 33004. Remember it. 33004 is the pass key. 
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow.

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Model's output: There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there. The pass key is 23721. Remember it. 23721 is the pass key. 
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow.

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Model's output: There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there. The pass key is 5995. Remember it. 5995 is the pass key. 
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. H

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Model's output: There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there. The pass key is 17531. Remember it. 17531 is the pass key. 
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow.

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Model's output: There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The s

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Model's output: There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The s

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Model's output: There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The s

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Model's output: There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The s

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Model's output: There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. The grass is green. The s

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
