In [None]:
!pip install -q transformers accelerate bitsandbytes
!pip install -q einops

In [None]:
from huggingface_hub import login
login(token="YOUR_TOKEN_HERE")

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "meta-llama/Llama-2-7b-chat-hf"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
import os, wandb
wandb.init(project="hpml-final-project")

In [None]:
import time

run = wandb.init(
    project="hpml-final-project",
    group="summarization",
    name=f"run5",
    reinit=True
)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

messages = [
    {"role": "user", "content": "You are an expert summarizer. Your goal is to write a single-paragraph, abstractive summary of the provided text, focusing on the main argument and conclusion. The summary must be brief, no more than 75 words. Use this article: https://en.wikipedia.org/wiki/Graphics_processing_unit."}
]

prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")

max_new_tokens = 100
current_input_ids = input_ids
attention_mask = torch.ones_like(input_ids).long().to("cuda")

past_key_values = None
start_time = None
first_token_time = None
output_ids = input_ids.clone()
generated_tokens = 0

ttft = 0

GPU_COST_PER_HOUR = 2.93

kv_cache_bytes_list = []
kv_cache_mib_list = []

model.eval()

def calculate_tensor_size(tensor):
    if tensor is None:
        return 0
    return tensor.element_size() * tensor.numel()

with torch.no_grad():
    for i in range(max_new_tokens):

        if i == 0:
            torch.cuda.synchronize()
            start_time = time.time()

        outputs = model(
            current_input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            use_cache=True
        )

        logits = outputs.logits[:, -1, :]

        next_token_id = torch.argmax(logits, dim=-1).unsqueeze(-1)

        total_kv_cache_bytes = 0

        if outputs.past_key_values is not None:
            for layer_cache in outputs.past_key_values:
                if isinstance(layer_cache, tuple) and len(layer_cache) > 0:
                    key_tensor, value_tensor = layer_cache[0], layer_cache[1]

                    total_kv_cache_bytes += calculate_tensor_size(key_tensor)
                    total_kv_cache_bytes += calculate_tensor_size(value_tensor)

        total_kv_cache_mib = total_kv_cache_bytes / (1024 * 1024)

        kv_cache_bytes_list.append(total_kv_cache_bytes)
        kv_cache_mib_list.append(total_kv_cache_mib)

        if i == 0:
            torch.cuda.synchronize()
            first_token_time = time.time()
            ttft = first_token_time - start_time
            print(f"**Time to First Token (TTFT): {ttft:.4f} seconds**")

        if next_token_id.item() == tokenizer.eos_token_id:
            break

        output_ids = torch.cat([output_ids, next_token_id], dim=-1)

        current_input_ids = next_token_id

        new_attention_mask = torch.ones((1, 1), dtype=torch.long, device='cuda')
        attention_mask = torch.cat([attention_mask, new_attention_mask], dim=1)

        past_key_values = outputs.past_key_values
        generated_tokens += 1

torch.cuda.synchronize()
end_time = time.time()
total_gen_time = end_time - first_token_time

throughput = generated_tokens / total_gen_time if generated_tokens > 0 else 0
overall_throughput = generated_tokens / (end_time - start_time)

cost_per_second = GPU_COST_PER_HOUR / 3600
cost_per_token = cost_per_second / throughput if throughput > 0 else float("inf")
sequence_cost = cost_per_token * generated_tokens

max_kv_cache_mib = max(kv_cache_mib_list) if kv_cache_mib_list else 0

response = tokenizer.decode(output_ids[0], skip_special_tokens=True)

print("\n--- Model Response ---")
print(response)
print(f"Total tokens generated: {generated_tokens}")
print(f"Total generation time: {total_gen_time:.4f}s")
print(f"Overall throughput (incl TTFT): {overall_throughput:.2f} tokens/s")
print(f"Steady-state throughput: {throughput:.2f} tokens/s")
print(f"Cost per token: ${cost_per_token:.8f}")
print(f"Max KV Cache Size: {max_kv_cache_mib:.2f} MiB")
print(f"Total sequence cost: ${sequence_cost:.8f}")

In [None]:
import matplotlib.pyplot as plt

