Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm/model_executor/layers/quantization/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# skip if there are no weights to process (for example, weight reloading)
if not hasattr(layer, "q_scale"):
assert not hasattr(layer, "k_scale")
assert not hasattr(layer, "v_scale")
assert not hasattr(layer, "prob_scale")
return

# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to
Expand Down
36 changes: 5 additions & 31 deletions vllm/model_executor/model_loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,39 +277,13 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
"0.14.0"
):
self.load_config.safetensors_load_strategy = "torchao"
weights_to_load = {name for name, _ in model.named_parameters()}

# if we don't have `model.weight_metadata_and_attr_saved` defined and
# set to True, it means that this is either offline quantization case
# or the first run of online quantization
# see online_quantization.py for detailed notes
offline_quantization_or_first_run_of_online_quantization = not getattr(
model, "weight_metadata_and_attr_saved", False
)

if model_config.quantization is None:
# model is not quantized
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model)
)
elif offline_quantization_or_first_run_of_online_quantization:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
# see online_quantization.py for detailed notes
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model)
)
else:
# to avoid circular dependency
from vllm.model_executor.model_loader.online_quantization import (
load_weights_and_online_quantize,
)

# subsequent runs of weight loading with online
# quantization
loaded_weights = load_weights_and_online_quantize(self, model, model_config)
# load weights into model
weights_to_load = {name for name, _ in model.named_parameters()}
weights_iterator = self.get_all_weights(model_config, model)
loaded_weights = model.load_weights(weights_iterator)

