Skip to content

Commit

Permalink
[Fix/Inference] Add unsupported auto-policy error message (#5730)
Browse files Browse the repository at this point in the history
* [fix] auto policy error message

* trivial
  • Loading branch information
yuanheng-zhao committed May 20, 2024
1 parent 283c407 commit bdf9a00
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from itertools import count
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Type, Union

import numpy as np
import torch
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
verbose: bool = False,
model_policy: Policy = None,
model_policy: Union[Policy, Type[Policy]] = None,
) -> None:
self.inference_config = inference_config
self.dtype = inference_config.dtype
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(

self._verify_args()

def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
"""
Shard model or/and Load weight
Expand Down Expand Up @@ -150,11 +150,17 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
)

if model_policy is None:
if self.inference_config.pad_input:
model_type = "padding_" + self.model_config.model_type
else:
model_type = "nopadding_" + self.model_config.model_type
model_policy = model_policy_map[model_type]()
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
model_policy = model_policy_map.get(model_policy_key)

if not isinstance(model_policy, Policy):
try:
model_policy = model_policy()
except Exception as e:
raise ValueError(f"Unable to instantiate model policy: {e}")

assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)

Expand Down

0 comments on commit bdf9a00

Please sign in to comment.