diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 581525e970ce..89d1ba2d37dd 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -16,8 +16,8 @@ __global__ void silu_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] const int d) { - const int token_idx = blockIdx.x; - for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); out[token_idx * d + idx] = silu(x) * y; @@ -30,7 +30,7 @@ void silu_and_mul( torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - int num_tokens = input.numel() / input.size(-1); + int64_t num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; dim3 grid(num_tokens); @@ -55,8 +55,8 @@ __global__ void activation_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., d] const int d) { - const int token_idx = blockIdx.x; - for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = __ldg(&input[token_idx * d + idx]); out[token_idx * d + idx] = ACT_FN(x); } @@ -67,7 +67,7 @@ __global__ void activation_kernel( // Launch element-wise activation kernel. #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ int d = input.size(-1); \ - int num_tokens = input.numel() / d; \ + int64_t num_tokens = input.numel() / d; \ dim3 grid(num_tokens); \ dim3 block(std::min(d, 1024)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 41001ba64746..0a5ec95f8c0d 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -84,7 +84,7 @@ void rotary_embedding( int head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox) { - int num_tokens = query.numel() / query.size(-1); + int64_t num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; int num_kv_heads = key.size(-1) / head_size; diff --git a/docs/source/index.rst b/docs/source/index.rst index 60a5b07f32fb..eb98aa6049bf 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -73,3 +73,9 @@ Documentation models/supported_models models/adding_model + +.. toctree:: + :maxdepth: 1 + :caption: Quantization + + quantization/auto_awq \ No newline at end of file diff --git a/docs/source/quantization/auto_awq.rst b/docs/source/quantization/auto_awq.rst new file mode 100644 index 000000000000..0a2b44239998 --- /dev/null +++ b/docs/source/quantization/auto_awq.rst @@ -0,0 +1,69 @@ +.. _auto_awq: + +AutoAWQ +================== + +To create a new 4-bit quantized model, you can leverage `AutoAWQ `_. +Quantizing reduces the model's precision from FP16 to INT4 which effectively reduces the file size by ~70%. +The main benefits are lower latency and memory usage. + +You can quantize your own models by installing AutoAWQ or picking one of the `400+ models on Huggingface `_. + +.. code-block:: console + + $ pip install autoawq + +After installing AutoAWQ, you are ready to quantize a model. Here is an example of how to quantize Vicuna 7B v1.5: + +.. code-block:: python + + from awq import AutoAWQForCausalLM + from transformers import AutoTokenizer + + model_path = 'lmsys/vicuna-7b-v1.5' + quant_path = 'vicuna-7b-v1.5-awq' + quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } + + # Load model + model = AutoAWQForCausalLM.from_pretrained(model_path, **{"low_cpu_mem_usage": True}) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Quantize + model.quantize(tokenizer, quant_config=quant_config) + + # Save quantized model + model.save_quantized(quant_path) + tokenizer.save_pretrained(quant_path) + +To run an AWQ model with vLLM, you can use `TheBloke/Llama-2-7b-Chat-AWQ `_ with the following command: + +.. code-block:: console + + $ python examples/llm_engine_example.py --model TheBloke/Llama-2-7b-Chat-AWQ --quantization awq + +AWQ models are also supported directly through the LLM entrypoint: + +.. code-block:: python + + from vllm import LLM, SamplingParams + + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + # Create an LLM. + llm = LLM(model="TheBloke/Llama-2-7b-Chat-AWQ", quantization="AWQ") + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index c4d33711cc9a..eec0d9ff7972 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -183,3 +183,37 @@ def test_sampler_mixed(seed: int): continue for nth_output in sequence_output.samples: assert nth_output.output_token in expected_tokens + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +def test_sampler_logits_processors(seed: int): + set_random_seed(seed) + batch_size = random.randint(1, 256) + input_tensor, _, sampler, worker = _prepare_test(batch_size) + + # This sample logits processor gives infinite score to the i-th token, + # where i is the length of the input sequence. + # We therefore expect the output token sequence to be [0, 1, 2, ...] + def pick_ith(token_ids, logits): + logits[len(token_ids)] = float("inf") + return logits + + seq_group_metadata_list = [] + for i in range(batch_size): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0, + logits_processors=[pick_ith]), + block_tables={0: [1]}, + )) + + _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) + sampler_output = sampler(embedding=None, + hidden_states=input_tensor, + input_metadata=input_metadata) + for i, sequence_output in enumerate(sampler_output): + for idx, nth_output in enumerate(sequence_output.samples): + assert nth_output.output_token == idx diff --git a/vllm/config.py b/vllm/config.py index 6e19491083d4..a9e86c24b273 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -390,6 +390,9 @@ def _get_and_verify_max_len( if rope_scaling is not None: assert "factor" in rope_scaling scaling_factor = rope_scaling["factor"] + if rope_scaling["type"] == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] derived_max_model_len *= scaling_factor if max_model_len is None: diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 58f868d407bf..c1259a1b11ea 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -12,7 +12,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.rotary_embedding import ( DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, - RotaryEmbedding) + RotaryEmbedding, YaRNScalingRotaryEmbedding) _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -156,7 +156,9 @@ def single_query_cached_kv_attention( # sequences or heads is large, we use V1 since there is enough work # to parallelize. # TODO(woosuk): Tune this heuristic. - use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 + # For context len > 8192, use V2 kernel to avoid shared memory shortage. + use_v1 = input_metadata.max_context_len <= 8192 and ( + max_num_partitions == 1 or num_seqs * num_heads > 512) if use_v1: # Run PagedAttention V1. attention_ops.paged_attention_v1( @@ -332,6 +334,19 @@ def __init__( self.rotary_emb = DynamicNTKScalingRotaryEmbedding( head_size, rotary_dim, max_position, base, is_neox_style, scaling_factor) + elif scaling_type == "yarn": + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + assert max_position == original_max_position * scaling_factor + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", + "beta_fast", "beta_slow") + } + self.rotary_emb = YaRNScalingRotaryEmbedding( + head_size, rotary_dim, original_max_position, base, + is_neox_style, scaling_factor, **extra_kwargs) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 4ecde07562fa..2cbd3b584c06 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -21,6 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Rotary Positional Embeddings.""" +import math from typing import Tuple, Union import torch @@ -167,3 +168,106 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) return cache + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim(num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> float: + return (dim * math.log(max_position_embeddings / + (num_rotations * 2 * math.pi))) / (2 * + math.log(base)) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range(low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> int: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, + max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask(low: float, high: float, dim: int, + dtype: torch.dtype, + device: torch.device) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype, device=device) - + low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def _yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +class YaRNScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: float = 32, + beta_slow: float = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + _yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, dtype=torch.float, + device="cuda")) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + device="cuda", + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + return cache diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6a29f1afd368..e0ec42081179 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -47,6 +47,8 @@ def forward( logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size) + # Apply logits processors (if any). + logits = _apply_logits_processors(logits, input_metadata) # Apply presence and frequency penalties. output_tokens = _get_output_tokens(input_metadata) assert len(output_tokens) == logits.shape[0] @@ -155,6 +157,28 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: return output_tokens +def _apply_logits_processors(logits: torch.Tensor, + input_metadata: InputMetadata) -> torch.Tensor: + logits_row_idx = 0 + found_logits_processors = False + for seq_ids, sampling_params in input_metadata.seq_groups: + logits_processors = sampling_params.logits_processors + if logits_processors: + found_logits_processors = True + for seq_id in seq_ids: + logits_row = logits[logits_row_idx] + token_ids = input_metadata.seq_data[seq_id].output_token_ids + for logits_processor in logits_processors: + logits_row = logits_processor(token_ids, logits_row) + logits[logits_row_idx] = logits_row + logits_row_idx += 1 + else: + logits_row_idx += len(seq_ids) + if found_logits_processors: + assert logits_row_idx == logits.shape[0] + return logits + + def _apply_penalties( logits: torch.Tensor, output_tokens: List[List[int]], diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 9dcfd968b45c..443f24790674 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -28,8 +28,8 @@ "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* "MistralForCausalLM": MistralForCausalLM, # transformers's mpt class has lower case - "MptForCausalLM": MPTForCausalLM, - "MPTForCausalLM": MPTForCausalLM, + "MptForCausalLM": MptForCausalLM, + "MPTForCausalLM": MptForCausalLM, "OPTForCausalLM": OPTForCausalLM, "QWenLMHeadModel": QWenLMHeadModel, "RWForCausalLM": FalconForCausalLM, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 35d72c16307b..c4bba4855ef3 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -10,7 +10,7 @@ from vllm.model_executor.models.internlm import InternLMForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.mistral import MistralForCausalLM -from vllm.model_executor.models.mpt import MPTForCausalLM +from vllm.model_executor.models.mpt import MptForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.qwen import QWenLMHeadModel @@ -26,7 +26,7 @@ "GPTNeoXForCausalLM", "InternLMForCausalLM", "LlamaForCausalLM", - "MPTForCausalLM", + "MptForCausalLM", "OPTForCausalLM", "QWenLMHeadModel", "MistralForCausalLM", diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index ba7441e145b1..4a66c5b5dec6 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn +from transformers import MptConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn @@ -19,7 +20,6 @@ ColumnParallelLinear, RowParallelLinear) from vllm.sequence import SamplerOutput -from vllm.transformers_utils.configs.mpt import MPTConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -37,17 +37,17 @@ def _get_alibi_slopes( return slopes -class MPTAttention(nn.Module): +class MptAttention(nn.Module): - def __init__(self, config: MPTConfig): + def __init__(self, config: MptConfig): super().__init__() self.d_model = config.d_model self.total_num_heads = config.n_heads - self.clip_qkv = config.attn_config["clip_qkv"] - self.qk_ln = config.attn_config["qk_ln"] - self.alibi_bias_max = config.attn_config["alibi_bias_max"] - assert not config.attn_config["prefix_lm"] - assert config.attn_config["alibi"] + self.clip_qkv = config.attn_config.clip_qkv + self.qk_ln = config.attn_config.qk_ln + self.alibi_bias_max = config.attn_config.alibi_bias_max + assert not config.attn_config.prefix_lm + assert config.attn_config.alibi self.qkv_proj = ColumnParallelLinear( self.d_model, @@ -105,9 +105,9 @@ def forward( return output -class MPTMLP(nn.Module): +class MptMLP(nn.Module): - def __init__(self, config: MPTConfig): + def __init__(self, config: MptConfig): super().__init__() hidden_size = config.d_model expansion_ratio = config.expansion_ratio @@ -133,15 +133,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class MPTBlock(nn.Module): +class MptBlock(nn.Module): - def __init__(self, config: MPTConfig): + def __init__(self, config: MptConfig): super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MPTAttention(config) + self.attn = MptAttention(config) self.norm_2 = nn.LayerNorm(hidden_size) - self.ffn = MPTMLP(config) + self.ffn = MptMLP(config) def forward( self, @@ -166,9 +166,9 @@ def forward( return hidden_states -class MPTModel(nn.Module): +class MptModel(nn.Module): - def __init__(self, config: MPTConfig): + def __init__(self, config: MptConfig): super().__init__() assert config.embedding_fraction == 1.0 assert config.norm_type == "low_precision_layernorm" @@ -178,7 +178,7 @@ def __init__(self, config: MPTConfig): config.d_model, ) self.blocks = nn.ModuleList( - [MPTBlock(config) for _ in range(config.n_layers)]) + [MptBlock(config) for _ in range(config.n_layers)]) self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): @@ -213,14 +213,14 @@ def forward( return hidden_states -class MPTForCausalLM(nn.Module): +class MptForCausalLM(nn.Module): - def __init__(self, config: MPTConfig): + def __init__(self, config: MptConfig): super().__init__() self.config = config assert config.tie_word_embeddings - self.transformer = MPTModel(config) + self.transformer = MptModel(config) # TODO(zhuohan): create a new weight after implementing pipeline # parallelism self.lm_head_weight = self.transformer.wte.weight diff --git a/vllm/outputs.py b/vllm/outputs.py index ad6733ff5723..fe54926e06e6 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -53,6 +53,7 @@ class RequestOutput: request_id: The unique ID of the request. prompt: The prompt string of the request. prompt_token_ids: The token IDs of the prompt. + prompt_logprobs: The log probabilities to return per prompt token. outputs: The output sequences of the request. finished: Whether the whole request is finished. """ diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 00a9135a5ca7..f8ef9be7b6a6 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,7 +1,8 @@ """Sampling parameters for text generation.""" from enum import IntEnum from functools import cached_property -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union +import torch _SAMPLING_EPS = 1e-5 @@ -12,6 +13,12 @@ class SamplingType(IntEnum): BEAM = 2 +LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] +"""LogitsProcessor is a function that takes a list of previously generated +tokens and a tensor of the logits for the next token, and returns a modified +tensor of logits to sample from.""" + + class SamplingParams: """Sampling parameters for text generation. @@ -73,6 +80,8 @@ class SamplingParams: skip_special_tokens: Whether to skip special tokens in the output. spaces_between_special_tokens: Whether to add spaces between special tokens in the output. Defaults to True. + logits_processors: List of functions that modify logits based on + previously generated tokens. """ def __init__( @@ -96,6 +105,7 @@ def __init__( prompt_logprobs: Optional[int] = None, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, + logits_processors: Optional[List[LogitsProcessor]] = None, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -124,7 +134,7 @@ def __init__( self.prompt_logprobs = prompt_logprobs self.skip_special_tokens = skip_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens - + self.logits_processors = logits_processors self._verify_args() if self.use_beam_search: self._verify_beam_search() diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index fd5618bd81ba..b69e0a1a4385 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,11 +1,11 @@ from typing import Optional -from transformers import AutoConfig, PretrainedConfig +from transformers import AutoConfig, MptConfig, PretrainedConfig from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import _CONFIG_REGISTRY = { - "mpt": MPTConfig, + "mpt": MptConfig, "baichuan": BaiChuanConfig, "aquila": AquilaConfig, "qwen": QWenConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 6611697d25ae..f5acb4e07972 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,4 +1,3 @@ -from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.baichuan import BaiChuanConfig from vllm.transformers_utils.configs.aquila import AquilaConfig from vllm.transformers_utils.configs.qwen import QWenConfig @@ -8,7 +7,6 @@ from vllm.transformers_utils.configs.falcon import RWConfig __all__ = [ - "MPTConfig", "BaiChuanConfig", "AquilaConfig", "QWenConfig", diff --git a/vllm/transformers_utils/configs/mpt.py b/vllm/transformers_utils/configs/mpt.py deleted file mode 100644 index 3909f710d44d..000000000000 --- a/vllm/transformers_utils/configs/mpt.py +++ /dev/null @@ -1,74 +0,0 @@ -# Adapted from -# https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py -from typing import Any, Dict, Optional, Union - -from transformers import PretrainedConfig - -_ATTN_CONFIG_DEFAULTS = { - "attn_type": "multihead_attention", - "attn_pdrop": 0.0, - "attn_impl": "triton", - "qk_ln": False, - "clip_qkv": None, - "softmax_scale": None, - "prefix_lm": False, - "attn_uses_sequence_id": False, - "alibi": False, - "alibi_bias_max": 8, -} - - -class MPTConfig(PretrainedConfig): - model_type = "mpt" - attribute_map = { - "hidden_size": "d_model", - "num_attention_heads": "n_heads", - "num_hidden_layers": "n_layers", - } - - def __init__( - self, - d_model: int = 2048, - n_heads: int = 16, - n_layers: int = 24, - expansion_ratio: int = 4, - max_seq_len: int = 2048, - vocab_size: int = 50368, - resid_pdrop: float = 0.0, - emb_pdrop: float = 0.0, - learned_pos_emb: bool = True, - attn_config: Optional[Dict[str, Any]] = None, - init_device: str = "cpu", - logit_scale: Optional[Union[float, str]] = None, - no_bias: bool = False, - verbose: int = 0, - embedding_fraction: float = 1.0, - norm_type: str = "low_precision_layernorm", - use_cache: bool = False, - **kwargs, - ) -> None: - self.d_model = d_model - self.n_heads = n_heads - self.n_layers = n_layers - self.expansion_ratio = expansion_ratio - self.max_seq_len = max_seq_len - self.vocab_size = vocab_size - self.resid_pdrop = resid_pdrop - self.emb_pdrop = emb_pdrop - self.learned_pos_emb = learned_pos_emb - if attn_config is None: - self.attn_config = _ATTN_CONFIG_DEFAULTS - else: - self.attn_config = attn_config - self.init_device = init_device - self.logit_scale = logit_scale - self.no_bias = no_bias - self.verbose = verbose - self.embedding_fraction = embedding_fraction - self.norm_type = norm_type - self.use_cache = use_cache - if "name" in kwargs: - del kwargs["name"] - if "loss_fn" in kwargs: - del kwargs["loss_fn"] - super().__init__(**kwargs) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d598a86cf0c1..b2391ae788a8 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -13,7 +13,7 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine -from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes +from vllm.utils import get_gpu_memory class Worker: @@ -141,13 +141,6 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: self.block_size = cache_config.block_size self.sliding_window = cache_config.sliding_window - if self.sliding_window is None: - max_seq_len = self.scheduler_config.max_model_len - else: - max_seq_len = min(self.scheduler_config.max_model_len, - self.sliding_window) - _check_if_can_support_max_seq_len(max_seq_len, self.block_size) - self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) self.cache_events = self.cache_engine.events @@ -421,26 +414,6 @@ def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: return x + [pad] * (max_len - len(x)) -def _check_if_can_support_max_seq_len(max_seq_len: int, - block_size: int) -> None: - # Follows the logic in - # attention_kernels.cu::single_query_cached_kv_attention_launcher - max_shared_mem = get_max_shared_memory_bytes() - float32_bytes = torch.finfo(torch.float).bits // 8 - padded_max_seq_len = ( - (max_seq_len + block_size - 1) / block_size) * block_size - # padded_max_seq_len + extra buffer - required_shared_mem = (padded_max_seq_len + 512) * float32_bytes - if padded_max_seq_len * float32_bytes > max_shared_mem: - raise RuntimeError( - f"vLLM cannot currently support max_model_len={max_seq_len} " - f"with block_size={block_size} on GPU with compute " - f"capability {torch.cuda.get_device_capability()} " - f"(required shared memory {required_shared_mem} > " - f"available shared memory {max_shared_mem}). " - "This will be fixed in a future release.") - - def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. if torch_dtype == torch.bfloat16: