Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Updating deepspeed handler for more models #359

Merged
merged 3 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 226 additions & 64 deletions engines/python/setup/djl_python/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,98 +10,261 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import json
import logging
import os
from typing import Optional

import torch
from transformers import (
AutoConfig,
PretrainedConfig,
AutoTokenizer,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoModelForQuestionAnswering,
AutoModelForMaskedLM,
AutoModelForTokenClassification,
pipeline,
Conversation,
SquadExample
)
import transformers
from deepspeed.module_inject.replace_policy import (
HFBertLayerPolicy,
HFGPTNEOLayerPolicy,
GPTNEOXLayerPolicy,
HFGPTJLayerPolicy,
MegatronLayerPolicy,
HFGPT2LayerPolicy,
BLOOMLayerPolicy,
HFOPTLayerPolicy,
HFCLIPLayerPolicy,
)
import deepspeed
from djl_python.inputs import Input
from djl_python.outputs import Output
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from typing import Optional

import deepspeed
SUPPORTED_MODEL_TYPES = {
"roberta",
"gpt2",
"bert",
"gpt_neo",
"gptj",
"opt",
"gpt-neox",
"bloom",
}

SUPPORTED_TASKS = {
"text-generation",
"text-classification",
"question-answering",
"fill-mask",
"token-classification",
"conversational",
}

ARCHITECTURES_TO_TASK = {
"ForCausalLM": "text-generation",
"GPT2LMHeadModel": "text-generation",
"ForSequenceClassification": "text-classification",
"ForQuestionAnswering": "question-answering",
"ForMaskedLM": "fill-mask",
"ForTokenClassification": "token-classification",
"BloomModel": "text-generation",
}

TASK_TO_MODEL = {
"text-generation": AutoModelForCausalLM,
"text-classification": AutoModelForSequenceClassification,
"question-answering": AutoModelForQuestionAnswering,
"fill-mask": AutoModelForMaskedLM,
"token-classification": AutoModelForTokenClassification,
"conversational": AutoModelForCausalLM,
}

MODEL_TYPE_TO_INJECTION_POLICY = {
"roberta": {transformers.models.roberta.modeling_roberta.RobertaLayer: HFBertLayerPolicy},
"gpt2": {transformers.models.gpt2.modeling_gpt2.GPT2Block: HFGPT2LayerPolicy},
"bert": {transformers.models.bert.modeling_bert.BertLayer: HFBertLayerPolicy},
"gpt_neo": {transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock: HFGPTNEOLayerPolicy},
"gptj": {transformers.models.gptj.modeling_gptj.GPTJBlock: HFGPTJLayerPolicy},
"opt": {transformers.models.opt.modeling_opt.OPTDecoderLayer: HFOPTLayerPolicy},
"gpt-neox": {transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer: GPTNEOXLayerPolicy},
"bloom": {transformers.models.bloom.modeling_bloom.BloomBlock: BLOOMLayerPolicy},
}


def get_torch_dtype_from_str(dtype: str):
if dtype == "fp32":
return torch.float32
if dtype == "fp16":
siddvenk marked this conversation as resolved.
Show resolved Hide resolved
return torch.float16
elif dtype == "bf16":
return torch.bfloat16
elif dtype == "int8":
return torch.int8
else:
raise ValueError(f"Invalid data type: {dtype}")


class DeepSpeedService(object):

def __init__(self):
self.predictor = None
self.max_new_tokens = 0
self.pipeline = None
self.initialized = False
self.ds_config = None
self.task = None
self.logger = logging.getLogger()
self.model_dir = None
self.model_id = None
self.data_type = None
self.max_tokens = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need min_tokens

self.device = None
self.world_size = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

world_size should be identical to tensor_parallel_degree

self.tensor_parallel_degree = None
self.model_config = None
self.low_cpu_mem_usage = False

