diff --git a/engines/python/setup/djl_python/fastertransformer.py b/engines/python/setup/djl_python/fastertransformer.py index 15afc05703..069da69b7a 100644 --- a/engines/python/setup/djl_python/fastertransformer.py +++ b/engines/python/setup/djl_python/fastertransformer.py @@ -57,28 +57,33 @@ def __init__(self) -> None: self.model_id_or_path = None self.model = None self.is_t5 = False + self.use_triton = None def initialize_properties(self, properties): self.tensor_parallel_degree = int( properties.get("tensor_parallel_degree", 1)) self.pipeline_parallel_degree = int( properties.get("pipeline_parallel_degree", 1)) + self.dtype = properties.get("data_type", self.dtype) self.dtype = properties.get("dtype", "fp32") self.model_id_or_path = properties.get("model_id") or properties.get( "model_dir") + self.use_triton = properties.get("engine", "Python") == "Python" def initialize(self, properties): self.initialize_properties(properties) self.model = self.load_model() - self.initialized = True model_config = AutoConfig.from_pretrained(self.model_id_or_path) self.is_t5 = model_config.model_type == "t5" + self.initialized = True def load_model(self): logging.info(f"Loading model: {self.model_id_or_path}") return ft.init_inference(self.model_id_or_path, self.tensor_parallel_degree, - self.pipeline_parallel_degree, self.dtype) + self.pipeline_parallel_degree, + self.dtype, + use_triton=self.use_triton) @staticmethod def param_mapper(parameters: dict): @@ -120,18 +125,26 @@ def inference(self, inputs: Input): ) parameters = self.param_mapper(parameters) - if self.is_t5: - result = self.model.pipeline_generate(input_data, **parameters) + max_length = parameters.pop("max_length", 50) + output_len = parameters.pop("max_seq_len", max_length) + if self.use_triton: + output_length = [output_len] * len(input_data) + result = self.model.pipeline_generate(input_data, + output_length, + **parameters) else: - output_len = parameters.pop("max_seq_len", 50) - beam_width = parameters.pop("beam_width", 1) - # TODO: remove after fixes in FT python package - result = self.model.pipeline_generate( - input_data, - batch_size=len(input_data), - output_len=output_len, - beam_width=beam_width, - **parameters) + if self.is_t5: + result = self.model.pipeline_generate( + input_data, **parameters) + else: + beam_width = parameters.pop("beam_width", 1) + # TODO: remove after fixes in FT python package + result = self.model.pipeline_generate( + input_data, + batch_size=len(input_data), + output_len=output_len, + beam_width=beam_width, + **parameters) offset = 0 outputs = Output() diff --git a/tests/integration/llm/fastertransformer-model.py b/tests/integration/llm/fastertransformer-model.py index 36d7681e7f..da6e3189d6 100644 --- a/tests/integration/llm/fastertransformer-model.py +++ b/tests/integration/llm/fastertransformer-model.py @@ -9,7 +9,7 @@ def load_model(properties): tensor_parallel_degree = properties["tensor_parallel_degree"] pipeline_parallel_degree = 1 # TODO: add tests for pp_degree > 1 model_id = properties.get('model_id') or properties.get('model_dir') - use_triton = properties.get("use_triton", False) + use_triton = properties.get("engine") != "FasterTransformer" dtype = properties.get("dtype", "fp32") return fastertransformer.init_inference(model_id, tensor_parallel_degree, diff --git a/tests/integration/llm/prepare.py b/tests/integration/llm/prepare.py index 165b9c929f..775f80bd5c 100644 --- a/tests/integration/llm/prepare.py +++ b/tests/integration/llm/prepare.py @@ -191,19 +191,23 @@ ft_model_list = { "t5-small": { + "engine": "FasterTransformer", "option.model_id": "t5-small", "option.tensor_parallel_degree": 4, }, "gpt2-xl": { + "engine": "FasterTransformer", "option.model_id": "gpt2-xl", "option.tensor_parallel_degree": 1, }, "facebook/opt-6.7b": { + "engine": "FasterTransformer", "option.model_id": "s3://djl-llm/opt-6b7/", "option.tensor_parallel_degree": 4, "option.dtype": "fp16", }, "bigscience/bloom-3b": { + "engine": "FasterTransformer", "option.model_id": "s3://djl-llm/bloom-3b/", "option.tensor_parallel_degree": 2, "option.dtype": "fp16", @@ -212,8 +216,7 @@ "nomic-ai/gpt4all-j": { "option.model_id": "s3://djl-llm/gpt4all-j/", "option.tensor_parallel_degree": 4, - "option.dtype": "fp32", - "option.use_triton": True + "option.dtype": "fp32" } } @@ -415,9 +418,9 @@ def build_ft_raw_model(model): f"{model} is not one of the supporting handler {list(ft_model_list.keys())}" ) options = ft_model_list[model] - options["engine"] = "FasterTransformer" - if "option.use_triton" in options and options["option.use_triton"]: + if "engine" not in options: options["engine"] = "Python" + write_properties(options) shutil.copyfile("llm/fastertransformer-model.py", "models/test/model.py") diff --git a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java index 7a660080fc..ef325e9758 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java @@ -42,17 +42,20 @@ public final class LmiUtils { private static final List DEEPSPEED_MODELS = List.of( - "roberta", - "xlm-roberta", - "gpt2", "bert", + "bloom", "gpt_neo", + "gpt_neox", + "gpt2", "gptj", "opt", - "gpt_neox", - "bloom"); + "roberta", + "xlm-roberta"); + + private static final List FASTERTRANSFORMER_MODELS = + List.of("bloom", "gpt2", "t5", "opt"); - private static final List FASTERTRANSFORMER_MODELS = List.of("t5"); + private static final List TRITON_MODELS = List.of("gptj"); private LmiUtils() {} @@ -97,8 +100,12 @@ static String inferLmiEngine(ModelInfo modelInfo) throws ModelException { engineName = "Python"; } else if (isDeepSpeedRecommended(modelType)) { engineName = "DeepSpeed"; + } else if (isTritonRecommended(modelType)) { + engineName = "Python"; + prop.setProperty("option.entryPoint", "djl_python.fastertransformer"); } else if (isFasterTransformerRecommended(modelType)) { engineName = "FasterTransformer"; + prop.setProperty("engine", engineName); // FT handler requires engine property } else { engineName = "Python"; } @@ -158,9 +165,14 @@ private static HuggingFaceModelConfig getHuggingFaceModelConfig(ModelInfo } } + private static boolean isTritonRecommended(String modelType) { + return isPythonDependencyInstalled("/opt/tritonserver/lib/", "tritontoolkit") + && TRITON_MODELS.contains(modelType); + } + private static boolean isFasterTransformerRecommended(String modelType) { return isPythonDependencyInstalled( - "/usr/local/backends/fastertransformer", "fastertransformer") + "/opt/tritonserver/backends/fastertransformer/", "fastertransformer") && FASTERTRANSFORMER_MODELS.contains(modelType); }