In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import warnings
warnings.filterwarnings("ignore")
import torch
import argparse
import json
import os
import time
import re
import sys

from tqdm import tqdm
from streaming_llm.utils import load, download_url, load_jsonl
from streaming_llm.enable_streaming_llm import enable_streaming_llm


@torch.no_grad()
def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len):
    outputs = model(
        input_ids=input_ids,
        past_key_values=past_key_values,
        use_cache=True,
    )
    past_key_values = outputs.past_key_values
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    generated_ids = [pred_token_idx.item()]
    pos = 0
    for _ in range(max_gen_len - 1):
        outputs = model(
            input_ids=pred_token_idx,
            past_key_values=past_key_values,
            use_cache=True,
        )
        past_key_values = outputs.past_key_values
        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        generated_ids.append(pred_token_idx.item())
        generated_text = (
            tokenizer.decode(
                generated_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
                spaces_between_special_tokens=False,
            )
            .strip()
            .split(" ")
        )

        now = len(generated_text) - 1
        if now > pos:
            print(" ".join(generated_text[pos:now]), end=" ", flush=True)
            pos = now

        if pred_token_idx == tokenizer.eos_token_id:
            break
    print(" ".join(generated_text[pos:]), flush=True)
    return past_key_values


@torch.no_grad()
def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=1000):
    past_key_values = None
    for idx, prompt in enumerate(prompts):
        prompt = "USER: " + prompt + "\n\nASSISTANT: "
        print("\n" + prompt, end="")
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        input_ids = input_ids.to(model.device)
        seq_len = input_ids.shape[1]
        if kv_cache is not None:
            space_needed = seq_len + max_gen_len
            past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)

        past_key_values = greedy_generate(
            model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
        )


In [2]:
model, tokenizer = load("/data/share_user/zzd/ckpt/Xorbits/vicuna-13b-v1.5") #lmsys/vicuna-13b-v1.5
test_filepath = os.path.join("streamingllm_mt_data.jsonl")

list_data = load_jsonl(test_filepath)
prompts = []
for sample in list_data:
    prompts += sample["turns"]

kv_cache = enable_streaming_llm(model, start_size=4, recent_size=2000)
streaming_inference(model, tokenizer, prompts, kv_cache,)

Loading checkpoint shards: 100%|██████████| 3/3 [00:15<00:00,  5.26s/it]


StartRecentKVCache: 4, 2000

USER: Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.

ASSISTANT: 🌴🌊🏝️ Hawaii: A Cultural Paradise 🌺🌞🏄‍♀️

Hawaii, the land of Aloha, is a tropical paradise that offers a unique blend of natural beauty, rich culture, and endless adventures. Recently, I had the pleasure of visiting this enchanting archipelago, and I must say, it exceeded all my expectations. From the stunning beaches to the vibrant local traditions, Hawaii is a true gem that every traveler should experience at least once in their lifetime.

🌊 Surf's Up in Waikiki 🌊

No trip to Hawaii is complete without spending some time in Waikiki, the iconic beach town that's home to some of the world's most famous surf spots. I was thrilled to take a surfing lesson with a local company that not only taught me the basics but also gave me a glimpse into the Hawaiian surf culture. It was amazing to see how surfing is not just a 