def initialize(self, properties: dict):
self.max_new_tokens = int(properties.get("max_new_tokens", "50"))
model_dir = properties.get("model_dir")
data_type = properties.get("data_type", "fp32")
mp_size = int(properties.get("tensor_parallel_degree", "1"))
model_id = properties.get("model_id")
# LOCAL_RANK env is initialized after constructor
device = int(os.getenv('LOCAL_RANK', '0'))
if not model_id:
model_id = model_dir
config_file = os.path.join(model_dir, "config.json")
self.parse_properties(properties)
self.validate_model_type_and_task()
self.create_model_pipeline()
self.logger.info(f"Initialized DeepSpeed model with the following configurations"
f"model: {self.model_id}"
f"task: {self.task}"
f"data_type: {self.data_type}"
f"tensor_parallel_degree: {self.tensor_parallel_degree}")
self.initialized = True

def _parse_properties(self, properties):
self.model_dir = properties.get("model_dir")
self.model_id = properties.get("model_id")
self.task = properties.get("task")
self.data_type = get_torch_dtype_from_str(properties.get("data_type", "fp32"))
self.max_tokens = int(properties.get("max_tokens", 1024))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what should be a proper default to max_tokens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deepspeed uses 1024 as the default, so I think keeping it the same as that makes sense.

Though, i think task specific pipelines on HF side might override that and set other defaults.

