Skip to content

Commit

Permalink
[Feat]Tensor Model Parallel Support For Inference (#5563)
Browse files Browse the repository at this point in the history
* tensor parallel support naive source

* [fix]precision, model load and refactor the framework

* add tp unit test

* docstring

* fix do_sample
  • Loading branch information
LRY89757 committed Apr 18, 2024
1 parent be396ad commit e37ee2f
Show file tree
Hide file tree
Showing 8 changed files with 634 additions and 144 deletions.
141 changes: 106 additions & 35 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,26 @@
import numpy as np
import torch
import torch.nn as nn
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast

from torch import distributed as dist
from transformers import (
AutoConfig,
AutoModelForCausalLM,
GenerationConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM

from colossalai.accelerator import get_accelerator
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
from colossalai.inference.utils import get_model_size, has_index_file
from colossalai.interface import ModelWrapper
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
Expand All @@ -25,10 +36,10 @@

PP_AXIS, TP_AXIS = 0, 1

_supported_models = [
"LlamaForCausalLM",
"BaichuanForCausalLM",
]
_supported_models = {
"LlamaForCausalLM": LlamaForCausalLM,
"BaichuanForCausalLM": AutoModelForCausalLM,
}

_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]

Expand All @@ -39,7 +50,7 @@ class InferenceEngine:
InferenceEngine which manages the inference process..
Args:
model (nn.Module): Path or nn.Module of this model.
model_or_path (nn.Module or str): Path or nn.Module of this model.
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
Expand All @@ -48,26 +59,40 @@ class InferenceEngine:

def __init__(
self,
model: nn.Module,
model_or_path: Union[nn.Module, str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
verbose: bool = False,
model_policy: Policy = None,
) -> None:
self.inference_config = inference_config
self.model_config = model.config
self.model = model
self.device = torch.device("cuda")
self.dtype = inference_config.dtype
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.high_precision = inference_config.high_precision
self._verify_args()

self.verbose = verbose
self.logger = get_dist_logger(__name__)

self.init_model(model_or_path, model_policy)

self.generation_config = inference_config.to_generation_config(self.model_config)
model.eval()
model = model.to(self.dtype)
model = model.to(self.device)

self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token

self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?

self.counter = count()

self.use_cuda_graph = self.inference_config.use_cuda_graph
if self.use_cuda_graph:
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
if verbose:
self.logger.info("Colossal AI CUDA Graph Capture on")

self.capture_model(self.k_cache, self.v_cache)

# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
self.use_spec_dec = False
Expand All @@ -76,40 +101,83 @@ def __init__(
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens

self._verify_args()

def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
"""
Shard model or/and Load weight
Args:
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
model_policy (Policy): the policy to replace the model
"""

if isinstance(model_or_path, str):
try:
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
arch = getattr(hf_config, "architectures")[0]
model = _supported_models[arch](hf_config)
except Exception as e:
self.logger.error(
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
)
else:
model = model_or_path

self.model_config = model.config

torch.cuda.empty_cache()
init_gpu_memory = torch.cuda.mem_get_info()[0]

self.device = get_accelerator().get_current_device()
if self.verbose:
self.logger.info(f"the device is {self.device}")

model = model.to(self.dtype).eval()

if self.verbose:
self.logger.info(
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
)

if model_policy is None:
if self.inference_config.pad_input:
model_type = "padding_" + self.model_config.model_type
else:
model_type = "nopadding_" + self.model_config.model_type
model_policy = model_policy_map[model_type]()

pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)

self.model = self._shardformer(
model,
model_policy,
None,
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None,
tp_group=tp_group,
)

self.verbose = verbose
if verbose:
self.logger = get_dist_logger(__name__)
self.model = ModelWrapper(model).to(self.device)

self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
# DISCUSS maybe move this into batch info?
if self.verbose:
self.logger.info(
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
)

self.counter = count()
if isinstance(model_or_path, str):
from colossalai.inference.core.plugin import InferCheckpoint_io

self.use_cuda_graph = self.inference_config.use_cuda_graph
if self.use_cuda_graph:
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
if verbose:
self.logger.info("Colossal AI CUDA Graph Capture on")
cpt_io = InferCheckpoint_io()
if_has_index_file, model_index_file = has_index_file(model_or_path)
assert if_has_index_file, "the model path is invalid"
cpt_io.load_model(self.model, model_index_file)

self.capture_model(self.k_cache, self.v_cache)
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = init_gpu_memory - free_gpu_memory
if self.verbose:
self.logger.info(
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
)

@torch.inference_mode()
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
Expand Down Expand Up @@ -194,8 +262,11 @@ def _verify_args(self) -> None:
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
)
if self.model.__class__.__name__ not in _supported_models:
raise ValueError(f"Model {self.model.__class__.__name__} is not supported.")
if isinstance(self.model, ModelWrapper):
model = self.model.module
assert (
model.__class__.__name__ in _supported_models.keys()
), f"Model {self.model.__class__.__name__} is not supported."

def _shardformer(
self,
Expand Down
140 changes: 140 additions & 0 deletions colossalai/inference/core/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import logging
import os
from functools import reduce
from pathlib import Path
from typing import Optional

import torch

from colossalai.checkpoint_io.general_checkpoint_io import GeneralCheckpointIO
from colossalai.checkpoint_io.index_file import CheckpointIndexFile
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper

try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"


class InferCheckpoint_io(GeneralCheckpointIO):
"""
This class is for inference model loading, most codes are copied from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io.HybridParallelCheckpointIO.
Origin HybridParallelCheckpointIO contains some codes about MixPrecision-Training, so we remove them and build a relatively clean class specifically for Inference.
"""

def __init__(
self,
verbose: bool = True,
) -> None:
super().__init__()
self.verbose = verbose
self.coordinator = DistCoordinator()

def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
Args:
model (nn.Module): The model to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model = model.unwrap()

# Check whether the checkpoint uses safetensors.
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True

if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")

# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
strict = False

# Load params & buffers to model.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()

missing_keys = []
missing_file_keys = []

def _load(name: str):
if name not in weight_map:
missing_file_keys.append(name)
return
filename = weight_map[name]

# If this param/buffer has been loaded before, directly return.
if filename in loaded_file:
return

file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)

load_state_dict_into_model(
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
)
loaded_file.add(filename)

# Load parameters.
for name, _ in model.named_parameters():
_load(name)

# Load buffers.
non_persistent_buffers = set()
for n, m in model.named_modules():
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers:
_load(name)

# Load extra states.
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
_load(extra_state_key)

if self.verbose and self.coordinator.is_master():
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")

if len(missing_keys) == 0:
raise RuntimeError(
"No weigth is loaded into the model. Please check the checkpoint files and the model structure."
)

remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
remain_keys = remain_keys.union(set(missing_file_keys))
if len(remain_keys) > 0:
if strict:
error_msgs = "Missing key(s) in state_dict: {}. ".format(
", ".join('"{}"'.format(k) for k in missing_keys)
)
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
)
)
else:
if self.coordinator.is_master():
logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}")

def save_sharded_model(
self,
model: ModelWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
) -> None:
return NotImplementedError
6 changes: 3 additions & 3 deletions colossalai/inference/core/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo

fd_inter_tensor.initialize(
max_batch_size=max_n_tokens,
num_attn_heads=model_config.num_attention_heads,
num_attn_heads=model_config.num_attention_heads // inference_config.tp_size,
kv_max_split_num=kv_max_split_num,
head_dim=head_dim,
dtype=self.dtype,
Expand All @@ -150,7 +150,7 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
# which may cause bugs and this issue should be fixed later.
self.running_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
Expand All @@ -161,7 +161,7 @@ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedCo
device=device,
)
self.prefill_bb = BatchBucket(
num_heads=model_config.num_attention_heads,
num_heads=model_config.num_attention_heads // inference_config.tp_size,
head_dim=head_dim,
max_batch_size=self.max_batch_size,
max_length=inference_config.max_input_len + inference_config.max_output_len,
Expand Down
Loading

0 comments on commit e37ee2f

Please sign in to comment.