diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 646b3cede4ae..96c2b15ee16e 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -1,6 +1,6 @@ import time from itertools import count -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union import numpy as np import torch @@ -64,7 +64,7 @@ def __init__( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], inference_config: InferenceConfig, verbose: bool = False, - model_policy: Policy = None, + model_policy: Union[Policy, Type[Policy]] = None, ) -> None: self.inference_config = inference_config self.dtype = inference_config.dtype @@ -105,7 +105,7 @@ def __init__( self._verify_args() - def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None): + def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None): """ Shard model or/and Load weight @@ -150,11 +150,17 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy ) if model_policy is None: - if self.inference_config.pad_input: - model_type = "padding_" + self.model_config.model_type - else: - model_type = "nopadding_" + self.model_config.model_type - model_policy = model_policy_map[model_type]() + prefix = "nopadding" if not self.inference_config.pad_input else "padding" + model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}" + model_policy = model_policy_map.get(model_policy_key) + + if not isinstance(model_policy, Policy): + try: + model_policy = model_policy() + except Exception as e: + raise ValueError(f"Unable to instantiate model policy: {e}") + + assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}" pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size) tp_group = pg_mesh.get_group_along_axis(TP_AXIS)