# logging and validation
self.counter_after_loading_weights = time.perf_counter()
logger.info_once(
"Loading weights took %.2f seconds",
Expand Down
323 changes: 114 additions & 209 deletions vllm/model_executor/model_loader/online_quantization.py
Original file line number Diff line number Diff line change
@@ -1,224 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import types
"""
Utilities for enabling weight reloading and online quantization
For more information and diagrams, see https://github.com/neuralmagic/vllm/pull/128

## Model Reloading Lifecycle ##
1. Model is loadeded for the first time
a. Checkpoint is loaded by `ModelLoader.get_all_weights` into `weights_iterator`
b. `weights_iterator` is loaded into model by `model.load_weights`
c. Model state is captured by `record_weights_for_reloading`
d. `process_weights_after_loading` converts model state into kernel format.
The model is no longer loadable while its weights are in kernel format

2. Model is reloaded via `reload_weights`
a. A `weights_iterator` is provided, which may be async/ chunked/ sharded
b. The original model state is restored by `restore_weights_for_reloading`
using metadata information from `record_weights_for_reloading`
c. `weights_iterator` is loaded into model by `model.load_weights`
d. `process_weights_after_loading` converts model state into kernel format.
The model is no longer loadable while its weights are in kernel format

Alternatively, if a user does not want to use `reload_weights`, they can call
steps 2b and 2d manually:

```python
record_weights_for_reloading(model)

for weights in weights_iterator: # may be async/ chunked/ sharded
model.load_weights(weights)

process_weights_after_loading(model, model_config, device)
```
"""

import torch
from torch import nn

from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import process_weights_after_loading

logger = init_logger(__name__)

# Notes for Online Quantization
# In terms of state of checkpoints, quantization config and their
# correspondance to online quantization:
# | Use Case | Checkpoints | model_config.quantization |
# | no quant | high precision | None |
# | offline quant | quantized | fp8, torchao etc. |
# | online quant | high precision | torchao etc. |
#
# The process for loading non-quantized checkpoint
# 1. load non-quantized weights (load_weights)
# 2. do any additional post processing (process_weights_after_loading)
#
# The process for loading offline quantized checkpoint
# 1. load offline-quantized weights (load_weights)
# 2. do any additional post processing (process_weights_after_loading)

# The process for unquantized model reloading
# (repeated run in RL training loop)
# first run
# UI1. load_weights: load bfloat16 weights
# UI2. process_weights_after_loading: any additional post processing
# subsequent run
# UC1: load_weights: load bfloat16 weights
# (shouldn't be any issues since we didn't change any attributes
# of the weights)
# UC2: process_weights_after_loading: any additional post processing

# The process for weight reloading with online quantization
# (repeated run in RL training loop)
# first run
# I1. load_weights: load bfloat16 weights
# I2. process_weights_after_loading:
# record weight metadata and attributes for R1 and R2
# quantize weights to fp8
# subsequent run
# (beginning model weight is in fp8)
# load_weights:
# R1. restore bfloat16 model weight metadata
# R2. restore the model weight attributes
# R3. reload bfloat16 weights
# R4. quantize weights (by calling process_weights_after_loading),
# also set `process_weights_after_loading_already_called` to
# True to stop it from running again
# process_weights_after_loading (if called):
# this will be skipped since it's already ran in
# load_weights


def maybe_save_metadata_and_attributes_for_weight_reloading(
model: nn.Module, model_config: ModelConfig
):
# following is to support on the fly quantization, currently only supported
# for torchao
if model_config.quantization != "torchao":
return

if getattr(model, "process_weights_after_loading_already_called", False):
# In case `process_weights_after_loading` is called multiple times
# we'll skip it at later times
logger.warning(
"process_weights_after_loading already called for model %s", model
)
return

from vllm.model_executor.model_loader.weight_utils import get_quant_config

quant_config = get_quant_config(model_config, None)

# If checkpoint is already torchao serialized, this means it's
# pre-quantized quantization case, we'll skip saving the metadata
# Otherwise, this is Step I2 of initialization steps of
# online quantization
# This step record the weights metadata and weight attributes so we can
# restore the bfloat16 model weights during the relad step (R1 and R2)
# see Notes in online_quantization.py for more details
if not (
hasattr(quant_config, "is_checkpoint_torchao_serialized")
and not quant_config.is_checkpoint_torchao_serialized
):
return

# This is the I2 step of online quantiztion that saves
# metadata and attributes of weights so they can be used in R1 and
# R2 step, note that we only save these during initialization

# Includes two things
# 1. save floating point metadata (shape, dtype, device) for init
# 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init

if getattr(model, "weight_metadata_and_attr_saved", False):
return

# save the dtype, shape and device for model parameter, used for
# restoring the model high precision parameters before
# reloading the weights
assert not hasattr(model, "original_weights_rebuild_keys")
model.original_weights_rebuild_keys = {}
for name, p in model.named_parameters():
model.original_weights_rebuild_keys[name] = {
"shape": p.shape,
"dtype": p.dtype,
"device": p.device,
__all__ = [
"RELOADABLE_QUANT_CONFIGS",
"record_weights_for_reloading",
"restore_weights_for_reloading",
]

# in theory, this implementation of weight recording/restoring
# should support any quantization config
RELOADABLE_QUANT_CONFIGS = {
None,
"torchao",
"fp8",
}


def record_weights_for_reloading(model: nn.Module):
# this function should be called before `process_weights_after_loading`
# in practice, this happens at the very start of `process_weights_after_loading`
if not hasattr(model, "weight_loading_metadata"):
model.weight_loading_metadata = {
name: _copy_to_meta_tensor(param)
for name, param in model.named_parameters(remove_duplicate=False)
}

# record the weight attributes (loader functions etc.)
# so these can be recovered later when we reload the weights
# structure: {"weight_name": {"weight_attr_key": attr}}
assert not hasattr(model, "recorded_weight_attr")
model.recorded_weight_attr = {}
for name, param in model.named_parameters():
model.recorded_weight_attr[name] = {}
for key in param.__dict__:
if hasattr(param, key):
attr = getattr(param, key)
if not callable(attr):
model.recorded_weight_attr[name][key] = attr
elif hasattr(attr, "__self__") and param is attr.__self__:
# if attr is a bonded method for an instance, and
# attr.__self__ points to the instance (param)
# we'll record the underlying function object
model.recorded_weight_attr[name][key] = attr.__func__
else:
model.recorded_weight_attr[name][key] = attr
# mark the metadata and attributes saved so we don't run it again
model.weight_metadata_and_attr_saved = True


def _bond_method_to_cls(func, obj):
if hasattr(func, "__self__") or not callable(func):
# If the function is already bound to an instance, return it as is
return func
else:
return types.MethodType(func, obj)


def load_weights_and_online_quantize(
model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig
) -> set[str]:
# online quantization, right now only enabled for
# torchao
# R1, R2, R3, R4 in the Notes

# TODO: Add fp8 support
assert model_config.quantization == "torchao", (
"online quantization is only enabled for torchao currently"

def restore_weights_for_reloading(model: nn.Module):
assert hasattr(model, "weight_loading_metadata")
metadata: dict[str, torch.Tensor] = model.weight_loading_metadata
model_param_names = dict(model.named_parameters(remove_duplicate=False)).keys()

# remove parameters which were not present at load time
params_to_remove = model_param_names - metadata.keys()
for param_fqn in params_to_remove:
module_name, param_name = param_fqn.rsplit(".", 1)
module = model.get_submodule(module_name)

# sometimes modules are shared, as is the case for `shared_experts`
if hasattr(module, param_name):
delattr(module, param_name)

# restore parameters that were present at load time
for param_fqn, meta_tensor in metadata.items():
module_name, param_name = param_fqn.rsplit(".", 1)
module = model.get_submodule(module_name)

# for faster runtime, skip materialization if the tensors match
original_tensor = getattr(module, param_name, None)
if _tensors_alike(original_tensor, meta_tensor):
continue

param = _materialize_meta_tensor(meta_tensor)
setattr(module, param_name, param)


def _copy_to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor:
meta_tensor = tensor.to("meta")
meta_tensor.__class__ = tensor.__class__
meta_tensor.__dict__ = tensor.__dict__
meta_tensor._original_device = tensor.device

return meta_tensor


def _tensors_alike(tensor: torch.Tensor | None, meta_tensor: torch.Tensor) -> bool:
if tensor is None:
return False

return (
tensor.device == meta_tensor._original_device
and tensor.dtype == meta_tensor.dtype
and tensor.shape == meta_tensor.shape
and tensor.__dict__ == meta_tensor.__dict__
)
# TODO: use create_weights to restore the weights to original state

# Step R1: First restore the quantized weights to original bfloat16
# weights, with original metadata (shape, dtype, device)
# and attributes, so that bfloat16 weights can be loaded properly
existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys()
named_modules = dict(model.named_modules(remove_duplicate=False))
model_device = None

# Step R2: recover the parameter to the state before first loading
for name, d in model.original_weights_rebuild_keys.items():
_shape = d["shape"]
_dtype = d["dtype"]
_device = d["device"]
if model_device is not None:
assert model_device == _device, (
"Expecting all weights "
"to be in the same device for now, got both: "
f"{model_device} and {_device}"
)
else:
model_device = _device

if name in existing_param_names:
module_name, weight_name = name.rsplit(".", 1)
module = named_modules[module_name]
setattr(
module,
weight_name,
torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)),
)

# recorded_weight_attr is
# {"weight_name": {"weight_attr_key": attr}}
# e.g.
# {
# {
# "layer.0.weight": {
# "weight_loader": weight_loader_function_object,
# "input_dim": 0, ...
# },
# "layer.1.weight": ...,
# }
# }
for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
for attr_name, attr in weight_attr_dict.items():
module_name, weight_name = full_weight_name.rsplit(".", 1)
module = named_modules[module_name]
weight = getattr(module, weight_name)
if not hasattr(weight, attr_name):
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))

# Step I1: reload bfloat16 / high precision weights
loaded_weights = model.load_weights(
model_loader.get_all_weights(model_config, model)


def _materialize_meta_tensor(meta_tensor: torch.Tensor) -> torch.Tensor:
tensor = torch.empty_strided(
size=tuple(meta_tensor.size()),
stride=tuple(meta_tensor.stride()),
dtype=meta_tensor.dtype,
device=meta_tensor._original_device,
requires_grad=meta_tensor.requires_grad,
)
tensor.__class__ = meta_tensor.__class__
tensor.__dict__ = meta_tensor.__dict__

# Step I2: online quantize the weights
# manually process weights after loading
model.process_weights_after_loading_already_called = False
process_weights_after_loading(model, model_config, model_device)
model.process_weights_after_loading_already_called = True
return loaded_weights
return tensor
Loading