Skip to content

Commit

Permalink
Support additional configs through serving.properties, add initial bl…
Browse files Browse the repository at this point in the history
…oom support
  • Loading branch information
siddvenk committed Dec 7, 2022
1 parent 7905e8f commit 9448997
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
40 changes: 26 additions & 14 deletions engines/python/setup/djl_python/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
AutoModelForTokenClassification,
pipeline,
Conversation,
SquadExample,
SquadExample
)
import transformers
from deepspeed.module_inject.replace_policy import (
Expand Down Expand Up @@ -70,6 +70,7 @@
"ForQuestionAnswering": "question-answering",
"ForMaskedLM": "fill-mask",
"ForTokenClassification": "token-classification",
"BloomModel": "text-generation",
}

TASK_TO_MODEL = {
Expand All @@ -93,6 +94,16 @@
}


def get_torch_dtype_from_str(dtype: str):
if dtype == "fp16":
return torch.float16
if dtype == "bf16":
return torch.bfloat16
if dtype == "int8":
return torch.int8
return torch.float32


class DeepSpeedService(object):

def __init__(self):
Expand All @@ -109,6 +120,7 @@ def __init__(self):
self.world_size = None
self.tensor_parallel_degree = None
self.model_config = None
self.low_cpu_mem_usage = False

def initialize(self, properties: dict):
self.parse_properties(properties)
Expand All @@ -125,19 +137,17 @@ def parse_properties(self, properties):
self.model_dir = properties.get("model_dir")
self.model_id = properties.get("model_id")
self.task = properties.get("task")
self.data_type = properties.get("data_type", "fp32")
self.data_type = get_torch_dtype_from_str(properties.get("data_type", "fp32"))
self.max_tokens = int(properties.get("max_tokens", 1024))
self.device = int(os.getenv("LOCAL_RANK", 0))
self.world_size = int(os.getenv("WORLD_SIZE", 1))
self.tensor_parallel_degree = int(properties.get("tensor_parallel_degree", self.world_size))
self.low_cpu_mem_usage = properties.get("low_cpu_mem_usage", "true").lower() == "true"
self.ds_config = {
"replace_with_kernel_inject": True,
"dtype": self.data_type,
"tensor_parallel": {
"enabled": True,
"tp_size": self.tensor_parallel_degree,
"mpu": None,
},
"mp_size": self.tensor_parallel_degree,
"mpu": None,
"enable_cuda_graph": properties.get("enable_cuda_graph", "false").lower() == "true",
"triangular_masking": properties.get("triangular_masking", "true").lower() == "true",
"checkpoint": properties.get("checkpoint"),
Expand All @@ -146,7 +156,7 @@ def parse_properties(self, properties):
"training_mp_size": int(properties.get("training_mp_size", 1)),
"replace_method": "auto",
"injection_policy": None,
"max_out_tokens": self.max_tokens,
"max_tokens": self.max_tokens,
}

def validate_model_type_and_task(self):
Expand Down Expand Up @@ -182,21 +192,19 @@ def infer_task_from_model_architecture(self, config: PretrainedConfig):
def create_model_pipeline(self):
# If a ds checkpoint is provided, we instantiate model with meta tensors. weights loaded when DS engine invoked
if self.ds_config["checkpoint"]:
dtype = torch.float32 if self.data_type == "fp32" else torch.float16
dtype = torch.float32 if self.data_type == torch.float32 else torch.float16
with deepspeed.OnDevice(dtype=dtype, device="meta"):
model = TASK_TO_MODEL[self.task].from_config(self.model_config)
else:
model = TASK_TO_MODEL[self.task].from_pretrained(self.model_id, low_cpu_mem_usage=True)
model = TASK_TO_MODEL[self.task].from_pretrained(self.model_id, low_cpu_mem_usage=self.low_cpu_mem_usage)

model.eval()
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.pipeline = pipeline(task=self.task, model=model, tokenizer=tokenizer, device=self.device)
self.logger.info(f"Model before deepspeed: {self.pipeline.model}")
if self.model_config.model_type in MODEL_TYPE_TO_INJECTION_POLICY:
self.ds_config["injection_policy"] = MODEL_TYPE_TO_INJECTION_POLICY[self.model_config.model_type]
engine = deepspeed.init_inference(self.pipeline.model, **self.ds_config)
self.pipeline.model = engine.module
self.logger.info(f"Model after deepspeed: {self.pipeline.model}")

def format_input_for_task(self, input_values):
if not isinstance(input_values, list):
Expand All @@ -213,8 +221,12 @@ def format_input_for_task(self, input_values):
)
elif self.task == "question-answering":
current_input = SquadExample(
context_text=val.get("context"),
question_text=val.get("question")
None,
val.get("context"),
val.get("question"),
None,
None,
None
)
else:
current_input = val
Expand Down
4 changes: 4 additions & 0 deletions serving/docker/deepspeed.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ FROM nvidia/cuda:$version
ARG djl_version=0.20.0~SNAPSHOT
ARG torch_version=1.12.1
ARG accelerate_version=0.13.2
<<<<<<< HEAD
ARG deepspeed_wheel="https://publish.djl.ai/deepspeed/deepspeed-0.7.5%2Bbf16-py2.py3-none-any.whl"
=======
ARG deepspeed_version=0.7.5
>>>>>>> 03747b9 (Support additional configs through serving.properties, add initial bloom support)
ARG transformers_version=4.23.1
ARG diffusers_version=0.7.2

Expand Down

0 comments on commit 9448997

Please sign in to comment.