New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for bigscience/bloomz #25
Conversation
23f4ff7
to
989b92c
Compare
f084033
to
448d3f8
Compare
448d3f8
to
773e830
Compare
ce5ebf3
to
8938221
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mayank31398 looks great, just a few minor comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noticed one more small thing
elif model_name == DS_INFERENCE_BLOOM_INT8: | ||
return "int8" | ||
|
||
|
||
def get_torch_dtype(dtype_str: str) -> torch.dtype: | ||
if dtype_str == "bf16": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: The type hints in the get_str_dtype
func below are the wrong way around. It should take a torch.dtype
and return a str
inference_server/models/utils.py
Outdated
from .ds_inference import DSInferenceModel | ||
from .ds_zero import DSZeROModel | ||
from .hf_accelerate import HFAccelerateModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be good to conditionally import these to avoid depending on non-required libraries when only one of them is used
No description provided.