In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

from dataclasses import dataclass

from transformers import AutoTokenizer, AutoModelForCausalLM
import json, os
import matplotlib.pyplot as plt
import time
import math

  from .autonotebook import tqdm as notebook_tqdm


## Benchmarking with vs without KV caching

1. Use [`meta-llama/Llama-3.1-8B`](https://huggingface.co/meta-llama/Llama-3.1-8B) and sweep `max_new_tokens`=`min_new_tokens`  by powers of two from 1 to 512, using the same short prompt ("Once upon a time,") for all output sequence lengths. Plot time vs output sequence length.
2. Now repeat the same with `use_cache=False` in `model.generate()` -- add this plot to the same figure. Make sure to include a legend and descriptive labels/titles. Describe the trends you see in 1-2 sentences -- play around with both log scales and linear scales for a clearer idea of the trends -- **you will only need to include one version in your report**, however.
3. Repeat 1-2 with a much longer prompt, read in from `long_prompt.txt`. **For this question, you only need to sweep from 1 to 32 without the KV cache -- with KV cache, you should be able to reach 512 without issues.** In ~2 sentences, (instead of just comparing KV cache vs no KV cache) compare the trends you see with those from your first plot, and provide an explanation for the trends you observe.
4. Repeat 1-2 with [`meta-llama/Llama-3.2-1B`](https://huggingface.co/meta-llama/Llama-3.2-1B). Again, in ~2 sentences, compare the trends you see with those from your first plot, and provide an explanation for the trends you observe.
5. Repeat 1-2 with [`meta-llama/Llama-2-7b`](https://huggingface.co/meta-llama/Llama-2-7b). Once again, write ~2 sentences to compare the trends you see with those from your first plot, and provide an explanation for the trends you observe.

In total, you will include four (4) figures, one (1) explanation of the general kv vs no kv trend, and reflections for three (3) pairwise comparisons in your report. You may need to restart your kernel each time you want to use a new model.

In [2]:
short_prompt = "Once upon a time,"

with open("long_prompt.txt") as f:
    long_prompt = f.read().strip()

### One generation example

The code below should run without modification in an environment that contains the necessary libraries. A TA used 1 L40S GPU to run these exact cells.

You can adapt it to complete the questions above. Again, **you may need to restart your kernel in between loading in different models**

In [3]:

tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.1-8B')
model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-3.1-8B', dtype='bfloat16').to('cuda')

Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 57.08it/s]


In [9]:
prompt_toks = tokenizer(short_prompt, return_tensors="pt").to('cuda')
start = time.time()
torch.cuda.synchronize()
output = model.generate(**prompt_toks, max_new_tokens=256, use_cache=True) # true by default
torch.cuda.synchronize()
end = time.time()
print("Generation of", output.shape[1]-prompt_toks['input_ids'].shape[1], "new tokens with KV cache took", round(end-start), "seconds.")
print("-----")
print("Output:", tokenizer.batch_decode(output))

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Generation of 256 new tokens with KV cache took 6 seconds.
-----
Output: ['<|begin_of_text|>Once upon a time, there was a man who worked as a bus driver. One day, he was driving his bus when he saw a young boy running along the side of the road. The boy was trying to catch up to the bus, but he was having a hard time keeping up. The bus driver felt sorry for the boy and decided to stop the bus so that he could help him get on board.\nThe boy was very grateful and thanked the bus driver for his kindness. He told the bus driver that he was running late for school and that he had been trying to catch the bus for the past few minutes. The bus driver assured the boy that he would get him to school on time.\nThe boy sat down in the seat next to the bus driver and they began to chat. The boy told the bus driver that his name was John and that he was in the third grade. The bus driver told John that his name was Mr. Smith and that he was a bus driver.\nJohn asked Mr. Smith if he could ask him 

In [10]:
prompt_toks = tokenizer(short_prompt, return_tensors="pt").to('cuda')

start = time.time()
torch.cuda.synchronize()
output = model.generate(**prompt_toks, max_new_tokens=256, use_cache=False) # no KV cache
torch.cuda.synchronize()
end = time.time()
print("Generation of", output.shape[1]-prompt_toks['input_ids'].shape[1], "new tokens without KV cache took", round(end-start), "seconds.")
print("-----")
print("Output:", tokenizer.batch_decode(output))

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Generation of 256 new tokens without KV cache took 8 seconds.
-----
Output: ['<|begin_of_text|>Once upon a time, I had a very different life. I was a professional, living in a big city, with a high-powered job, a beautiful house, a fancy car and a lot of money. But I was miserable. I was always stressed out, and I felt like I was running on a hamster wheel, going nowhere. I was so unhappy that I decided to quit my job and move to a small town in the middle of nowhere. I wanted to start over, to find happiness and fulfillment in a new place. And that’s when I discovered the magic of small town life.\nSmall town life is something that many people dream of. It’s a place where everyone knows each other, where the pace of life is slower and more relaxed, and where people are more connected to the land and to each other. It’s a place where you can live a simpler, more authentic life, where you can focus on what’s important, and where you can find a sense of community and belonging.\nSmall to

### Example plotting code

In [None]:
times_with_kv = []
times_no_kv = []

In [None]:
import matplotlib.pyplot as plt

plt.plot(range(10), times_with_kv, label="with kv cache") # or use [2**x for x in range(10)] for linear scale
plt.plot(range(10), times_no_kv, label="no kv cache")
xticks = [f"2^{x}" for x in range(10)]
plt.xticks(list(range(10)), labels=xticks)
plt.legend()
# TODO: titles, etc.
plt.savefig("[YOUR FILEPATH HERE]", bbox_inches="tight")
plt.show()