Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat]Tensor Model Parallel Support For Inference #5563

Merged
Merged
140 changes: 111 additions & 29 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,19 @@
import numpy as np
import torch
import torch.nn as nn
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from torch import distributed as dist
from transformers import AutoConfig, GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import Sequence
from colossalai.inference.utils import has_index_file
from colossalai.interface import ModelWrapper
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
Expand All @@ -24,9 +29,9 @@

PP_AXIS, TP_AXIS = 0, 1

_supported_models = [
"LlamaForCausalLM",
]
_supported_models = {
"LlamaForCausalLM": LlamaForCausalLM,
}

_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]

Expand All @@ -37,7 +42,7 @@ class InferenceEngine:
InferenceEngine which manages the inference process..

Args:
model (nn.Module): Path or nn.Module of this model.
model_or_path (nn.Module or str): Path or nn.Module of this model.
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
Expand All @@ -46,7 +51,7 @@ class InferenceEngine:

def __init__(
self,
model: nn.Module,
model_or_path: Union[nn.Module, str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
verbose: bool = False,
Expand All @@ -55,16 +60,72 @@ def __init__(
assert inference_config, "Please provide inference_config."
assert tokenizer, "Please provide a tokenizer, either a defined one or str"
self.inference_config = inference_config
self.model_config = model.config
self.device = torch.device("cuda")

self.dtype = inference_config.dtype
torch.set_default_dtype(self.dtype)
self.high_precision = inference_config.high_precision

self.verbose = verbose
if verbose:
self.logger = get_dist_logger(__name__)

# enable memory history, which will
# add tracebacks and event history to snapshots
# torch.cuda.memory._record_memory_history()
self.init_model(model_or_path, model_policy)
# torch.cuda.memory._dump_snapshot(f"my_snapshot_rank_{dist.get_rank()}.pickle")
LRY89757 marked this conversation as resolved.
Show resolved Hide resolved

self.generation_config = inference_config.to_generation_config(self.model_config)

self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.generation_config = inference_config.to_generation_config(self.model_config)
self.high_precision = inference_config.high_precision

self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?

self.counter = count()

self.use_cuda_graph = self.inference_config.use_cuda_graph
if self.use_cuda_graph:
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
if verbose:
self.logger.info("Colossal AI CUDA Graph Capture on")

self.capture_model(self.k_cache, self.v_cache)

def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
"""
Shard model or/and Load weight

Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model
"""

if isinstance(model_or_path, str):
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
arch = getattr(hf_config, "architectures", [])[0]
model = _supported_models[arch](hf_config)
else:
model = model_or_path

self.model_config = model.config

torch.cuda.empty_cache()
init_gpu_memory = torch.cuda.mem_get_info()[0]

self.device = get_accelerator().get_current_device()
if self.verbose:
self.logger.info(f"the device is {self.device}")

model = model.eval()
model = model.cuda()
model.to(self.dtype)

if self.verbose:
self.logger.info(
f"Before the shard, Rank: [{dist.get_rank()}], model size: {self.get_model_size(model)} GB, model's device is: {model.device}"
)

if model_policy is None:
if self.inference_config.pad_input:
Expand All @@ -73,33 +134,54 @@ def __init__(
model_type = "nopadding_" + self.model_config.model_type
model_policy = model_policy_map[model_type]()

pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)

self.model = self._shardformer(
model,
model_policy,
None,
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None,
tp_group=tp_group,
)

self.verbose = verbose
if verbose:
self.logger = get_dist_logger(__name__)
self.model = ModelWrapper(model).to(self.device)

self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?
if self.verbose:
self.logger.info(
f"After the shard, Rank: [{dist.get_rank()}], model size: {self.get_model_size(self.model)} GB, model's device is: {model.device}"
)

self.counter = count()
if isinstance(model_or_path, str):
from colossalai.inference.core.plugin import InferCheckpoint_io

self.use_cuda_graph = self.inference_config.use_cuda_graph
if self.use_cuda_graph:
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
if verbose:
self.logger.info("Colossal AI CUDA Graph Capture on")
# from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO

self.capture_model(self.k_cache, self.v_cache)
cpt_io = InferCheckpoint_io()
# cpt_io = HybridParallelCheckpointIO(dp_group=None, pp_group=None, tp_group=pg_mesh.get_group_along_axis(TP_AXIS), zero_stage=0, verbose=True)
if_has_index_file, model_index_file = has_index_file(model_or_path)
assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file)

free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {self.get_model_size(self.model)} GB"
)

def get_model_size(self, model: nn.Module):
"""Calculates the total size of the model weights (including biases) in bytes.

Args:
model: The PyTorch model to analyze.

Returns:
The total size of the model weights in bytes.
"""
total_size = 0
for key, param in model.named_parameters():
total_size += param.element_size() * param.numel()
return total_size / (1024**3)

@torch.inference_mode()
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
Expand Down Expand Up @@ -187,7 +269,7 @@ def _verify_config(self) -> None:
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
)
assert (
self.model.__class__.__name__ in _supported_models
self.model.__class__.__name__ in _supported_models.keys()
), f"Model {self.model.__class__.__name__} is not supported."

def _shardformer(
Expand Down
135 changes: 135 additions & 0 deletions colossalai/inference/core/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import logging
import os
from functools import reduce
from pathlib import Path
from typing import Optional

import torch

from colossalai.checkpoint_io.general_checkpoint_io import GeneralCheckpointIO
from colossalai.checkpoint_io.index_file import CheckpointIndexFile
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper

try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"


class InferCheckpoint_io(GeneralCheckpointIO):
LRY89757 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
verbose: bool = True,
) -> None:
super().__init__()
self.verbose = verbose
self.coordinator = DistCoordinator()

def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.

Args:
model (nn.Module): The model to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model = model.unwrap()

# Check whether the checkpoint uses safetensors.
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True

if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")

# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
strict = False

# Load params & buffers to model.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()

missing_keys = []
missing_file_keys = []

def _load(name: str):
if name not in weight_map:
missing_file_keys.append(name)
return
filename = weight_map[name]

# If this param/buffer has been loaded before, directly return.
if filename in loaded_file:
return

file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)

load_state_dict_into_model(
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
)
loaded_file.add(filename)

# Load parameters.
for name, _ in model.named_parameters():
_load(name)

# Load buffers.
non_persistent_buffers = set()
for n, m in model.named_modules():
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers:
_load(name)

# Load extra states.
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
_load(extra_state_key)

if self.verbose and self.coordinator.is_master():
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")

if len(missing_keys) == 0:
raise RuntimeError(
"No weigth is loaded into the model. Please check the checkpoint files and the model structure."
)

remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
remain_keys = remain_keys.union(set(missing_file_keys))
if len(remain_keys) > 0:
if strict:
error_msgs = "Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
)
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
)
)
else:
if self.coordinator.is_master():
logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}")

def save_sharded_model(
self,
model: ModelWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
return NotImplementedError
6 changes: 3 additions & 3 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo

fd_inter_tensor.initialize(
max_batch_size=self.max_batch_size,
num_attn_heads=model_config.num_attention_heads,
num_attn_heads=model_config.num_attention_heads // inference_config.tp_size,
kv_max_split_num=kv_max_split_num,
head_dim=head_dim,
dtype=self.dtype,
Expand All @@ -146,7 +146,7 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
Expand All @@ -157,7 +157,7 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo
device=device,
)
self.prefill_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
Expand Down
Loading