diff --git a/engines/python/setup/djl_python/deepspeed.py b/engines/python/setup/djl_python/deepspeed.py index 279e427b6..8116fefc2 100644 --- a/engines/python/setup/djl_python/deepspeed.py +++ b/engines/python/setup/djl_python/deepspeed.py @@ -10,98 +10,261 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. - -import json import logging import os -from typing import Optional - +import torch +from transformers import ( + AutoConfig, + PretrainedConfig, + AutoTokenizer, + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoModelForQuestionAnswering, + AutoModelForMaskedLM, + AutoModelForTokenClassification, + pipeline, + Conversation, + SquadExample +) +import transformers +from deepspeed.module_inject.replace_policy import ( + HFBertLayerPolicy, + HFGPTNEOLayerPolicy, + GPTNEOXLayerPolicy, + HFGPTJLayerPolicy, + MegatronLayerPolicy, + HFGPT2LayerPolicy, + BLOOMLayerPolicy, + HFOPTLayerPolicy, + HFCLIPLayerPolicy, +) +import deepspeed from djl_python.inputs import Input from djl_python.outputs import Output -from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer +from typing import Optional -import deepspeed +SUPPORTED_MODEL_TYPES = { + "roberta", + "gpt2", + "bert", + "gpt_neo", + "gptj", + "opt", + "gpt-neox", + "bloom", +} + +SUPPORTED_TASKS = { + "text-generation", + "text-classification", + "question-answering", + "fill-mask", + "token-classification", + "conversational", +} + +ARCHITECTURES_TO_TASK = { + "ForCausalLM": "text-generation", + "GPT2LMHeadModel": "text-generation", + "ForSequenceClassification": "text-classification", + "ForQuestionAnswering": "question-answering", + "ForMaskedLM": "fill-mask", + "ForTokenClassification": "token-classification", + "BloomModel": "text-generation", +} + +TASK_TO_MODEL = { + "text-generation": AutoModelForCausalLM, + "text-classification": AutoModelForSequenceClassification, + "question-answering": AutoModelForQuestionAnswering, + "fill-mask": AutoModelForMaskedLM, + "token-classification": AutoModelForTokenClassification, + "conversational": AutoModelForCausalLM, +} + +MODEL_TYPE_TO_INJECTION_POLICY = { + "roberta": {transformers.models.roberta.modeling_roberta.RobertaLayer: HFBertLayerPolicy}, + "gpt2": {transformers.models.gpt2.modeling_gpt2.GPT2Block: HFGPT2LayerPolicy}, + "bert": {transformers.models.bert.modeling_bert.BertLayer: HFBertLayerPolicy}, + "gpt_neo": {transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock: HFGPTNEOLayerPolicy}, + "gptj": {transformers.models.gptj.modeling_gptj.GPTJBlock: HFGPTJLayerPolicy}, + "opt": {transformers.models.opt.modeling_opt.OPTDecoderLayer: HFOPTLayerPolicy}, + "gpt-neox": {transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer: GPTNEOXLayerPolicy}, + "bloom": {transformers.models.bloom.modeling_bloom.BloomBlock: BLOOMLayerPolicy}, +} + + +def get_torch_dtype_from_str(dtype: str): + if dtype == "fp32": + return torch.float32 + if dtype == "fp16": + return torch.float16 + elif dtype == "bf16": + return torch.bfloat16 + elif dtype == "int8": + return torch.int8 + else: + raise ValueError(f"Invalid data type: {dtype}") class DeepSpeedService(object): def __init__(self): - self.predictor = None - self.max_new_tokens = 0 + self.pipeline = None self.initialized = False + self.ds_config = None + self.task = None + self.logger = logging.getLogger() + self.model_dir = None + self.model_id = None + self.data_type = None + self.max_tokens = None + self.device = None + self.world_size = None + self.tensor_parallel_degree = None + self.model_config = None + self.low_cpu_mem_usage = False def initialize(self, properties: dict): - self.max_new_tokens = int(properties.get("max_new_tokens", "50")) - model_dir = properties.get("model_dir") - data_type = properties.get("data_type", "fp32") - mp_size = int(properties.get("tensor_parallel_degree", "1")) - model_id = properties.get("model_id") - # LOCAL_RANK env is initialized after constructor - device = int(os.getenv('LOCAL_RANK', '0')) - if not model_id: - model_id = model_dir - config_file = os.path.join(model_dir, "config.json") + self.parse_properties(properties) + self.validate_model_type_and_task() + self.create_model_pipeline() + self.logger.info(f"Initialized DeepSpeed model with the following configurations" + f"model: {self.model_id}" + f"task: {self.task}" + f"data_type: {self.data_type}" + f"tensor_parallel_degree: {self.tensor_parallel_degree}") + self.initialized = True + + def _parse_properties(self, properties): + self.model_dir = properties.get("model_dir") + self.model_id = properties.get("model_id") + self.task = properties.get("task") + self.data_type = get_torch_dtype_from_str(properties.get("data_type", "fp32")) + self.max_tokens = int(properties.get("max_tokens", 1024)) + self.device = int(os.getenv("LOCAL_RANK", 0)) + self.world_size = int(os.getenv("WORLD_SIZE", 1)) + self.tensor_parallel_degree = int(properties.get("tensor_parallel_degree", self.world_size)) + self.low_cpu_mem_usage = properties.get("low_cpu_mem_usage", "true").lower() == "true" + self.ds_config = { + "replace_with_kernel_inject": True, + "dtype": self.data_type, + "mp_size": self.tensor_parallel_degree, + "mpu": None, + "enable_cuda_graph": properties.get("enable_cuda_graph", "false").lower() == "true", + "triangular_masking": properties.get("triangular_masking", "true").lower() == "true", + "checkpoint": properties.get("checkpoint"), + "base_dir": properties.get("base_dir"), + "return_tuple": properties.get("return_tuple", "true").lower() == "true", + "training_mp_size": int(properties.get("training_mp_size", 1)), + "replace_method": "auto", + "injection_policy": None, + "max_tokens": self.max_tokens, + } + + def _validate_model_type_and_task(self): + if not self.model_id: + self.model_id = self.model_dir + config_file = os.path.join(self.model_id, "config.json") if not os.path.exists(config_file): - raise ValueError( - "config.json file is required for DeepSpeed model") - - with open(config_file, "r") as f: - config = json.load(f) - architectures = config.get("architectures") - if not architectures: - raise ValueError( - "No architectures found in config.json file") - # TODO: check all supported architectures - - logging.info( - f"Init: {model_id}, tensor_parallel_degree={mp_size}, data_type={data_type}, " - f"device={device}, max_new_tokenx={self.max_new_tokens}") - - model = AutoModelForCausalLM.from_pretrained(model_id, - low_cpu_mem_usage=True) - tokenizer = AutoTokenizer.from_pretrained(model_id) - if data_type == "fp16": - model.half() - - model = deepspeed.init_inference(model, - mp_size=mp_size, - dtype=model.dtype, - replace_method='auto', - replace_with_kernel_inject=True) - self.predictor = pipeline(task='text-generation', - model=model, - tokenizer=tokenizer, - device=device) + raise ValueError(f"model_dir: {self.model_id} does not contain a config.json. " + f"This is required for loading models from local storage") + self.model_config = AutoConfig.from_pretrained(config_file) + else: + self.model_config = AutoConfig.from_pretrained(self.model_id) - self.initialized = True + if self.model_config.model_type not in SUPPORTED_MODEL_TYPES: + raise ValueError(f"model_type: {self.model_config.model_type} is not currently supported by DeepSpeed") + + if not self.task: + self.logger.warning("No task provided. Attempting to infer from model architecture") + self.infer_task_from_model_architecture(self.model_config) + if self.task not in SUPPORTED_TASKS: + raise ValueError(f"task: {self.task} is not currently supported by DeepSpeed") + + def infer_task_from_model_architecture(self, config: PretrainedConfig): + architecture = config.architectures[0] + for arch_option in ARCHITECTURES_TO_TASK: + if architecture.endswith(arch_option): + self.task = ARCHITECTURES_TO_TASK[arch_option] + + if not self.task: + raise ValueError(f"Task could not be inferred from model config. " + f"Please manually set `task` in serving.properties") + + def create_model_pipeline(self): + # If a ds checkpoint is provided, we instantiate model with meta tensors. weights loaded when DS engine invoked + if self.ds_config["checkpoint"]: + dtype = torch.float32 if self.data_type == torch.float32 else torch.float16 + with deepspeed.OnDevice(dtype=dtype, device="meta"): + model = TASK_TO_MODEL[self.task].from_config(self.model_config) + else: + model = TASK_TO_MODEL[self.task].from_pretrained(self.model_id, low_cpu_mem_usage=self.low_cpu_mem_usage) + + model.eval() + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.pipeline = pipeline(task=self.task, model=model, tokenizer=tokenizer, device=self.device) + if self.model_config.model_type in MODEL_TYPE_TO_INJECTION_POLICY: + self.ds_config["injection_policy"] = MODEL_TYPE_TO_INJECTION_POLICY[self.model_config.model_type] + engine = deepspeed.init_inference(self.pipeline.model, **self.ds_config) + self.pipeline.model = engine.module + + def format_input_for_task(self, input_values): + if not isinstance(input_values, list): + input_values = [input_values] + + batch_inputs = [] + for val in input_values: + if self.task == "conversational": + current_input = Conversation( + text=val.get("text"), + conversation_id=val.get("conversation_id"), + past_user_inputs=val.get("past_user_inputs", []), + generated_responses=val.get("generated_responses", []) + ) + elif self.task == "question-answering": + current_input = SquadExample( + None, + val.get("context"), + val.get("question"), + None, + None, + None + ) + else: + current_input = val + batch_inputs += [current_input] + return batch_inputs def inference(self, inputs: Input): try: content_type = inputs.get_property("Content-Type") + model_kwargs = {} if content_type is not None and content_type == "application/json": json_input = inputs.get_as_json() if isinstance(json_input, dict): - max_tokens = json_input.pop("max_new_tokens", - self.max_new_tokens) - data = json_input.pop("inputs", json_input) + input_data = self.format_input_for_task(json_input.pop("inputs")) + model_kwargs = json_input else: - max_tokens = self.max_new_tokens - data = json_input + input_data = json_input else: - data = inputs.get_as_string() - max_tokens = self.max_new_tokens + input_data = inputs.get_as_string() - result = self.predictor(data, - do_sample=True, - max_new_tokens=max_tokens) + result = self.pipeline(input_data, **model_kwargs) + if self.task == "conversational": + result = { + "generated_text": result.generated_responses[-1], + "conversation": { + "past_user_inputs": result.past_user_inputs, + "generated_responses": result.generated_responses, + }, + } outputs = Output() outputs.add(result) except Exception as e: logging.exception("DeepSpeed inference failed") - # error handling - outputs = Output().error(str(e)) - + outputs = Output().error((str(e))) return outputs @@ -113,7 +276,6 @@ def handle(inputs: Input) -> Optional[Output]: _service.initialize(inputs.get_properties()) if inputs.is_empty(): - # initialization request return None - return _service.inference(inputs) + return _service.inference(inputs) \ No newline at end of file diff --git a/engines/python/setup/djl_python/stable-diffusion.py b/engines/python/setup/djl_python/stable-diffusion.py new file mode 100644 index 000000000..4291c5082 --- /dev/null +++ b/engines/python/setup/djl_python/stable-diffusion.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python +# +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. +import logging +import os +import torch +from diffusers import DiffusionPipeline +import deepspeed +from djl_python.inputs import Input +from djl_python.outputs import Output +from typing import Optional + + +class StableDiffusionService(object): + + def __init__(self): + self.pipeline = None + self.initialized = False + self.ds_config = None + self.logger = logging.getLogger() + self.model_dir = None + self.model_id = None + self.data_type = None + self.device = None + self.world_size = None + self.max_tokens = None + self.tensor_parallel_degree = None + self.save_image_dir = None + + def initialize(self, properties: dict): + self.model_dir = properties.get("model_dir") + self.model_id = properties.get("model_id") + self.data_type = properties.get("data_type", "fp32") + self.max_tokens = int(properties.get("max_tokens", "1024")) + self.device = int(os.getenv("LOCAL_RANK", "0")) + self.world_size = int(os.getenv("TENSOR_PARALLEL_DEGREE", "1")) + self.tensor_parallel_degree = int(properties.get("tensor_parallel_degree", self.world_size)) + self.save_image_dir = properties.get("save_image_dir", "output_images") + self.ds_config = { + "replace_with_kernel_inject": True, + # TODO: Figure out why cuda graph doesn't work for stable diffusion via DS + "enable_cuda_graph": False, + "replace_method": "auto", + "dtype": torch.float16 if self.data_type == "fp16" else torch.float32, + "mp_size": self.tensor_parallel_degree + } + + if not self.model_id: + config_file = os.path.join(self.model_dir, "model_index.json") + if not os.path.exists(config_file): + raise ValueError(f"model_dir: {self.model_dir} does not contain a model_index.json." + f"This is required for loading stable diffusion models from local storage") + self.model_id = self.model_dir + + kwargs = {} + if self.data_type == "fp16": + kwargs["torch_dtype"] = torch.float16 + kwargs["revision"] = "fp16" + + torch.set_grad_enabled(False) + pipeline = DiffusionPipeline.from_pretrained(self.model_id, **kwargs) + pipeline.to(f"cuda:{self.device}") + deepspeed.init_distributed() + engine = deepspeed.init_inference(getattr(pipeline, "model", pipeline), config=self.ds_config) + + if hasattr(pipeline, "model"): + pipeline.model = engine + + self.pipeline = pipeline + if not os.path.exists(self.save_image_dir): + os.mkdir(self.save_image_dir) + self.initialized = True + + def inference(self, inputs: Input): + try: + content_type = inputs.get_property("Content-Type") + if content_type is not None and content_type == "application/json": + json_input = inputs.get_as_json() + if isinstance(json_input, dict): + data = json_input.pop("inputs") + if isinstance(data, dict): + prompt = data.pop("query") + else: + prompt = data + else: + prompt = inputs.get_as_string() + else: + prompt = inputs.get_as_string() + + result = self.pipeline(prompt) + + saved_image_names = [] + for idx, img in enumerate(result.images): + save_path = os.path.join(self.save_image_dir, f"img-{idx}.png") + if self.device == 0: + img.save(save_path) + saved_image_names += [save_path] + + outputs = Output() + outputs.add_as_json(saved_image_names, "saved_images") + + except Exception as e: + logging.exception("DeepSpeed inference failed") + outputs = Output().error(str(e)) + return outputs + + + +_service = StableDiffusionService() + + +def handle(inputs: Input) -> Optional[Output]: + if not _service.initialized: + _service.initialize(inputs.get_properties()) + + if inputs.is_empty(): + return None + + return _service.inference(inputs) diff --git a/serving/docker/deepspeed.Dockerfile b/serving/docker/deepspeed.Dockerfile index 589dbe2da..c0b041c0f 100644 --- a/serving/docker/deepspeed.Dockerfile +++ b/serving/docker/deepspeed.Dockerfile @@ -16,6 +16,7 @@ ARG torch_version=1.12.1 ARG accelerate_version=0.13.2 ARG deepspeed_wheel="https://publish.djl.ai/deepspeed/deepspeed-0.7.5%2Bbf16-py2.py3-none-any.whl" ARG transformers_version=4.23.1 +ARG diffusers_version=0.7.2 EXPOSE 8080 @@ -47,6 +48,7 @@ RUN apt-get update && \ pip3 install ${deepspeed_wheel} && \ pip3 install transformers==${transformers_version} && \ pip3 install triton==1.0.0 mpi4py sentencepiece accelerate==${accelerate_version} bitsandbytes && \ + pip3 install diffusers[torch]==${diffusers_version} && \ scripts/patch_oss_dlc.sh python && \ scripts/security_patch.sh deepspeed && \ rm -rf scripts && \