self.device = int(os.getenv("LOCAL_RANK", 0))
self.world_size = int(os.getenv("WORLD_SIZE", 1))
self.tensor_parallel_degree = int(properties.get("tensor_parallel_degree", self.world_size))
self.low_cpu_mem_usage = properties.get("low_cpu_mem_usage", "true").lower() == "true"
self.ds_config = {
"replace_with_kernel_inject": True,
"dtype": self.data_type,
"mp_size": self.tensor_parallel_degree,
"mpu": None,
"enable_cuda_graph": properties.get("enable_cuda_graph", "false").lower() == "true",
"triangular_masking": properties.get("triangular_masking", "true").lower() == "true",
"checkpoint": properties.get("checkpoint"),
"base_dir": properties.get("base_dir"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why user need to set base_dir?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If base_dir is not set, DS will look for checkpoint in current directory. If checkpoint is in some child dir then base_dir helps to identify it (though it works if you specify the full path to checkpoint and leave base_dir empty). Not sure if we need to expose this, but if user is coming from DS world they might expect this.

"return_tuple": properties.get("return_tuple", "true").lower() == "true",
"training_mp_size": int(properties.get("training_mp_size", 1)),
"replace_method": "auto",
"injection_policy": None,
"max_tokens": self.max_tokens,
}

def _validate_model_type_and_task(self):
if not self.model_id:
self.model_id = self.model_dir
config_file = os.path.join(self.model_id, "config.json")
if not os.path.exists(config_file):
raise ValueError(
"config.json file is required for DeepSpeed model")

with open(config_file, "r") as f:
config = json.load(f)
architectures = config.get("architectures")
if not architectures:
raise ValueError(
"No architectures found in config.json file")
# TODO: check all supported architectures

logging.info(
f"Init: {model_id}, tensor_parallel_degree={mp_size}, data_type={data_type}, "
f"device={device}, max_new_tokenx={self.max_new_tokens}")

model = AutoModelForCausalLM.from_pretrained(model_id,
low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
if data_type == "fp16":
model.half()

model = deepspeed.init_inference(model,
mp_size=mp_size,
dtype=model.dtype,
replace_method='auto',
replace_with_kernel_inject=True)
self.predictor = pipeline(task='text-generation',
model=model,
tokenizer=tokenizer,
device=device)
raise ValueError(f"model_dir: {self.model_id} does not contain a config.json. "
f"This is required for loading models from local storage")
self.model_config = AutoConfig.from_pretrained(config_file)
else:
self.model_config = AutoConfig.from_pretrained(self.model_id)

self.initialized = True
if self.model_config.model_type not in SUPPORTED_MODEL_TYPES:
raise ValueError(f"model_type: {self.model_config.model_type} is not currently supported by DeepSpeed")

if not self.task:
self.logger.warning("No task provided. Attempting to infer from model architecture")
self.infer_task_from_model_architecture(self.model_config)
if self.task not in SUPPORTED_TASKS:
raise ValueError(f"task: {self.task} is not currently supported by DeepSpeed")

def infer_task_from_model_architecture(self, config: PretrainedConfig):
architecture = config.architectures[0]
for arch_option in ARCHITECTURES_TO_TASK:
if architecture.endswith(arch_option):
self.task = ARCHITECTURES_TO_TASK[arch_option]

if not self.task:
raise ValueError(f"Task could not be inferred from model config. "
f"Please manually set `task` in serving.properties")

def create_model_pipeline(self):
# If a ds checkpoint is provided, we instantiate model with meta tensors. weights loaded when DS engine invoked
if self.ds_config["checkpoint"]:
dtype = torch.float32 if self.data_type == torch.float32 else torch.float16
with deepspeed.OnDevice(dtype=dtype, device="meta"):
model = TASK_TO_MODEL[self.task].from_config(self.model_config)
else:
model = TASK_TO_MODEL[self.task].from_pretrained(self.model_id, low_cpu_mem_usage=self.low_cpu_mem_usage)

model.eval()
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.pipeline = pipeline(task=self.task, model=model, tokenizer=tokenizer, device=self.device)
if self.model_config.model_type in MODEL_TYPE_TO_INJECTION_POLICY:
self.ds_config["injection_policy"] = MODEL_TYPE_TO_INJECTION_POLICY[self.model_config.model_type]
engine = deepspeed.init_inference(self.pipeline.model, **self.ds_config)
self.pipeline.model = engine.module

def format_input_for_task(self, input_values):
if not isinstance(input_values, list):
input_values = [input_values]

batch_inputs = []
for val in input_values:
if self.task == "conversational":
current_input = Conversation(
text=val.get("text"),
conversation_id=val.get("conversation_id"),
past_user_inputs=val.get("past_user_inputs", []),
generated_responses=val.get("generated_responses", [])
)
elif self.task == "question-answering":
current_input = SquadExample(
None,
val.get("context"),
val.get("question"),
None,
None,
None
)
else:
current_input = val
batch_inputs += [current_input]
return batch_inputs

def inference(self, inputs: Input):
try:
content_type = inputs.get_property("Content-Type")
model_kwargs = {}
if content_type is not None and content_type == "application/json":
json_input = inputs.get_as_json()
if isinstance(json_input, dict):
max_tokens = json_input.pop("max_new_tokens",
self.max_new_tokens)
data = json_input.pop("inputs", json_input)
input_data = self.format_input_for_task(json_input.pop("inputs"))
model_kwargs = json_input
else:
max_tokens = self.max_new_tokens
data = json_input
input_data = json_input
else:
data = inputs.get_as_string()
max_tokens = self.max_new_tokens
input_data = inputs.get_as_string()

result = self.predictor(data,
do_sample=True,
max_new_tokens=max_tokens)
result = self.pipeline(input_data, **model_kwargs)
if self.task == "conversational":
result = {
"generated_text": result.generated_responses[-1],
"conversation": {
"past_user_inputs": result.past_user_inputs,
"generated_responses": result.generated_responses,
},
}

outputs = Output()
outputs.add(result)
except Exception as e:
logging.exception("DeepSpeed inference failed")
# error handling
outputs = Output().error(str(e))

outputs = Output().error((str(e)))
return outputs


Expand All @@ -113,7 +276,6 @@ def handle(inputs: Input) -> Optional[Output]:
_service.initialize(inputs.get_properties())

if inputs.is_empty():
# initialization request
return None

return _service.inference(inputs)
return _service.inference(inputs)
Loading