Skip to content

Commit

Permalink
fail on unsupported dtype, make some methods private
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk committed Dec 7, 2022
1 parent 9448997 commit 054016e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
13 changes: 8 additions & 5 deletions engines/python/setup/djl_python/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,16 @@


def get_torch_dtype_from_str(dtype: str):
if dtype == "fp32":
return torch.float32
if dtype == "fp16":
return torch.float16
if dtype == "bf16":
elif dtype == "bf16":
return torch.bfloat16
if dtype == "int8":
elif dtype == "int8":
return torch.int8
return torch.float32
else:
raise ValueError(f"Invalid data type: {dtype}")


class DeepSpeedService(object):
Expand Down Expand Up @@ -133,7 +136,7 @@ def initialize(self, properties: dict):
f"tensor_parallel_degree: {self.tensor_parallel_degree}")
self.initialized = True

def parse_properties(self, properties):
def _parse_properties(self, properties):
self.model_dir = properties.get("model_dir")
self.model_id = properties.get("model_id")
self.task = properties.get("task")
Expand All @@ -159,7 +162,7 @@ def parse_properties(self, properties):
"max_tokens": self.max_tokens,
}

def validate_model_type_and_task(self):
def _validate_model_type_and_task(self):
if not self.model_id:
self.model_id = self.model_dir
config_file = os.path.join(self.model_id, "config.json")
Expand Down
4 changes: 0 additions & 4 deletions serving/docker/deepspeed.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@ FROM nvidia/cuda:$version
ARG djl_version=0.20.0~SNAPSHOT
ARG torch_version=1.12.1
ARG accelerate_version=0.13.2
<<<<<<< HEAD
ARG deepspeed_wheel="https://publish.djl.ai/deepspeed/deepspeed-0.7.5%2Bbf16-py2.py3-none-any.whl"
=======
ARG deepspeed_version=0.7.5
>>>>>>> 03747b9 (Support additional configs through serving.properties, add initial bloom support)
ARG transformers_version=4.23.1
ARG diffusers_version=0.7.2

Expand Down

0 comments on commit 054016e

Please sign in to comment.