In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import wandb
import gc

wandb.login(key='TOKEN')

wandb.init(
    project="kv-cache-benchmark",
    name="transformer-kv-cache-comparison",
    config={
        "model_name": "facebook/opt-125m",
        "device": "cuda" if torch.cuda.is_available() else "cpu",
        "input_lengths": [10, 50, 100, 200],
        "output_tokens": 100,
        "batch_sizes": [1, 4, 8, 16],
        "n_repeats": 5
    }
)

config = wandb.config

model = AutoModelForCausalLM.from_pretrained(config.model_name).to(config.device)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def run_inference(prompt, output_length, use_kv_cache=True, batch_size=1):
    inputs = tokenizer(prompt, padding=True, return_tensors="pt").to(config.device)
    
    start_time = time.time()
    
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_new_tokens=output_length,
            attention_mask=inputs.attention_mask,
            pad_token_id=tokenizer.pad_token_id,
            use_cache=use_kv_cache
        )
    
    end_time = time.time()
    
    latency = end_time - start_time
    throughput = (output_length * batch_size) / latency
    
    if config.device == "cuda":
        memory_usage = torch.cuda.max_memory_allocated() / (1024 ** 3)  # GB
        torch.cuda.reset_peak_memory_stats()
    else:
        memory_usage = 0
    
    return latency, throughput, memory_usage

columns = ["batch_size", "input_length", "kv_cache", "latency", "throughput", "memory_usage", "repeat"]
table = wandb.Table(columns=columns)

for batch_size in config.batch_sizes:
    print(f"Testing batch size: {batch_size}")
    
    for input_length in tqdm(config.input_lengths):
        prompts = [
            " ".join(["test"] * input_length) for _ in range(batch_size)
        ]
        
        for repeat in range(config.n_repeats):
            if config.device == "cuda":
                torch.cuda.empty_cache()
                gc.collect()
                
            latency, throughput, memory = run_inference(
                prompts, config.output_tokens, use_kv_cache=True, batch_size=batch_size
            )
            
            table.add_data(batch_size, input_length, "with_kv_cache", 
                          latency, throughput, memory, repeat)
            
            wandb.log({
                "latency/with_kv_cache": latency,
                "throughput/with_kv_cache": throughput,
                "memory/with_kv_cache": memory,
                "batch_size": batch_size,
                "input_length": input_length,
                "repeat": repeat
            })
        
        for repeat in range(config.n_repeats):
            if config.device == "cuda":
                torch.cuda.empty_cache()
                gc.collect()
                
            latency, throughput, memory = run_inference(
                prompts, config.output_tokens, use_kv_cache=False, batch_size=batch_size
            )
            
            table.add_data(batch_size, input_length, "without_kv_cache", 
                          latency, throughput, memory, repeat)
            
            wandb.log({
                "latency/without_kv_cache": latency,
                "throughput/without_kv_cache": throughput,
                "memory/without_kv_cache": memory,
                "batch_size": batch_size,
                "input_length": input_length,
                "repeat": repeat
            })

wandb.log({"results_table": table})