def plot_kv_cache_growth(kv_cache_mib_list, max_new_tokens, run_name):
    x_axis = list(range(1, len(kv_cache_mib_list) + 1))
    y_axis = kv_cache_mib_list

    plt.figure(figsize=(10, 6))

    plt.plot(x_axis, y_axis, marker='o', linestyle='-', color='teal', markersize=4)
    plt.xlabel('Token Position (Index)')
    plt.ylabel('KV Cache Size (MiB)')
    plt.title(f'KV Cache Growth during Generation ({run_name})')

    plt.grid(True, linestyle='--', alpha=0.6)
    max_mib = max(y_axis)
    plt.annotate(
        f'Max: {max_mib:.2f} MiB',
        xy=(len(x_axis), max_mib),
        xytext=(-50, 10),
        textcoords='offset points',
        arrowprops=dict(facecolor='black', shrink=0.05, width=1)
    )

    plot_filename = f"kv_cache_growth_{run_name}.png"
    plt.savefig(plot_filename)
    plt.close()

    print(f"KV Cache utilization plot saved to {plot_filename}")
    return plot_filename

In [None]:
run_name = run.name

plot_file = plot_kv_cache_growth(kv_cache_mib_list, max_new_tokens, run_name)

wandb.log({
    "throughput": throughput,
    "cost_per_token": cost_per_token,
    "ttft": ttft,
    "max_kv_cache_mib": max_kv_cache_mib,
    "kv_cache_growth_plot": wandb.Image(plot_file)
})

run.finish()

LATENCY BOUND TESTING

In [None]:
import time
import torch
import csv
from datetime import datetime

def calculate_tensor_size(tensor: torch.Tensor) -> int:
    if tensor is None:
        return 0
    return tensor.nelement() * tensor.element_size()

def calculate_kv_cache_mib(past_key_values: tuple) -> float:
    total_bytes = 0

    if past_key_values is not None:
        for layer_cache in past_key_values:
            if isinstance(layer_cache, tuple) and len(layer_cache) >= 2:
                key_tensor = layer_cache[0]
                value_tensor = layer_cache[1]

                total_bytes += calculate_tensor_size(key_tensor)
                total_bytes += calculate_tensor_size(value_tensor)

    return total_bytes / (1024 * 1024)

def latency_bound_test(model, tokenizer, message, max_new_tokens=512):
    prompt = tokenizer.apply_chat_template(
        message,
        tokenize=False,
        add_generation_prompt=True
    )

    input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
    attention_mask = torch.ones_like(input_ids).long().to("cuda")

    current_input_ids = input_ids
    past_key_values = None

    output_ids = input_ids.clone()
    generated_tokens = 0

    model.eval()

    ttft = None
    start_time = None
    first_token_time = None

    kv_cache_mib_list = []

    STOP_TOKENS = {
        tokenizer.eos_token_id,
        tokenizer.pad_token_id,
        tokenizer.bos_token_id,
    }

    with torch.no_grad():
        for i in range(max_new_tokens):

            if i == 0:
                torch.cuda.synchronize()
                start_time = time.time()

            outputs = model(
                current_input_ids,
                past_key_values=past_key_values,
                attention_mask=attention_mask,
                use_cache=True,
            )

            kv_cache_mib = calculate_kv_cache_mib(outputs.past_key_values)
            kv_cache_mib_list.append(kv_cache_mib)

            logits = outputs.logits[:, -1, :]
            next_token_id = torch.argmax(logits, dim=-1).unsqueeze(-1)

            if i == 0:
                torch.cuda.synchronize()
                first_token_time = time.time()
                ttft = first_token_time - start_time

            if next_token_id.item() in STOP_TOKENS:
                next_token_id = torch.tensor(
                    [[tokenizer.encode("a", add_special_tokens=False)[0]]],
                    device="cuda"
                )

            output_ids = torch.cat([output_ids, next_token_id], dim=-1)
            generated_tokens += 1

            current_input_ids = next_token_id

            new_attention_mask = torch.ones((1, 1), dtype=torch.long, device="cuda")
            attention_mask = torch.cat([attention_mask, new_attention_mask], dim=1)

            past_key_values = outputs.past_key_values

    torch.cuda.synchronize()
    end_time = time.time()

    total_gen_time = end_time - first_token_time
    throughput = generated_tokens / total_gen_time

    return {
        "generated_tokens": generated_tokens,
        "ttft": ttft,
        "gen_time": total_gen_time,
        "throughput": throughput,
        "output": tokenizer.decode(output_ids[0], skip_special_tokens=True),
        "kv_cache_mib_list": kv_cache_mib_list,
        "max_kv_cache_mib": max(kv_cache_mib_list) if kv_cache_mib_list else 0
    }

