# Run an experiment mixing StreamingLLM and Quantization
- Llama seems to generate reasonable text after compression + quantization, but the results deteriorate a little.
- QWen seems perform badly after quantization even when no compression.

In [1]:
import os

# reduce compilation time on H100
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"

In [48]:
from dataclasses import dataclass
from contextlib import contextmanager

import torch
from torch import nn
from transformers import pipeline

from kvpress import BasePress, KnormPress, ScorerPress
from transformers import QuantizedCacheConfig, QuantoQuantizedCache

## Llama 3.1 8B

In [62]:
# Load pipeline

device = "cuda:0"
# ckpt = "Qwen/Qwen2.5-1.5B-Instruct"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
attn_implementation = "flash_attention_2"
pipe = pipeline("kv-press-text-generation", model=ckpt, device=device, torch_dtype="auto", model_kwargs={"attn_implementation":attn_implementation})

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Device set to use cuda:0


In [63]:
# Load data

context = "In this step-by-step guide, you will learn how to create a new press in kvpress !"
question = "\nWhat is the purpose of this guide?"
tokens = pipe.tokenizer(context, return_tensors="pt").to(device)

In [70]:
compression_ratio = 0.25
press = KnormPress(compression_ratio)

with torch.no_grad():
    outputs_without_press = pipe.model(**tokens, output_hidden_states=True)

with torch.no_grad(), press(pipe.model):
    output_with_press = pipe.model(**tokens)

print(f"Cache shape w/o press: {outputs_without_press.past_key_values[0][0].shape}")
print(f"Cache shape w/ press:  {output_with_press.past_key_values[0][0].shape}\n")

# The `KVPressTextGenerationPipeline` simply applies the `press` as above on the context tokens (see `_forward` method for more details).
print(pipe(context, question=question, press=press)["answer"])

Cache shape w/o press: torch.Size([1, 8, 21, 128])
Cache shape w/ press:  torch.Size([1, 8, 15, 128])

The purpose of this guide is to walk you through the process of creating a new press in kvpress, which is likely a content management system or a plugin for managing content. The guide will provide step-by-step instructions on how to create a new press


In [71]:
config = QuantizedCacheConfig(nbits=4)
cache = QuantoQuantizedCache(config)
pipe(context, question=question, press=press, cache=cache)["answer"]

'The purpose of this guide is to walk you through the process of creating a new press in kvpress, which is a content management system (CMS) used for managing and publishing content on websites. \n\nBy following this guide, you will be able to'

## QWen 2.5 1.5B: Quantization will deteriorate performance?

In [53]:
device = "cuda:0"
ckpt = "Qwen/Qwen2.5-1.5B-Instruct"
attn_implementation = "flash_attention_2"
pipe = pipeline("kv-press-text-generation", model=ckpt, device=device, torch_dtype="auto", model_kwargs={"attn_implementation":attn_implementation})

Device set to use cuda:0


In [54]:
compression_ratio = 0.0
press = KnormPress(compression_ratio)

with torch.no_grad():
    outputs_without_press = pipe.model(**tokens, output_hidden_states=True)

with torch.no_grad(), press(pipe.model):
    output_with_press = pipe.model(**tokens)

print(f"Cache shape w/o press: {outputs_without_press.past_key_values[0][0].shape}")
print(f"Cache shape w/ press:  {output_with_press.past_key_values[0][0].shape}\n")

# The `KVPressTextGenerationPipeline` simply applies the `press` as above on the context tokens (see `_forward` method for more details).
print(pipe(context, question=question, press=press)["answer"])

Cache shape w/o press: torch.Size([1, 2, 21, 128])
Cache shape w/ press:  torch.Size([1, 2, 21, 128])

The purpose of this step-by-step guide is to provide a comprehensive and easy-to-follow tutorial on how to create a new press in the KVPress platform. The guide is designed to help users understand the process of creating a new press, including the necessary


In [55]:
config = QuantizedCacheConfig(nbits=4)
cache = QuantoQuantizedCache(config)
print(pipe(context, question=question, press=press, cache=cache)["answer"])

. a a a a.... created created created learn learn learn learn learn user user created created created created created created in created in created. is created created created created created created created learn learn user guide is to create created created created created this