for batch_size in config.batch_sizes:
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
    
    batch_data = [row for row in table.data if row[0] == batch_size]

    data_grouped = {}
    for row in batch_data:
        key = (row[1], row[2])
        if key not in data_grouped:
            data_grouped[key] = []
        data_grouped[key].append((row[3], row[4], row[5]))
    
    avg_data = {}
    for key, values in data_grouped.items():
        latencies, throughputs, memories = zip(*values)
        avg_data[key] = (np.mean(latencies), np.mean(throughputs), np.mean(memories))
    
    input_lengths = sorted(set(key[0] for key in avg_data.keys()))
    with_kv_cache_latency = [avg_data.get((il, "with_kv_cache"), (0, 0, 0))[0] for il in input_lengths]
    without_kv_cache_latency = [avg_data.get((il, "without_kv_cache"), (0, 0, 0))[0] for il in input_lengths]
    
    with_kv_cache_throughput = [avg_data.get((il, "with_kv_cache"), (0, 0, 0))[1] for il in input_lengths]
    without_kv_cache_throughput = [avg_data.get((il, "without_kv_cache"), (0, 0, 0))[1] for il in input_lengths]
    
    with_kv_cache_memory = [avg_data.get((il, "with_kv_cache"), (0, 0, 0))[2] for il in input_lengths]
    without_kv_cache_memory = [avg_data.get((il, "without_kv_cache"), (0, 0, 0))[2] for il in input_lengths]
    
    ax1.plot(input_lengths, with_kv_cache_latency, 'o-', label='With KV Cache')
    ax1.plot(input_lengths, without_kv_cache_latency, 'o-', label='Without KV Cache')
    ax1.set_title(f'Latency (Batch Size: {batch_size})')
    ax1.set_xlabel('Input Length (tokens)')
    ax1.set_ylabel('Latency (seconds)')
    ax1.legend()
    ax1.grid(True)
    
    ax2.plot(input_lengths, with_kv_cache_throughput, 'o-', label='With KV Cache')
    ax2.plot(input_lengths, without_kv_cache_throughput, 'o-', label='Without KV Cache')
    ax2.set_title(f'Throughput (Batch Size: {batch_size})')
    ax2.set_xlabel('Input Length (tokens)')
    ax2.set_ylabel('Throughput (tokens/second)')
    ax2.legend()
    ax2.grid(True)
    
    ax3.plot(input_lengths, with_kv_cache_memory, 'o-', label='With KV Cache')
    ax3.plot(input_lengths, without_kv_cache_memory, 'o-', label='Without KV Cache')
    ax3.set_title(f'Memory Usage (Batch Size: {batch_size})')
    ax3.set_xlabel('Input Length (tokens)')
    ax3.set_ylabel('Memory Usage (GB)')
    ax3.legend()
    ax3.grid(True)
    
    plt.tight_layout()
    
    wandb.log({f"comparison_batch_{batch_size}": wandb.Image(fig)})
    plt.close(fig)


speedup_data = []
memory_increase_data = []

for batch_size in config.batch_sizes:
    for input_length in config.input_lengths:

        with_kv = [(row[3], row[5]) for row in table.data 
                  if row[0] == batch_size and row[1] == input_length and row[2] == "with_kv_cache"]
        without_kv = [(row[3], row[5]) for row in table.data 
                     if row[0] == batch_size and row[1] == input_length and row[2] == "without_kv_cache"]
        
        if with_kv and without_kv:
            avg_with_latency = np.mean([x[0] for x in with_kv])
            avg_without_latency = np.mean([x[0] for x in without_kv])
            avg_with_memory = np.mean([x[1] for x in with_kv])
            avg_without_memory = np.mean([x[1] for x in without_kv])
            
            speedup = avg_without_latency / avg_with_latency if avg_with_latency > 0 else 0
            memory_increase = avg_with_memory / avg_without_memory if avg_without_memory > 0 else 0
            
            speedup_data.append({
                "batch_size": batch_size,
                "input_length": input_length,
                "speedup_ratio": speedup
            })
            
            memory_increase_data.append({
                "batch_size": batch_size,
                "input_length": input_length,
                "memory_increase_ratio": memory_increase
            })
            
            wandb.log({
                "speedup_ratio": speedup,
                "memory_increase_ratio": memory_increase,
                "batch_size": batch_size,
                "input_length": input_length
            })

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

for batch_size in config.batch_sizes:
    batch_speedup = [(item["input_length"], item["speedup_ratio"]) 
                     for item in speedup_data if item["batch_size"] == batch_size]
    batch_memory = [(item["input_length"], item["memory_increase_ratio"]) 
                    for item in memory_increase_data if item["batch_size"] == batch_size]
    

    batch_speedup.sort(key=lambda x: x[0])
    batch_memory.sort(key=lambda x: x[0])
    

    if batch_speedup:
        x, y = zip(*batch_speedup)
        ax1.plot(x, y, 'o-', label=f'Batch {batch_size}')
    
    if batch_memory:
        x, y = zip(*batch_memory)
        ax2.plot(x, y, 'o-', label=f'Batch {batch_size}')

