Skip to content

Commit

Permalink
[WIP] Updating deepspeed handler for more models (#359)
Browse files Browse the repository at this point in the history
* Updating deepspeed handler for more models

* Support additional configs through serving.properties, add initial bloom support

* fail on unsupported dtype, make some methods private
  • Loading branch information
siddvenk committed Dec 7, 2022
1 parent e4a3bce commit b0f7364
Show file tree
Hide file tree
Showing 3 changed files with 356 additions and 64 deletions.
290 changes: 226 additions & 64 deletions engines/python/setup/djl_python/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,98 +10,261 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import json
import logging
import os
from typing import Optional

import torch
from transformers import (
AutoConfig,
PretrainedConfig,
AutoTokenizer,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoModelForQuestionAnswering,
AutoModelForMaskedLM,
AutoModelForTokenClassification,
pipeline,
Conversation,
SquadExample
)
import transformers
from deepspeed.module_inject.replace_policy import (
HFBertLayerPolicy,
HFGPTNEOLayerPolicy,
GPTNEOXLayerPolicy,
HFGPTJLayerPolicy,
MegatronLayerPolicy,
HFGPT2LayerPolicy,
BLOOMLayerPolicy,
HFOPTLayerPolicy,
HFCLIPLayerPolicy,
)
import deepspeed
from djl_python.inputs import Input
from djl_python.outputs import Output
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from typing import Optional

import deepspeed
SUPPORTED_MODEL_TYPES = {
"roberta",
"gpt2",
"bert",
"gpt_neo",
"gptj",
"opt",
"gpt-neox",
"bloom",
}

SUPPORTED_TASKS = {
"text-generation",
"text-classification",
"question-answering",
"fill-mask",
"token-classification",
"conversational",
}

ARCHITECTURES_TO_TASK = {
"ForCausalLM": "text-generation",
"GPT2LMHeadModel": "text-generation",
"ForSequenceClassification": "text-classification",
"ForQuestionAnswering": "question-answering",
"ForMaskedLM": "fill-mask",
"ForTokenClassification": "token-classification",
"BloomModel": "text-generation",
}

TASK_TO_MODEL = {
"text-generation": AutoModelForCausalLM,
"text-classification": AutoModelForSequenceClassification,
"question-answering": AutoModelForQuestionAnswering,
"fill-mask": AutoModelForMaskedLM,
"token-classification": AutoModelForTokenClassification,
"conversational": AutoModelForCausalLM,
}

MODEL_TYPE_TO_INJECTION_POLICY = {
"roberta": {transformers.models.roberta.modeling_roberta.RobertaLayer: HFBertLayerPolicy},
"gpt2": {transformers.models.gpt2.modeling_gpt2.GPT2Block: HFGPT2LayerPolicy},
"bert": {transformers.models.bert.modeling_bert.BertLayer: HFBertLayerPolicy},
"gpt_neo": {transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock: HFGPTNEOLayerPolicy},
"gptj": {transformers.models.gptj.modeling_gptj.GPTJBlock: HFGPTJLayerPolicy},
"opt": {transformers.models.opt.modeling_opt.OPTDecoderLayer: HFOPTLayerPolicy},
"gpt-neox": {transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer: GPTNEOXLayerPolicy},
"bloom": {transformers.models.bloom.modeling_bloom.BloomBlock: BLOOMLayerPolicy},
}


def get_torch_dtype_from_str(dtype: str):
if dtype == "fp32":
return torch.float32
if dtype == "fp16":
return torch.float16
elif dtype == "bf16":
return torch.bfloat16
elif dtype == "int8":
return torch.int8
else:
raise ValueError(f"Invalid data type: {dtype}")


class DeepSpeedService(object):

def __init__(self):
self.predictor = None
self.max_new_tokens = 0
self.pipeline = None
self.initialized = False
self.ds_config = None
self.task = None
self.logger = logging.getLogger()
self.model_dir = None
self.model_id = None
self.data_type = None
self.max_tokens = None
self.device = None
self.world_size = None
self.tensor_parallel_degree = None
self.model_config = None
self.low_cpu_mem_usage = False

def initialize(self, properties: dict):
self.max_new_tokens = int(properties.get("max_new_tokens", "50"))
model_dir = properties.get("model_dir")
data_type = properties.get("data_type", "fp32")
mp_size = int(properties.get("tensor_parallel_degree", "1"))
model_id = properties.get("model_id")
# LOCAL_RANK env is initialized after constructor
device = int(os.getenv('LOCAL_RANK', '0'))
if not model_id:
model_id = model_dir
config_file = os.path.join(model_dir, "config.json")
self.parse_properties(properties)
self.validate_model_type_and_task()
self.create_model_pipeline()
self.logger.info(f"Initialized DeepSpeed model with the following configurations"
f"model: {self.model_id}"
f"task: {self.task}"
f"data_type: {self.data_type}"
f"tensor_parallel_degree: {self.tensor_parallel_degree}")
self.initialized = True

def _parse_properties(self, properties):
self.model_dir = properties.get("model_dir")
self.model_id = properties.get("model_id")
self.task = properties.get("task")
self.data_type = get_torch_dtype_from_str(properties.get("data_type", "fp32"))
self.max_tokens = int(properties.get("max_tokens", 1024))
self.device = int(os.getenv("LOCAL_RANK", 0))
self.world_size = int(os.getenv("WORLD_SIZE", 1))
self.tensor_parallel_degree = int(properties.get("tensor_parallel_degree", self.world_size))
self.low_cpu_mem_usage = properties.get("low_cpu_mem_usage", "true").lower() == "true"
self.ds_config = {
"replace_with_kernel_inject": True,
"dtype": self.data_type,
"mp_size": self.tensor_parallel_degree,
"mpu": None,
"enable_cuda_graph": properties.get("enable_cuda_graph", "false").lower() == "true",
"triangular_masking": properties.get("triangular_masking", "true").lower() == "true",
"checkpoint": properties.get("checkpoint"),
"base_dir": properties.get("base_dir"),
"return_tuple": properties.get("return_tuple", "true").lower() == "true",
"training_mp_size": int(properties.get("training_mp_size", 1)),
"replace_method": "auto",
"injection_policy": None,
"max_tokens": self.max_tokens,
}

def _validate_model_type_and_task(self):
if not self.model_id:
self.model_id = self.model_dir
config_file = os.path.join(self.model_id, "config.json")
if not os.path.exists(config_file):
raise ValueError(
"config.json file is required for DeepSpeed model")

with open(config_file, "r") as f:
config = json.load(f)
architectures = config.get("architectures")
if not architectures:
raise ValueError(
"No architectures found in config.json file")
# TODO: check all supported architectures

logging.info(
f"Init: {model_id}, tensor_parallel_degree={mp_size}, data_type={data_type}, "
f"device={device}, max_new_tokenx={self.max_new_tokens}")

model = AutoModelForCausalLM.from_pretrained(model_id,
low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
if data_type == "fp16":
model.half()

model = deepspeed.init_inference(model,
mp_size=mp_size,
dtype=model.dtype,
replace_method='auto',
replace_with_kernel_inject=True)
self.predictor = pipeline(task='text-generation',
model=model,
tokenizer=tokenizer,
device=device)
raise ValueError(f"model_dir: {self.model_id} does not contain a config.json. "
f"This is required for loading models from local storage")
self.model_config = AutoConfig.from_pretrained(config_file)
else:
self.model_config = AutoConfig.from_pretrained(self.model_id)

self.initialized = True
if self.model_config.model_type not in SUPPORTED_MODEL_TYPES:
raise ValueError(f"model_type: {self.model_config.model_type} is not currently supported by DeepSpeed")

if not self.task:
self.logger.warning("No task provided. Attempting to infer from model architecture")
self.infer_task_from_model_architecture(self.model_config)
if self.task not in SUPPORTED_TASKS:
raise ValueError(f"task: {self.task} is not currently supported by DeepSpeed")

def infer_task_from_model_architecture(self, config: PretrainedConfig):
architecture = config.architectures[0]
for arch_option in ARCHITECTURES_TO_TASK:
if architecture.endswith(arch_option):
self.task = ARCHITECTURES_TO_TASK[arch_option]

if not self.task:
raise ValueError(f"Task could not be inferred from model config. "
f"Please manually set `task` in serving.properties")

def create_model_pipeline(self):
# If a ds checkpoint is provided, we instantiate model with meta tensors. weights loaded when DS engine invoked
if self.ds_config["checkpoint"]:
dtype = torch.float32 if self.data_type == torch.float32 else torch.float16
with deepspeed.OnDevice(dtype=dtype, device="meta"):
model = TASK_TO_MODEL[self.task].from_config(self.model_config)
else:
model = TASK_TO_MODEL[self.task].from_pretrained(self.model_id, low_cpu_mem_usage=self.low_cpu_mem_usage)

model.eval()
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.pipeline = pipeline(task=self.task, model=model, tokenizer=tokenizer, device=self.device)
if self.model_config.model_type in MODEL_TYPE_TO_INJECTION_POLICY:
self.ds_config["injection_policy"] = MODEL_TYPE_TO_INJECTION_POLICY[self.model_config.model_type]
engine = deepspeed.init_inference(self.pipeline.model, **self.ds_config)
self.pipeline.model = engine.module

def format_input_for_task(self, input_values):
if not isinstance(input_values, list):
input_values = [input_values]

batch_inputs = []
for val in input_values:
if self.task == "conversational":
current_input = Conversation(
text=val.get("text"),
conversation_id=val.get("conversation_id"),
past_user_inputs=val.get("past_user_inputs", []),
generated_responses=val.get("generated_responses", [])
)
elif self.task == "question-answering":
current_input = SquadExample(
None,
val.get("context"),
val.get("question"),
None,
None,
None
)
else:
current_input = val
batch_inputs += [current_input]
return batch_inputs

def inference(self, inputs: Input):
try:
content_type = inputs.get_property("Content-Type")
model_kwargs = {}
if content_type is not None and content_type == "application/json":
json_input = inputs.get_as_json()
if isinstance(json_input, dict):
max_tokens = json_input.pop("max_new_tokens",
self.max_new_tokens)
data = json_input.pop("inputs", json_input)
input_data = self.format_input_for_task(json_input.pop("inputs"))
model_kwargs = json_input
else:
max_tokens = self.max_new_tokens
data = json_input
input_data = json_input
else:
data = inputs.get_as_string()
max_tokens = self.max_new_tokens
input_data = inputs.get_as_string()

result = self.predictor(data,
do_sample=True,
max_new_tokens=max_tokens)
result = self.pipeline(input_data, **model_kwargs)
if self.task == "conversational":
result = {
"generated_text": result.generated_responses[-1],
"conversation": {
"past_user_inputs": result.past_user_inputs,
"generated_responses": result.generated_responses,
},
}

outputs = Output()
outputs.add(result)
except Exception as e:
logging.exception("DeepSpeed inference failed")
# error handling
outputs = Output().error(str(e))

outputs = Output().error((str(e)))
return outputs


Expand All @@ -113,7 +276,6 @@ def handle(inputs: Input) -> Optional[Output]:
_service.initialize(inputs.get_properties())

if inputs.is_empty():
# initialization request
return None

return _service.inference(inputs)
return _service.inference(inputs)
Loading

0 comments on commit b0f7364

Please sign in to comment.