Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ __global__ void silu_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = silu(x) * y;
Expand All @@ -30,7 +30,7 @@ void silu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
int num_tokens = input.numel() / input.size(-1);
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;

dim3 grid(num_tokens);
Expand All @@ -55,8 +55,8 @@ __global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
Expand All @@ -67,7 +67,7 @@ __global__ void activation_kernel(
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int num_tokens = input.numel() / d; \
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
Expand Down
2 changes: 1 addition & 1 deletion csrc/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void rotary_embedding(
int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int num_tokens = query.numel() / query.size(-1);
int64_t num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
Expand Down
6 changes: 6 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,9 @@ Documentation

models/supported_models
models/adding_model

.. toctree::
:maxdepth: 1
:caption: Quantization

quantization/auto_awq
69 changes: 69 additions & 0 deletions docs/source/quantization/auto_awq.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
.. _auto_awq:

AutoAWQ
==================

To create a new 4-bit quantized model, you can leverage `AutoAWQ <https://github.com/casper-hansen/AutoAWQ>`_.
Quantizing reduces the model's precision from FP16 to INT4 which effectively reduces the file size by ~70%.
The main benefits are lower latency and memory usage.

You can quantize your own models by installing AutoAWQ or picking one of the `400+ models on Huggingface <https://huggingface.co/models?sort=trending&search=awq>`_.

.. code-block:: console

$ pip install autoawq

After installing AutoAWQ, you are ready to quantize a model. Here is an example of how to quantize Vicuna 7B v1.5:

.. code-block:: python

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = 'lmsys/vicuna-7b-v1.5'
quant_path = 'vicuna-7b-v1.5-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path, **{"low_cpu_mem_usage": True})
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
model.quantize(tokenizer, quant_config=quant_config)

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)

To run an AWQ model with vLLM, you can use `TheBloke/Llama-2-7b-Chat-AWQ <https://huggingface.co/TheBloke/Llama-2-7b-Chat-AWQ>`_ with the following command:

.. code-block:: console

$ python examples/llm_engine_example.py --model TheBloke/Llama-2-7b-Chat-AWQ --quantization awq

AWQ models are also supported directly through the LLM entrypoint:

.. code-block:: python

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="TheBloke/Llama-2-7b-Chat-AWQ", quantization="AWQ")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
34 changes: 34 additions & 0 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,37 @@ def test_sampler_mixed(seed: int):
continue
for nth_output in sequence_output.samples:
assert nth_output.output_token in expected_tokens


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_logits_processors(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, worker = _prepare_test(batch_size)

# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
return logits

seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
))

_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
for idx, nth_output in enumerate(sequence_output.samples):
assert nth_output.output_token == idx
3 changes: 3 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ def _get_and_verify_max_len(
if rope_scaling is not None:
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor

if max_model_len is None:
Expand Down
19 changes: 17 additions & 2 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.rotary_embedding import (
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
RotaryEmbedding)
RotaryEmbedding, YaRNScalingRotaryEmbedding)

_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
Expand Down Expand Up @@ -156,7 +156,9 @@ def single_query_cached_kv_attention(
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = input_metadata.max_context_len <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
# Run PagedAttention V1.
attention_ops.paged_attention_v1(
Expand Down Expand Up @@ -332,6 +334,19 @@ def __init__(
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor",
"beta_fast", "beta_slow")
}
self.rotary_emb = YaRNScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, **extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

Expand Down
104 changes: 104 additions & 0 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rotary Positional Embeddings."""
import math
from typing import Tuple, Union

import torch
Expand Down Expand Up @@ -167,3 +168,106 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache


# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(num_rotations: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> float:
return (dim * math.log(max_position_embeddings /
(num_rotations * 2 * math.pi))) / (2 *
math.log(base))


# Find dim range bounds based on rotations
def _yarn_find_correction_range(low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> int:
low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(
_yarn_find_correction_dim(high_rot, dim, base,
max_position_embeddings))
return max(low, 0), min(high, dim - 1) # Clamp values just in case


def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
dtype: torch.dtype,
device: torch.device) -> torch.Tensor:
if low == high:
high += 0.001 # Prevent singularity

linear_func = (torch.arange(dim, dtype=dtype, device=device) -
low) / (high - low)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func


def _yarn_get_mscale(scale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0


class YaRNScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with YaRN method.

Credits to Peng et al. github.com/jquesnelle/yarn
"""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: float = 32,
beta_slow: float = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(
_yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)

def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2, dtype=torch.float,
device="cuda")) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq

def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1)
return cache
24 changes: 24 additions & 0 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def forward(
logits = _get_logits(hidden_states, embedding, embedding_bias,
self.vocab_size)

# Apply logits processors (if any).
logits = _apply_logits_processors(logits, input_metadata)
# Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0]
Expand Down Expand Up @@ -155,6 +157,28 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
return output_tokens


def _apply_logits_processors(logits: torch.Tensor,
input_metadata: InputMetadata) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in input_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
logits_row = logits[logits_row_idx]
token_ids = input_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
assert logits_row_idx == logits.shape[0]
return logits


def _apply_penalties(
logits: torch.Tensor,
output_tokens: List[List[int]],
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MistralForCausalLM": MistralForCausalLM,
# transformers's mpt class has lower case
"MptForCausalLM": MPTForCausalLM,
"MPTForCausalLM": MPTForCausalLM,
"MptForCausalLM": MptForCausalLM,
"MPTForCausalLM": MptForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel,
"RWForCausalLM": FalconForCausalLM,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.model_executor.models.internlm import InternLMForCausalLM
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.mistral import MistralForCausalLM
from vllm.model_executor.models.mpt import MPTForCausalLM
from vllm.model_executor.models.mpt import MptForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.models.qwen import QWenLMHeadModel

Expand All @@ -26,7 +26,7 @@
"GPTNeoXForCausalLM",
"InternLMForCausalLM",
"LlamaForCausalLM",
"MPTForCausalLM",
"MptForCausalLM",
"OPTForCausalLM",
"QWenLMHeadModel",
"MistralForCausalLM",
Expand Down
Loading