ax1.set_title('Speedup Ratio (Without KV / With KV)')
ax1.set_xlabel('Input Length (tokens)')
ax1.set_ylabel('Speedup Ratio')
ax1.legend()
ax1.grid(True)

ax2.set_title('Memory Usage Increase (With KV / Without KV)')
ax2.set_xlabel('Input Length (tokens)')
ax2.set_ylabel('Memory Increase Ratio')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
wandb.log({"summary_ratios": wandb.Image(fig)})
plt.close(fig)


wandb.finish()


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/boboxa/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mivanboboshko888[0m ([33mivanboboshko888-hse-university4375[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Testing batch size: 1


100%|██████████| 4/4 [01:00<00:00, 15.08s/it]


Testing batch size: 4


100%|██████████| 4/4 [01:13<00:00, 18.45s/it]


Testing batch size: 8


100%|██████████| 4/4 [01:40<00:00, 25.20s/it]


Testing batch size: 16


100%|██████████| 4/4 [02:25<00:00, 36.40s/it]


0,1
batch_size,▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▄▄▄▄▄▄▄█████████▁▁▂█
input_length,▁▁▂▂▄██▁▁▂▂▂▂▂▄████▁▁▂▂▂▄██▁▁▁▂▄▄▄▄██▁▂▄
latency/with_kv_cache,▄▁▁▁▁▁▁▁▁█▇▄▁▁█▁▁▁▁▁▁▁▁▁▁█▂▇▂▆▂▇▂▁▁▁▂▂▂▂
latency/without_kv_cache,▁▁▁▁▁▂▂▁▂▁▁▂▁▂▁▃▃▃▂▂▃▂▃▃▃▄▅▄▃▃▃▄▄▄▄▅▆▅▅█
memory/with_kv_cache,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▃▃▁▁▁▁▂▂▂▂▃▄▂▂▂▃▃▄▄██
memory/without_kv_cache,▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▃▃▃▂▂▂▂▃▃▃▃▄▄▃▄▄▄▆▆████
memory_increase_ratio,█▇██▆▅▇▇▄▄▅▇▁▁▃▇
repeat,█▁▃▆▅▅█▁█▅▁▆▃▆▃▅▅█▆▁▁▅▅▁▆█▁▅▆▅▆██▅█▁▅▆▁█
speedup_ratio,▁▁▁▁▁▁▂▃▁▂▂▄▂▄▄█
throughput/with_kv_cache,▁▁▁▁▁▁▁▁▁▁▁▂▃▂▃▃▂▃▃▃▃▃▃▄▄▃▄▄▄▃██████████

0,1
batch_size,16.0
input_length,200.0
latency/with_kv_cache,0.79712
latency/without_kv_cache,8.51889
memory/with_kv_cache,1.30931
memory/without_kv_cache,1.39153
memory_increase_ratio,0.94091
repeat,4.0
speedup_ratio,10.59228
throughput/with_kv_cache,2007.22518


In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import wandb
import gc

In [None]:
wandb.login(key='TOKEN')

wandb.init(
    project="kv-cache-benchmark",
    name="transformer-kv-cache-comparison",
    config={
        "model_name": "facebook/opt-125m",
        "device": "cuda" if torch.cuda.is_available() else "cpu",
        "input_lengths": [10, 50, 100, 200],
        "output_tokens": 100,
        "batch_sizes": [1, 4, 8, 16],
        "n_repeats": 5
    }
)

config = wandb.config

In [None]:
model = AutoModelForCausalLM.from_pretrained(config.model_name).to(config.device)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token