In [10]:
# https://github.com/huggingface/transformers/issues/24841

1

In [1]:
import torch
import torch.nn.functional as F
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
from tqdm import tqdm
import gc
import os
import time

# run params
temperature=1.0
top_p=0.9
max_new_tokens=64
use_cache=True
device = "mps"

model_name = "osunlp/TableLlama"

# load inputs
file_path = "turl_test_2k_prompts_50.jsonl"
device = torch.device(device)

with open(file_path, "r", encoding="utf-8") as f:
    prompts = [json.loads(line) for line in f]

config = transformers.AutoConfig.from_pretrained(model_name)
orig_ctx_len = getattr(config, "max_position_embeddings", None)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
model.resize_token_embeddings(32001)
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=orig_ctx_len, padding_side="left", use_fast=False)
model.eval()

# build prompts
PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input_seg}\n\n### Question:\n{question}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}

def generate_prompt(instruction, question, input_seg=None):
    question += " Answer with just a candidate, selected from the provided referent entity candidates list, and nothing else. The selected candidate must be reported verbatim from the list provided as input. Each candidate in the list is enclosed between < and > and reports [DESC] and [TYPE] information."
    if input_seg:
        return PROMPT_DICT["prompt_input"].format(instruction=instruction, input_seg=input_seg, question=question)
    else:
        return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
# p = prompts[742] # 742 is shortest
p = prompts[840] # average joe

In [6]:
%%time

import copy
import torch
from transformers import StaticCache

# Initialize prompt cache
prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=2**14, device='mps', dtype=torch.float16)

# Generate initial prompt input
prompt = generate_prompt(p["instruction"], p["question"], p["input"])
inputs = tokenizer(prompt, return_tensors="pt").to(device)
prompt_ = prompt[:-1]
inputs_ = tokenizer(prompt_, return_tensors="pt").to(device)

CPU times: user 16.6 ms, sys: 920 ms, total: 937 ms
Wall time: 1.05 s


In [4]:
%%time
# as it were in the script
with torch.no_grad():
    pre_output = model(**inputs, use_cache=True)

CPU times: user 201 ms, sys: 614 ms, total: 815 ms
Wall time: 2.58 s


In [5]:
%%time
# as it were in the script

post_output = model.generate(
    **inputs,
    max_new_tokens=max_new_tokens,
    temperature=temperature,
    top_p=top_p,
    output_scores=True,
    return_dict_in_generate=True,
    output_logits=True,
    use_cache=True
)

From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.


CPU times: user 1.81 s, sys: 1.15 s, total: 2.95 s
Wall time: 5.82 s


In [6]:
%%time 
# No cache
post_output = model.generate(
    **inputs,
    max_new_tokens=max_new_tokens,
    temperature=temperature,
    top_p=top_p,
    output_scores=True,
    return_dict_in_generate=True,
    output_logits=True,
    use_cache=False 
)

CPU times: user 4.3 s, sys: 6.45 s, total: 10.7 s
Wall time: 1min


In [31]:
2**12

4096

In [8]:
%%time

prompt_cache = StaticCache(config=model.config, max_batch_size=1, 
                           max_cache_len = 2**12, 
                           device='mps', dtype=torch.bfloat16)

# Generate cache
with torch.no_grad():
    pre_output = model(**inputs_, past_key_values=prompt_cache, use_cache=False)
    prompt_cache = pre_output.past_key_values

CPU times: user 128 ms, sys: 656 ms, total: 784 ms
Wall time: 4.55 s


In [11]:
%%time 
# With cache
post_output = model.generate(
    **inputs,
    max_new_tokens=max_new_tokens,
    temperature=temperature,
    top_p=top_p,
    output_scores=True,
    return_dict_in_generate=True,
    output_logits=True,
    past_key_values=copy.deepcopy(prompt_cache),
    use_cache=True 
)

CPU times: user 1.32 s, sys: 905 ms, total: 2.23 s
Wall time: 6.64 s


In [10]:
%%time 
# With cache
post_output = model.generate(
    **inputs,
    max_new_tokens=max_new_tokens,
    temperature=temperature,
    top_p=top_p,
    output_scores=True,
    return_dict_in_generate=True,
    output_logits=True,
    past_key_values=copy.deepcopy(prompt_cache),
    use_cache=False 
)

CPU times: user 3.39 s, sys: 5.9 s, total: 9.3 s
Wall time: 1min 32s


In [12]:
%%time 
# Reset cache

prompt_cache = StaticCache(config=model.config, max_batch_size=1, 
                           max_cache_len = 2**12, 
                           device='mps', dtype=torch.bfloat16)

post_output = model.generate(
    **inputs,
    max_new_tokens=max_new_tokens,
    temperature=temperature,
    top_p=top_p,
    output_scores=True,
    return_dict_in_generate=True,
    output_logits=True,
    past_key_values=copy.deepcopy(prompt_cache),
    use_cache=True 
)

CPU times: user 1.45 s, sys: 1.54 s, total: 2.99 s
Wall time: 10.1 s


In [8]:
%%time 
# With cache
post_output = model.generate(
    **inputs,
    max_new_tokens=max_new_tokens,
    temperature=temperature,
    top_p=top_p,
    output_scores=True,
    return_dict_in_generate=True,
    output_logits=True,
    past_key_values=copy.deepcopy(prompt_cache),
    use_cache=True 
)

CPU times: user 1.41 s, sys: 7.7 s, total: 9.11 s
Wall time: 24.1 s


In [4]:
%%time

prompt = generate_prompt(p["instruction"], p["question"], p["input"])
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
    pre_output = model(**inputs, use_cache=use_cache)
pre_output = pre_output.logits.cpu().detach()

CPU times: user 308 ms, sys: 650 ms, total: 958 ms
Wall time: 5.64 s


In [5]:
inputs.input_ids.shape

torch.Size([1, 1073])

In [6]:
2**14

16384

In [10]:
prompt[-1]

':'

In [18]:
%%time
import copy
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache
from accelerate.test_utils.testing import get_backend

prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=2**14, device='mps', dtype=torch.float16)

prompt = generate_prompt(p["instruction"], p["question"], p["input"])
prompt_ = prompt[:-1]

inputs = tokenizer(prompt_, return_tensors="pt").to(device)

with torch.no_grad():
    pre_output = model(**inputs, past_key_values = prompt_cache)
    prompt_cache = pre_output.past_key_values

CPU times: user 199 ms, sys: 1.16 s, total: 1.36 s
Wall time: 9.59 s


In [19]:
%%time
new_inputs = tokenizer(prompt, return_tensors="pt").to('mps')
outputs = model.generate(**new_inputs, past_key_values=prompt_cache, max_new_tokens=1)

CPU times: user 98.4 ms, sys: 367 ms, total: 466 ms
Wall time: 3.77 s
