In [None]:
# https://github.com/huggingface/transformers/issues/24841

In [None]:
import os

os.environ["HF_TOKEN"] = "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"

import copy
import json

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache

In [None]:
# run params
temperature = 1.0
top_p = 0.9
max_new_tokens = 64
use_cache = True
device = "cuda:0"

# model_name = "osunlp/TableLlama"
model_name = "google/gemma-2-2b-it"

# load inputs
file_path = "turl_test_2k_prompts_50.jsonl"
device = torch.device(device)

with open(file_path, "r", encoding="utf-8") as f:
    prompts = [json.loads(line) for line in f]

# config = transformers.AutoConfig.from_pretrained(model_name)
# orig_ctx_len = getattr(config, "max_position_embeddings", None)
# model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
# model.resize_token_embeddings(32001)
# tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=orig_ctx_len, padding_side="left", use_fast=False)
# model.eval()

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it", device_map="cuda:0", torch_dtype=torch.bfloat16
)

# build prompts
PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input_seg}\n\n### Question:\n{question}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}


def generate_prompt(instruction, question, input_seg=None):
    question += " Answer with just a candidate, selected from the provided referent entity candidates list, and nothing else. The selected candidate must be reported verbatim from the list provided as input. Each candidate in the list is enclosed between < and > and reports [DESC] and [TYPE] information."
    if input_seg:
        return PROMPT_DICT["prompt_input"].format(
            instruction=instruction, input_seg=input_seg, question=question
        )
    else:
        return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)

In [None]:
# p = prompts[742] # 742 is shortest
p = prompts[840] # average joe

In [None]:
# Initialize prompt cache
prompt_cache = DynamicCache()

# Generate initial prompt input
prompt = generate_prompt(p["instruction"], p["question"], p["input"])
inputs = tokenizer(prompt, return_tensors="pt").to(device)

prompt_ = prompt[:-1]
inputs_ = tokenizer(prompt_, return_tensors="pt").to(device)

In [None]:
inputs_["attention_mask"].shape

In [None]:
# Generate cache
with torch.no_grad():
    pre_output = model(**inputs_, past_key_values=prompt_cache)
    prompt_cache = pre_output.past_key_values

In [None]:
# With cache
post_output = model.generate(
    **inputs,
    max_new_tokens=max_new_tokens,
    temperature=temperature,
    top_p=top_p,
    output_scores=True,
    return_dict_in_generate=True,
    output_logits=True,
    use_cache=True
)

In [None]:
%%time 
# With cache
post_output = model.generate(
    **inputs,
    max_new_tokens=max_new_tokens,
    temperature=temperature,
    top_p=top_p,
    output_scores=True,
    return_dict_in_generate=True,
    output_logits=True,
    # past_key_values=copy.deepcopy(prompt_cache),
    use_cache=False 
)

In [None]:
%%time 
# Reset cache

prompt_cache = StaticCache(config=model.config, max_batch_size=1, 
                           max_cache_len = 2**12, 
                           device='mps', dtype=torch.bfloat16)

post_output = model.generate(
    **inputs,
    max_new_tokens=max_new_tokens,
    temperature=temperature,
    top_p=top_p,
    output_scores=True,
    return_dict_in_generate=True,
    output_logits=True,
    past_key_values=copy.deepcopy(prompt_cache),
    use_cache=True 
)

In [None]:
%%time 
# With cache
post_output = model.generate(
    **inputs,
    max_new_tokens=max_new_tokens,
    temperature=temperature,
    top_p=top_p,
    output_scores=True,
    return_dict_in_generate=True,
    output_logits=True,
    past_key_values=copy.deepcopy(prompt_cache),
    use_cache=True 
)

In [None]:
%%time

prompt = generate_prompt(p["instruction"], p["question"], p["input"])
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
    pre_output = model(**inputs, use_cache=use_cache)
pre_output = pre_output.logits.cpu().detach()

In [None]:
inputs.input_ids.shape

In [None]:
2**14

In [None]:
prompt[-1]

In [None]:
%%time
import copy
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache
from accelerate.test_utils.testing import get_backend

prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=2**14, device='mps', dtype=torch.float16)

prompt = generate_prompt(p["instruction"], p["question"], p["input"])
prompt_ = prompt[:-1]

inputs = tokenizer(prompt_, return_tensors="pt").to(device)

with torch.no_grad():
    pre_output = model(**inputs, past_key_values = prompt_cache)
    prompt_cache = pre_output.past_key_values

In [None]:
%%time
new_inputs = tokenizer(prompt, return_tensors="pt").to('mps')
outputs = model.generate(**new_inputs, past_key_values=prompt_cache, max_new_tokens=1)