In [1]:
import os

os.environ["HF_TOKEN"] = "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"

import copy
import json
import time

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache

In [2]:
# run params
temperature = 1.0
top_p = 0.9
max_new_tokens = 64
use_cache = True
device = "cuda:0"

# model_name = "osunlp/TableLlama"
model_name = "google/gemma-2-2b-it"

# load inputs
file_path = "turl_test_2k_prompts_50.jsonl"
device = torch.device(device)

with open(file_path, "r", encoding="utf-8") as f:
    prompts = [json.loads(line) for line in f]

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it", device_map="cuda:0", torch_dtype=torch.bfloat16
)
model.eval()

# build prompts
PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input_seg}\n\n### Question:\n{question}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}


def generate_prompt(instruction, question, input_seg=None):
    question += " Answer with just a candidate, selected from the provided referent entity candidates list, and nothing else. The selected candidate must be reported verbatim from the list provided as input. Each candidate in the list is enclosed between < and > and reports [DESC] and [TYPE] information."
    if input_seg:
        return PROMPT_DICT["prompt_input"].format(
            instruction=instruction, input_seg=input_seg, question=question
        )
    else:
        return PROMPT_DICT["prompt_no_input"].format(instruction=instruction)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
# p = prompts[742] # 742 is shortest
p = prompts[840] # average joe

In [4]:
# Generate initial prompt input
prompt = generate_prompt(p["instruction"], p["question"], p["input"])
inputs = tokenizer(prompt, return_tensors="pt").to(device)

prompt_ = prompt[:-1]
inputs_ = tokenizer(prompt_, return_tensors="pt").to(device)

### Generate no-cache

In [None]:
prompts = 5
responses = []
tic = time.perf_counter()
with torch.no_grad():
    for _ in range(prompts):
        outputs = model.generate(
            **inputs,
            max_new_tokens=32,
            cache_implementation=None,
            disable_compile=True,
            return_dict_in_generate=True,
            use_cache=False
        )
        response = tokenizer.batch_decode(outputs["sequences"])[0]
        responses.append(response)
toc = time.perf_counter()
print(f"Elapsed time: {toc - tic:.4f} seconds")

for r in responses:
    print(r)

### Generate with-cache

In [5]:
# Init StaticCache with big enough max-length (1024 tokens for the below example)
# You can also init a DynamicCache, if that suits you better

# prompt_cache = StaticCache(
#     config=model.config,
#     max_batch_size=1,
#     max_cache_len=4096,
#     device="cuda:0",
#     dtype=torch.bfloat16,
# )
prompt_cache = DynamicCache()

In [6]:
prompts = 5
responses = []

tic = time.perf_counter()
# This is the common prompt cached, we need to run forward without grad to be able to copy
with torch.no_grad():
    prompt_cache = model(
        **inputs_, use_cache=True, past_key_values=prompt_cache, compile=False
    ).past_key_values
    for _ in range(prompts):
        past_key_values = copy.deepcopy(prompt_cache)
        outputs = model.generate(
            **inputs,
            past_key_values=past_key_values,
            max_new_tokens=32,
            cache_implementation=None,
            disable_compile=True,
            return_dict_in_generate=True,
            use_cache=True,
        )
        response = tokenizer.batch_decode(outputs["sequences"])[0]
        responses.append(response)
toc = time.perf_counter()
print(f"Elapsed time: {toc - tic:.4f} seconds")

for r in responses:
    print(r)

Elapsed time: 8.5105 seconds
<bos>Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
This is an entity linking task. The goal for this task is to link the selected entity mention in the table cells to the entity in the knowledge base. You will be given a list of referent entities, with each one composed of an entity name, its description and its type. Please choose the correct one from the referent entity candidates. Note that the Wikipedia page, Wikipedia section and table caption (if any) provide important information for choosing the correct referent entity.

### Input:
[TLE] Qatar Stars League. Qatari Stars League member clubs [TAB] col: |club|location|stadium|year formed| row 0: |Al Ahli Sports Club|Doha|Khalifa International Stadium|1950| row 1: |Al-Arabi Doha Sports Club|Doha|Grand Hamad Stadium|1952| row 2: |Al-Gharafa Sports Club|Al-Gharafa, Al-Rayyan|Th