-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feat]Tensor Model Parallel Support For Inference (#5563)
* tensor parallel support naive source * [fix]precision, model load and refactor the framework * add tp unit test * docstring * fix do_sample
- Loading branch information
Showing
8 changed files
with
634 additions
and
144 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import logging | ||
import os | ||
from functools import reduce | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from colossalai.checkpoint_io.general_checkpoint_io import GeneralCheckpointIO | ||
from colossalai.checkpoint_io.index_file import CheckpointIndexFile | ||
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model | ||
from colossalai.cluster import DistCoordinator | ||
from colossalai.interface import ModelWrapper | ||
|
||
try: | ||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX | ||
except ImportError: | ||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state" | ||
|
||
|
||
class InferCheckpoint_io(GeneralCheckpointIO): | ||
""" | ||
This class is for inference model loading, most codes are copied from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io.HybridParallelCheckpointIO. | ||
Origin HybridParallelCheckpointIO contains some codes about MixPrecision-Training, so we remove them and build a relatively clean class specifically for Inference. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
verbose: bool = True, | ||
) -> None: | ||
super().__init__() | ||
self.verbose = verbose | ||
self.coordinator = DistCoordinator() | ||
|
||
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False): | ||
""" | ||
Load sharded model with the given path to index file of checkpoint folder. | ||
Args: | ||
model (nn.Module): The model to be loaded. | ||
checkpoint_index_file (str): Path to the index file of checkpointing folder. | ||
strict (bool, optional): For name matching during loading state_dict. Defaults to False. | ||
This argument should be manually set to False since params on same device might be stored in different files. | ||
""" | ||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!" | ||
model = model.unwrap() | ||
|
||
# Check whether the checkpoint uses safetensors. | ||
use_safetensors = False | ||
if "safetensors" in checkpoint_index_file.name: | ||
use_safetensors = True | ||
|
||
if use_safetensors and not is_safetensors_available(): | ||
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.") | ||
|
||
# Read checkpoint index file. | ||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) | ||
ckpt_root_path = ckpt_index_file.root_path | ||
weight_map = ckpt_index_file.weight_map | ||
strict = False | ||
|
||
# Load params & buffers to model. | ||
# Keep a record of loaded files so that file will not be repeatedly loaded. | ||
loaded_file = set() | ||
|
||
missing_keys = [] | ||
missing_file_keys = [] | ||
|
||
def _load(name: str): | ||
if name not in weight_map: | ||
missing_file_keys.append(name) | ||
return | ||
filename = weight_map[name] | ||
|
||
# If this param/buffer has been loaded before, directly return. | ||
if filename in loaded_file: | ||
return | ||
|
||
file_path = os.path.join(ckpt_root_path, filename) | ||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors) | ||
|
||
load_state_dict_into_model( | ||
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True | ||
) | ||
loaded_file.add(filename) | ||
|
||
# Load parameters. | ||
for name, _ in model.named_parameters(): | ||
_load(name) | ||
|
||
# Load buffers. | ||
non_persistent_buffers = set() | ||
for n, m in model.named_modules(): | ||
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set) | ||
for name, buf in model.named_buffers(): | ||
if buf is not None and name not in non_persistent_buffers: | ||
_load(name) | ||
|
||
# Load extra states. | ||
extra_state_key = _EXTRA_STATE_KEY_SUFFIX | ||
if ( | ||
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) | ||
is not torch.nn.Module.get_extra_state | ||
): | ||
_load(extra_state_key) | ||
|
||
if self.verbose and self.coordinator.is_master(): | ||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.") | ||
|
||
if len(missing_keys) == 0: | ||
raise RuntimeError( | ||
"No weigth is loaded into the model. Please check the checkpoint files and the model structure." | ||
) | ||
|
||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys)) | ||
remain_keys = remain_keys.union(set(missing_file_keys)) | ||
if len(remain_keys) > 0: | ||
if strict: | ||
error_msgs = "Missing key(s) in state_dict: {}. ".format( | ||
", ".join('"{}"'.format(k) for k in missing_keys) | ||
) | ||
raise RuntimeError( | ||
"Error(s) in loading state_dict for {}:\n\t{}".format( | ||
self.__class__.__name__, "\n\t".join(error_msgs) | ||
) | ||
) | ||
else: | ||
if self.coordinator.is_master(): | ||
logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}") | ||
|
||
def save_sharded_model( | ||
self, | ||
model: ModelWrapper, | ||
checkpoint: str, | ||
gather_dtensor: bool = True, | ||
prefix: Optional[str] = None, | ||
size_per_shard: int = 1024, | ||
use_safetensors: bool = False, | ||
) -> None: | ||
return NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.