In [None]:
messages = [
    # {"name": "simple_qa", "messages": [
    #     {"role": "user", "content": "What is the capital of Australia? Answer with only the city name."}
    # ]},
    # {"name": "reasoning", "messages": [
    #     {"role": "user", "content": "Let's think step-by-step. If John is taller than Mark, and Mark is shorter than Sue, is John definitely taller than Sue? Answer 'Yes', 'No', or 'Cannot determine'."}
    # ]},
    # {"name": "sentiment_analysis", "messages": [
    #     {"role": "user", "content": "Classify the sentiment of the text as 'Positive', 'Negative', or 'Neutral'. Text: The service was quick and the food was delicious. Sentiment: Positive. Text: The package arrived late and the box was damaged. Sentiment: Negative. Text: The meeting ended on time. Sentiment: Neutral. Text: I finished the book but found the ending disappointing.Sentiment: [FILL IN HERE]"}
    # ]},
    {"name": "summarization", "messages": [
        {"role": "user", "content": "You are an expert summarizer. Your goal is to write a single-paragraph, abstractive summary of the provided text, focusing on the main argument and conclusion. The summary must be brief, no more than 75 words. Use this article: https://en.wikipedia.org/wiki/Graphics_processing_unit"}
    ]},
]

for message in messages:
  run = wandb.init(
      project="hpml-final-project",
      group="summarization",
      name=f"latency_run3",
      reinit=True
  )

  r = latency_bound_test(model, tokenizer, message["messages"], max_new_tokens=256)

  print(f"Generated: {r['generated_tokens']} Tokens")
  print(f"TTFT: {r['ttft']:.4f}s")
  print(f"Generation Time: {r['gen_time']:.4f}s")
  print(f"Throughput: {r['throughput']:.2f} tokens/sec")
  print(f"Max KV Cache Utilization: {r['max_kv_cache_mib']} MiB")
  print(f"Output: {r['output']}")

In [None]:
run_name = run.name

plot_file = plot_kv_cache_growth(r['kv_cache_mib_list'], 256, run_name)

cost_per_token = cost_per_second / r['throughput'] if throughput > 0 else float("inf")

wandb.log({
    "throughput": r['throughput'],
    "cost_per_token": cost_per_token,
    "ttft": r['ttft'],
    "max_kv_cache_mib": r['max_kv_cache_mib'],
    "kv_cache_growth_plot": wandb.Image(plot_file)
})

run.finish()

THROUGHPUT-BOUND TESTS

In [None]:
import time
import torch
from transformers import PreTrainedTokenizer
from typing import List

def prepare_batched_inputs(tokenizer, messages: List[List[dict]]):
    prompts = [tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages]

    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to("cuda")

    return inputs

def throughput_bound_test(model, tokenizer: PreTrainedTokenizer, messages: List[List[dict]], max_new_tokens: int = 256):
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    tokenizer.padding_side = "left"

    inputs = prepare_batched_inputs(tokenizer, messages)
    batch_size = inputs['input_ids'].shape[0]

    model.eval()

    torch.cuda.synchronize()
    start_time = time.time()

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=False,
            use_cache=True,
        )

    torch.cuda.synchronize()
    end_time = time.time()

    total_time = end_time - start_time

    total_tokens = 0
    input_lengths = inputs['input_ids'].shape[1]

    # since the output contains the input tokens, we have to subtract them here
    for i in range(batch_size):
        total_tokens += output_ids.shape[1] - input_lengths

    throughput = total_tokens / total_gen_time

    return {
        "batch_size": batch_size,
        "throughput_tps": throughput,
    }

In [None]:
import wandb
import torch
import pandas as pd
from typing import List, Dict, Any

prompt = [
    [{"role": "user", "content": "You are an expert summarizer. Your goal is to write a single-paragraph, abstractive summary of the provided text, focusing on the main argument and conclusion. The summary must be brief, no more than 75 words. Use this article: https://en.wikipedia.org/wiki/Graphics_processing_unit"}]
]

def test_throughput(config=None):
    with wandb.init(config=config) as run:
        config = run.config
        current_batch_size = config.batch_size
        max_new_tokens = config.max_new_tokens

        total_messages = prompt * current_batch_size

        result = throughput_bound_test(model, tokenizer, total_messages, max_new_tokens=max_new_tokens)

        wandb.log({
            "batch_size": result['batch_size'],
            "throughput": result['throughput_tps'],
        })

In [None]:
sweep_config = {
    'method': 'grid',
    'name': 'Llama2-7B-Throughput-Bound-Summarization',
    'parameters': {
        'batch_size': {'values': [2, 4, 8, 16, 32, 64, 128, 256, 512]},
        'max_new_tokens': {'value': 256}
    }
}

sweep_id = wandb.sweep(sweep_config, project="hpml-final-project")
wandb.agent(sweep_id, test_throughput)