Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Add unsupported auto-policy error message #5730

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from itertools import count
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Type, Union

import numpy as np
import torch
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: InferenceConfig,
verbose: bool = False,
model_policy: Policy = None,
model_policy: Union[Policy, Type[Policy]] = None,
) -> None:
self.inference_config = inference_config
self.dtype = inference_config.dtype
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(

self._verify_args()

def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Union[Policy, Type[Policy]] = None):
"""
Shard model or/and Load weight

Expand Down Expand Up @@ -149,11 +149,17 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
)

if model_policy is None:
if self.inference_config.pad_input:
model_type = "padding_" + self.model_config.model_type
else:
model_type = "nopadding_" + self.model_config.model_type
model_policy = model_policy_map[model_type]()
prefix = "nopadding" if not self.inference_config.pad_input else "padding"
yuehuayingxueluo marked this conversation as resolved.
Show resolved Hide resolved
model_policy_key = f"{prefix}_{getattr(self.model_config, 'model_type', None)}"
model_policy = model_policy_map.get(model_policy_key)

if not isinstance(model_policy, Policy):
try:
model_policy = model_policy()
except Exception as e:
raise ValueError(f"Unable to instantiate model policy: {e}")

assert isinstance(model_policy, Policy), f"Invalid type of model policy: {type(model_policy)}"
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)

Expand Down
Loading