In [20]:
!pip install gradio torch transformers accelerate -q

import gradio as gr
import torch
import time
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "Qwen/Qwen2.5-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")


Loading checkpoint shards: 100%|██████████| 2/2 [00:12<00:00,  6.32s/steps]


In [21]:

def generate(text, max_tokens, use_cache):
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    torch.cuda.empty_cache()
    
    start = time.time()
    with torch.no_grad():
        output = model.generate(
            **inputs, 
            max_new_tokens=int(max_tokens),
            do_sample=True, 
            temperature=0.7,
            use_cache=use_cache,
            pad_token_id=tokenizer.eos_token_id
        )
    elapsed = time.time() - start
    
    mem_peak = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0
    result = tokenizer.decode(output[0][len(inputs['input_ids'][0]):], skip_special_tokens=True)
    
    return result, f"Time: {elapsed:.2f}s | Mem: {mem_peak:.2f}GB"

In [22]:
demo = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(label="Prompt", value="Explain KV-cache"),
        gr.Slider(10, 100, 30, step=10, label="Max Tokens"),
        gr.Checkbox(label="Use KV-Cache", value=True)
    ],
    outputs=[
        gr.Textbox(label="Generation", lines=6),
        gr.Textbox(label="Performance")
    ],
    title="KV Cache Performance Demo",
    description="Toggle cache to see speedup/memory diff!"
)

demo.launch(share=True, debug=True)

* Running on local URL:  http://127.0.0.1:7862
* Running on public URL: https://52c83927767c58aad7.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7862 <> https://52c83927767c58aad7.gradio.live


