From 035c760632544cc747bd17e3c12a7e0a0ec14427 Mon Sep 17 00:00:00 2001 From: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> Date: Tue, 25 Jul 2023 03:20:09 +0000 Subject: [PATCH 1/2] fix(build): running from container choosing models correctly Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> --- pyproject.toml | 5 +- src/openllm/_llm.py | 133 ++++------------------ src/openllm/bundle/_package.py | 32 ++---- src/openllm/cli.py | 16 +-- src/openllm/serialisation/__init__.py | 13 +-- src/openllm/serialisation/transformers.py | 105 +++++------------ 6 files changed, 66 insertions(+), 238 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6868fbcd7..1c8091b14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -305,10 +305,7 @@ reportUnknownVariableType = "warning" typeCheckingMode = "strict" [tool.mypy] -# TODO: remove all of the disable to ensure strict type -disable_error_code = ["attr-defined", "name-defined", "annotation-unchecked"] -enable_error_code = ["redundant-expr"] -exclude = ["examples/", "tools/", "tests/", "src/openllm/playground/"] +exclude = ["src/openllm/playground/"] files = ["src/openllm", "src/openllm_client"] local_partial_types = true mypy_path = "typings" diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 4e29795c7..2f9a5c67f 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -29,14 +29,13 @@ from pathlib import Path import attr +import fs.path import inflection import orjson from huggingface_hub import hf_hub_download import bentoml import openllm -from bentoml._internal.configuration.containers import BentoMLContainer -from bentoml._internal.models import ModelStore from bentoml._internal.models.model import CUSTOM_OBJECTS_FILENAME from bentoml._internal.models.model import ModelSignature @@ -115,14 +114,12 @@ logger = logging.getLogger(__name__) - class ModelSignatureDict(t.TypedDict, total=False): batchable: bool batch_dim: t.Union[t.Tuple[int, int], int] input_spec: NotRequired[t.Union[t.Any, t.Tuple[t.Any]]] output_spec: NotRequired[t.Any] - def normalise_model_name(name: str) -> str: return os.path.basename(resolve_filepath(name)) if validate_is_path(name) else re.sub("[^a-zA-Z0-9]+", "-", name) @functools.lru_cache(maxsize=128) @@ -142,7 +139,6 @@ def generate_hash_from_file(f: str, algorithm: t.Literal["md5", "sha1"] = "sha1" # the below is similar to peft.utils.other.CONFIG_NAME PEFT_CONFIG_NAME = "adapter_config.json" - def resolve_peft_config_type(adapter_map: dict[str, str | None]) -> AdaptersMapping: """Resolve the type of the PeftConfig given the adapter_map. @@ -172,18 +168,10 @@ def resolve_peft_config_type(adapter_map: dict[str, str | None]) -> AdaptersMapp resolved[_peft_type] += (_AdaptersTuple((path_or_adapter_id, resolve_name, resolved_config)),) return resolved - _reserved_namespace = {"config_class", "model", "tokenizer", "import_kwargs"} -M = t.TypeVar( - "M", - bound="t.Union[transformers.PreTrainedModel, transformers.Pipeline, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel, vllm.LLMEngine, vllm.AsyncLLMEngine, peft.PeftModel, autogptq.modeling.BaseGPTQForCausalLM]", -) -T = t.TypeVar( - "T", - bound="t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]", -) - +M = t.TypeVar("M", bound="t.Union[transformers.PreTrainedModel, transformers.Pipeline, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel, vllm.LLMEngine, vllm.AsyncLLMEngine, peft.PeftModel, autogptq.modeling.BaseGPTQForCausalLM]") +T = t.TypeVar("T", bound="t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]") def _default_post_init(self: LLM[t.Any, t.Any]) -> None: self.device = None @@ -191,10 +179,8 @@ def _default_post_init(self: LLM[t.Any, t.Any]) -> None: if self.__llm_implementation__ == "pt" and is_torch_available(): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - class LLMInterface(ABC, t.Generic[M, T]): """This defines the loose contract for all openllm.LLM implementations.""" - @property def import_kwargs(self) -> tuple[DictStrAny, DictStrAny] | None: """The default import kwargs to used when importing the model. @@ -205,7 +191,6 @@ def import_kwargs(self) -> tuple[DictStrAny, DictStrAny] | None: Returns: Optional tuple of model kwargs and tokenizer kwargs """ - def embeddings(self, prompts: list[str]) -> LLMEmbeddings: """The implementation for generating text embeddings from given prompt. @@ -215,7 +200,6 @@ def embeddings(self, prompts: list[str]) -> LLMEmbeddings: The embeddings for the given prompt. """ raise NotImplementedError - @abstractmethod def generate(self, prompt: str, **preprocess_generate_kwds: t.Any) -> t.Any: """The implementation for text generation from given prompt. @@ -224,26 +208,16 @@ def generate(self, prompt: str, **preprocess_generate_kwds: t.Any) -> t.Any: pass it to 'self.model.generate'. """ raise NotImplementedError - - def generate_one( - self, - prompt: str, - stop: list[str], - **preprocess_generate_kwds: t.Any, - ) -> t.Sequence[dict[t.Literal["generated_text"], str]]: + def generate_one(self, prompt: str, stop: list[str], **preprocess_generate_kwds: t.Any) -> t.Sequence[dict[t.Literal["generated_text"], str]]: """The entrypoint for generating one prompt. This provides additional stop tokens for generating per token level. This is useful when running with agents, or initial streaming support. """ raise NotImplementedError - def generate_iterator(self, prompt: str, **attrs: t.Any) -> t.Iterator[t.Any]: - """T iterator version of `generate` function.""" - raise NotImplementedError( - "Currently generate_iterator requires SSE (Server-side events) support, which is not yet implemented." - ) - + """The iterator version of `generate` function.""" + raise NotImplementedError("Currently generate_iterator requires SSE (Server-side events) support, which is not yet implemented.") def sanitize_parameters(self, prompt: str, **attrs: t.Any) -> tuple[str, DictStrAny, DictStrAny]: """This handler will sanitize all attrs and setup prompt text. @@ -254,7 +228,6 @@ def sanitize_parameters(self, prompt: str, **attrs: t.Any) -> tuple[str, DictStr - The attributes dictionary that will be passed into `self.postprocess_generate`. """ return prompt, attrs, attrs - def postprocess_generate(self, prompt: str, generation_result: t.Any, **attrs: t.Any) -> t.Any: """This handler will postprocess generation results from LLM.generate and then output nicely formatted results (if the LLM decide to do so.). @@ -263,11 +236,9 @@ def postprocess_generate(self, prompt: str, generation_result: t.Any, **attrs: t NOTE: this will be used from the client side. """ return generation_result - def llm_post_init(self) -> None: """This function can be implemented if you need to initialized any additional variables that doesn't concern OpenLLM internals.""" pass - def import_model(self, *args: t.Any, trust_remote_code: bool, **attrs: t.Any) -> bentoml.Model: """This function can be implemented if default import_model doesn't satisfy your needs. @@ -280,21 +251,18 @@ def import_model(self, *args: t.Any, trust_remote_code: bool, **attrs: t.Any) -> By default, `model_decls` and `model_attrs` is already sanitised and concatenated into `args` and `attrs` """ raise NotImplementedError - def load_model(self, *args: t.Any, **attrs: t.Any) -> M: """This function can be implemented to override the default load_model behaviour. See falcon for example implementation. Tag can be accessed via ``self.tag`` """ raise NotImplementedError - def load_tokenizer(self, tag: bentoml.Tag, **attrs: t.Any) -> T: """This function can be implemented to override how to load the tokenizer. See falcon for example implementation. """ raise NotImplementedError - def save_pretrained(self, save_directory: str | Path, **attrs: t.Any) -> None: """This function defines how this model can be saved to local store. @@ -303,7 +271,6 @@ def save_pretrained(self, save_directory: str | Path, **attrs: t.Any) -> None: This is useful during fine tuning. """ raise NotImplementedError - # NOTE: All fields below are attributes that can be accessed by users. config_class: type[openllm.LLMConfig] """The config class to use for this LLM. If you are creating a custom LLM, you must specify this class.""" @@ -354,9 +321,7 @@ def save_pretrained(self, save_directory: str | Path, **attrs: t.Any) -> None: """A boolean to determine whether models does implement ``LLM.generate_one``.""" __llm_supports_generate_iterator__: bool """A boolean to determine whether models does implement ``LLM.generate_iterator``.""" - if t.TYPE_CHECKING and not MYPY: - def __attrs_init__( self, config: openllm.LLMConfig, @@ -375,36 +340,22 @@ def __attrs_init__( ) -> None: """Generated __attrs_init__ for openllm.LLM.""" - if t.TYPE_CHECKING: _R = t.TypeVar("_R") - class _import_model_wrapper(t.Generic[_R, M, T]): - def __call__(self, llm: LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> _R: - ... - + def __call__(self, llm: LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> _R: ... class _load_model_wrapper(t.Generic[M, T]): - def __call__(self, llm: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M: - ... - + def __call__(self, llm: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M: ... class _load_tokenizer_wrapper(t.Generic[M, T]): - def __call__(self, llm: LLM[M, T], **attrs: t.Any) -> T: - ... - + def __call__(self, llm: LLM[M, T], **attrs: t.Any) -> T: ... class _llm_post_init_wrapper(t.Generic[M, T]): - def __call__(self, llm: LLM[M, T]) -> T: - ... - + def __call__(self, llm: LLM[M, T]) -> T: ... class _save_pretrained_wrapper(t.Generic[M, T]): - def __call__(self, llm: LLM[M, T], save_directory: str | Path, **attrs: t.Any) -> None: - ... - + def __call__(self, llm: LLM[M, T], save_directory: str | Path, **attrs: t.Any) -> None: ... def _wrapped_import_model(f: _import_model_wrapper[bentoml.Model, M, T]): @functools.wraps(f) - def wrapper( - self: LLM[M, T], *decls: t.Any, trust_remote_code: bool | None = None, **attrs: t.Any - ) -> bentoml.Model: + def wrapper(self: LLM[M, T], *decls: t.Any, trust_remote_code: bool | None = None, **attrs: t.Any) -> bentoml.Model: trust_remote_code = first_not_none(trust_remote_code, default=self.__llm_trust_remote_code__) # wrapped around custom init to provide some meta compression # for all decls and attrs @@ -412,13 +363,10 @@ def wrapper( decls = (*model_decls, *decls) attrs = {**model_attrs, **attrs} return f(self, *decls, trust_remote_code=trust_remote_code, **attrs) - return wrapper - _DEFAULT_TOKENIZER = "hf-internal-testing/llama-tokenizer" - @requires_dependencies("vllm", extra="vllm") def get_engine_args(llm: openllm.LLM[M, T], tokenizer: str = _DEFAULT_TOKENIZER) -> vllm.EngineArgs: return vllm.EngineArgs(model=llm._bentomodel.path, tokenizer=tokenizer, tokenizer_mode="auto", tensor_parallel_size=1, dtype="auto", worker_use_ray=False) def _wrapped_load_model(f: _load_model_wrapper[M, T]): @@ -437,26 +385,21 @@ def wrapper(self: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M | vllm.LLMEngin return f(self, *(*model_decls, *decls), **{**model_attrs, **attrs}) return wrapper - def _wrapped_load_tokenizer(f: _load_tokenizer_wrapper[M, T]): @functools.wraps(f) def wrapper(self: LLM[M, T], **tokenizer_attrs: t.Any) -> T: _, model_tokenizer_attrs = self.llm_parameters tokenizer_attrs = {**model_tokenizer_attrs, **tokenizer_attrs} return f(self, **tokenizer_attrs) - return wrapper - def _wrapped_llm_post_init(f: _llm_post_init_wrapper[M, T]) -> t.Callable[[LLM[M, T]], None]: @functools.wraps(f) def wrapper(self: LLM[M, T]): _default_post_init(self) f(self) - return wrapper - def _wrapped_save_pretrained(f: _save_pretrained_wrapper[M, T]): @functools.wraps(f) def wrapper(self: LLM[M, T], save_directory: str | Path, **attrs: t.Any) -> None: @@ -467,20 +410,17 @@ def wrapper(self: LLM[M, T], save_directory: str | Path, **attrs: t.Any) -> None f(self, save_directory, **attrs) return wrapper - def _update_docstring(cls: LLM[M, T], fn: str) -> AnyCallable: # update docstring for given entrypoint original_fn = getattr(cls, fn, getattr(LLMInterface, fn)) - original_fn.__doc__ = ( - original_fn.__doc__ - or f"""\ + original_fn.__doc__ = original_fn.__doc__ or f"""\ {cls.__name__}'s implementation for {fn}. Note that if LoRA is enabled (via either SDK or CLI), `self.model` will become a `peft.PeftModel` The original model can then be accessed with 'self.model.get_base_model()'. """ - ) setattr(cls, fn, original_fn) + return original_fn def _make_assignment_script(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]]], None]: @@ -504,10 +444,8 @@ def _make_assignment_script(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]] impl_name = f"__wrapped_{func}" globs.update({f"__serialisation_{func}": getattr(openllm.serialisation, func, None), impl_name: impl}) cached_func_name = f"_cached_{cls.__name__}_func" - if func == "llm_post_init": - func_call = f"_impl_{cls.__name__}_{func}={cached_func_name}" - else: - func_call = f"_impl_{cls.__name__}_{func}={cached_func_name} if {cached_func_name} is not _cached_LLMInterface_get('{func}') else __serialisation_{func}" + if func == "llm_post_init": func_call = f"_impl_{cls.__name__}_{func}={cached_func_name}" + else: func_call = f"_impl_{cls.__name__}_{func}={cached_func_name} if {cached_func_name} is not _cached_LLMInterface_get('{func}') else __serialisation_{func}" lines.extend( [ f"{cached_func_name}=cls.{func}", @@ -532,17 +470,13 @@ def _make_assignment_script(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]] ] ) anns[key] = interface_anns.get(key) - - return codegen.generate_function( - cls, "__assign_llm_attr", lines, args=("cls", *args), globs=globs, annotations=anns - ) - + return codegen.generate_function(cls, "__assign_llm_attr", lines, args=("cls", *args), globs=globs, annotations=anns) _AdaptersTuple: type[AdaptersTuple] = codegen.make_attr_tuple_class("AdaptersTuple", ["adapter_id", "name", "config"]) - @attr.define(slots=True, repr=False, init=False) class LLM(LLMInterface[M, T], ReprMixin): + if t.TYPE_CHECKING: __name__: str config: openllm.LLMConfig """The config instance to use for this LLM. This will be created based on config_class and available when initialising the LLM.""" @@ -585,7 +519,6 @@ def vllm_generate(self: LLM["vllm.LLMEngine", T], prompt: str, **attrs: t.Any) - return [openllm.unmarshal_vllm_outputs(i) for i in outputs] cls.postprocess_generate = vllm_postprocess_generate cls.generate = vllm_generate - # fmt: off @overload def __getitem__(self, item: t.Literal["trust_remote_code"]) -> bool: ... @@ -615,7 +548,6 @@ def __getitem__(self, item: t.LiteralString | t.Any) -> t.Any: if hasattr(self, internal_attributes): return getattr(self, internal_attributes) elif hasattr(self, item): return getattr(self, item) else: raise KeyError(item) - @classmethod @overload def from_pretrained(cls, model_id: str | None = ..., model_version: str | None = ..., llm_config: openllm.LLMConfig | None = ..., *args: t.Any, runtime: t.Literal["ggml", "transformers"] | None = ..., quantize: t.Literal["int8", "int4"] = ..., bettertransformer: str | bool | None = ..., adapter_id: str | None = ..., adapter_name: str | None = ..., adapter_map: dict[str, str | None] | None = ..., quantization_config: transformers.BitsAndBytesConfig | None = ..., serialisation: t.Literal["safetensors", "legacy"] = ..., **attrs: t.Any) -> LLM[M, T]: ... @@ -738,7 +670,6 @@ def from_pretrained( _serialisation_format=serialisation, **attrs, ) - @classmethod @functools.lru_cache def generate_tag(cls, model_id: str, model_version: str | None) -> bentoml.Tag: @@ -748,6 +679,9 @@ def generate_tag(cls, model_id: str, model_version: str | None) -> bentoml.Tag: If model_id contains the revision itself, then it will be -:revision If model_id is a path, then it will be -- """ + # specific branch for running in docker, this is very hacky, needs change upstream + if in_docker() and os.getenv("BENTO_PATH") is not None: return bentoml.Tag.from_taglike(":".join(fs.path.parts(model_id)[-2:])) + model_name = normalise_model_name(model_id) model_id, *maybe_revision = model_id.rsplit(":") if len(maybe_revision) > 0: @@ -762,7 +696,6 @@ def generate_tag(cls, model_id: str, model_version: str | None) -> bentoml.Tag: model_version = getattr(transformers.AutoConfig.from_pretrained(model_id, trust_remote_code=cls.config_class.__openllm_trust_remote_code__, revision=first_not_none(model_version, default="main")), "_commit_hash", None) if model_version is None: raise ValueError(f"Internal errors when parsing config for pretrained {model_id} ('commit_hash' not found)") return bentoml.Tag.from_taglike(f"{tag_name}:{model_version}") - def __init__( self, *args: t.Any, @@ -864,22 +797,12 @@ def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any: # low_cpu_mem_usage is only available for model # this is helpful on system with low memory to avoid OOM low_cpu_mem_usage = attrs.pop("low_cpu_mem_usage", True) - - if self.__llm_implementation__ == "pt": - attrs.update({"low_cpu_mem_usage": low_cpu_mem_usage, "quantization_config": quantization_config}) - + if self.__llm_implementation__ == "pt": attrs.update({"low_cpu_mem_usage": low_cpu_mem_usage, "quantization_config": quantization_config}) model_kwds, tokenizer_kwds = {}, {} if self.import_kwargs is not None: model_kwds, tokenizer_kwds = self.import_kwargs # parsing tokenizer and model kwargs, as the hierachy is param pass > default normalized_model_kwds, normalized_tokenizer_kwds = normalize_attrs_to_model_tokenizer_pair(**attrs) # NOTE: Save the args and kwargs for latter load - - # specific branch for running in docker, this is very hacky, needs change upstream - if in_docker() and os.getenv("BENTO_PATH") is not None: - BentoMLContainer.model_store.set(ModelStore("/home/bentoml/bento/models")) - os.environ["OPENLLM_USE_LOCAL_LATEST"] = str(True) - _tag = self.generate_tag(model_id, _model_version) - self.__attrs_init__(llm_config, quantization_config, model_id, _runtime, args, {**model_kwds, **normalized_model_kwds}, {**tokenizer_kwds, **normalized_tokenizer_kwds}, _tag, _adapters_mapping, _model_version, _quantize_method, _serialisation_format) # handle trust_remote_code self.__llm_trust_remote_code__ = self._model_attrs.pop("trust_remote_code", self.config["trust_remote_code"]) @@ -890,11 +813,9 @@ def load_model(self, tag: bentoml.Tag, *args: t.Any, **attrs: t.Any) -> t.Any: else: non_intrusive_setattr(self, "bettertransformer", self.config["bettertransformer"]) # If lora is passed, the disable bettertransformer if _adapters_mapping and self.bettertransformer is True: self.bettertransformer = False - def __setattr__(self, attr: str, value: t.Any) -> None: if attr in _reserved_namespace: raise ForbiddenAttributeError(f"{attr} should not be set during runtime as these value will be reflected during runtime. Instead, you can create a custom LLM subclass {self.__class__.__name__}.") super().__setattr__(attr, value) - @property def adapters_mapping(self) -> AdaptersMapping | None: return self._adapters_mapping @adapters_mapping.setter @@ -939,11 +860,9 @@ def tokenizer(self) -> T: # NOTE: the signature of load_tokenizer here is the wrapper under _wrapped_load_tokenizer if self.__llm_tokenizer__ is None: self.__llm_tokenizer__ = self.load_tokenizer(**self._tokenizer_attrs) return self.__llm_tokenizer__ - def _default_ft_config(self, _adapter_type: AdapterType, inference_mode: bool) -> FineTuneConfig: strategy = first_not_none(self.config["fine_tune_strategies"].get(_adapter_type), default=FineTuneConfig(adapter_type=t.cast("PeftType", _adapter_type), llm_config_class=self.config_class)) return strategy.eval() if inference_mode else strategy.train() - def _transpose_adapter_mapping( self, inference_mode: bool = True, use_cache: bool = True) -> ResolvedAdaptersMapping: if self._adapters_mapping is None: raise ValueError("LoRA mapping is not set up correctly.") # early out if we already serialized everything. @@ -963,10 +882,8 @@ def _transpose_adapter_mapping( self, inference_mode: bool = True, use_cache: bo name = "default" peft_config = default_config.with_config(**adapter.config).to_peft_config() if name == "default" else FineTuneConfig(adapter_type=t.cast("PeftType", _adapter_type), adapter_config=adapter.config, inference_mode=inference_mode, llm_config_class=self.config_class).to_peft_config() adapter_map[_adapter_type][name] = (peft_config, adapter.adapter_id) - if self.__llm_adapter_map__ is None and use_cache: self.__llm_adapter_map__ = adapter_map return adapter_map - @requires_dependencies("peft", extra="fine-tune") def prepare_for_training(self, adapter_type: AdapterType = "lora", use_gradient_checkpointing: bool = True, **attrs: t.Any) -> tuple[peft.PeftModel, T]: from peft import prepare_model_for_kbit_training @@ -974,7 +891,6 @@ def prepare_for_training(self, adapter_type: AdapterType = "lora", use_gradient_ wrapped_peft = peft.get_peft_model(prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=use_gradient_checkpointing), peft_config) if DEBUG: wrapped_peft.print_trainable_parameters() return wrapped_peft, self.tokenizer - @requires_dependencies("peft", extra="fine-tune") def apply_adapter(self, inference_mode: bool = True, adapter_type: AdapterType = "lora", load_adapters: t.Literal["all"] | list[str] | None = None, use_cache: bool = True) -> peft.PeftModel | M: """Apply given LoRA mapping to the model. @@ -1006,7 +922,6 @@ def apply_adapter(self, inference_mode: bool = True, adapter_type: AdapterType = self.__llm_model__.load_adapter(_peft_model_id, adapter_name=adapter_name, is_trainable=not inference_mode, **dict(_peft_config.to_dict())) return self.__llm_model__ - def _wrap_default_peft_model(self, adapter_mapping: dict[str, tuple[peft.PeftConfig, str]], inference_mode: bool): assert self.__llm_model__ is not None, "Error: Model is not loaded correctly" # noqa: S101 if isinstance(self.__llm_model__, peft.PeftModel): return self.__llm_model__ @@ -1031,7 +946,6 @@ def _wrap_default_peft_model(self, adapter_mapping: dict[str, tuple[peft.PeftCon model = peft_class.from_pretrained(self.__llm_model__, peft_model_id, **kwargs) else: model = peft_class(self.__llm_model__, default_config) # in this case, the given base_model_name_or_path is None. This will be hit during training return model - # order of these fields matter here, make sure to sync it with # openllm.models.auto.factory.BaseAutoLLMClass.for_model def to_runner(self, models: list[bentoml.Model] | None = None, max_batch_size: int | None = None, max_latency_ms: int | None = None, scheduling_strategy: type[bentoml.Strategy] | None = None) -> LLMRunner[M, T]: @@ -1079,11 +993,9 @@ def to_runner(self, models: list[bentoml.Model] | None = None, max_batch_size: i method_configs=bentoml_cattr.unstructure({"embeddings": embeddings_sig, "__call__": generate_sig, "generate": generate_sig, "generate_one": generate_sig, "generate_iterator": generate_iterator_sig}), scheduling_strategy=scheduling_strategy, ) - def predict(self, prompt: str, **attrs: t.Any) -> t.Any: """The scikit-compatible API for self(...).""" return self.__call__(prompt, **attrs) - def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: """Returns the generation result and format the result. @@ -1101,7 +1013,6 @@ def __call__(self, prompt: str, **attrs: t.Any) -> t.Any: prompt, generate_kwargs, postprocess_kwargs = self.sanitize_parameters(prompt, **attrs) return self.postprocess_generate(prompt, self.generate(prompt, **generate_kwargs), **postprocess_kwargs) - @overload def Runner(model_name: str, *, model_id: str | None = None, model_version: str | None = ..., init_local: t.Literal[False, True] = ..., **attrs: t.Any) -> LLMRunner[t.Any, t.Any]: ... @overload diff --git a/src/openllm/bundle/_package.py b/src/openllm/bundle/_package.py index 04da2851c..d5b2ea045 100644 --- a/src/openllm/bundle/_package.py +++ b/src/openllm/bundle/_package.py @@ -23,7 +23,6 @@ import fs.copy import fs.errors import orjson -from packaging.version import Version from simple_di import Provide from simple_di import inject @@ -76,23 +75,13 @@ def build_editable(path: str) -> str | None: return builder.build("wheel", path, config_settings={"--global-option": "--quiet"}) raise RuntimeError("Custom OpenLLM build is currently not supported. Please install OpenLLM from PyPI or built it from Git source.") - -def handle_package_version(package: str, has_dockerfile_template: bool, lower_bound: bool = True): - version = Version(pkg.get_pkg_version(package)) - if version.is_devrelease: - if has_dockerfile_template: logger.warning("Installed %s has version %s as a dev release. This means you have a custom build of %s with %s. Make sure to use custom dockerfile templates (--dockerfile-template) to setup %s correctly. See https://docs.bentoml.com/en/latest/guides/containerization.html#dockerfile-template for more information.", package, version, package, "CUDA support" if "cu" in str(version) else "more features", package) - return package - return f"{package}>={importlib.metadata.version(package)}" if lower_bound else package - - def construct_python_options( llm: openllm.LLM[t.Any, t.Any], llm_fs: FS, - has_dockerfile_template: bool, extra_dependencies: tuple[str, ...] | None = None, adapter_map: dict[str, str | None] | None = None, ) -> PythonOptions: - packages = ["openllm"] + packages = ["openllm", "scipy"] # apparently bnb misses this one if adapter_map is not None: packages += ["openllm[fine-tune]"] # NOTE: add openllm to the default dependencies # if users has openllm custom built wheels, it will still respect @@ -102,14 +91,13 @@ def construct_python_options( req = llm.config["requirements"] if req is not None: packages.extend(req) - if str(os.environ.get("BENTOML_BUNDLE_LOCAL_BUILD", False)).lower() == "false": packages.append(f"bentoml>={'.'.join([str(i) for i in pkg.pkg_version_info('bentoml')])}") - env: EnvVarMixin = llm.config["env"] + env = llm.config["env"] framework_envvar = env["framework_value"] if framework_envvar == "flax": if not is_flax_available(): raise ValueError(f"Flax is not available, while {env.framework} is set to 'flax'") - packages.extend([handle_package_version("flax", has_dockerfile_template), handle_package_version("jax", has_dockerfile_template), handle_package_version("jaxlib", has_dockerfile_template)]) + packages.extend([importlib.metadata.version("flax"), importlib.metadata.version("jax"), importlib.metadata.version("jaxlib")]) elif framework_envvar == "tf": if not is_tf_available(): raise ValueError(f"TensorFlow is not available, while {env.framework} is set to 'tf'") candidates = ( @@ -127,7 +115,7 @@ def construct_python_options( # For the metadata, we have to look for both tensorflow and tensorflow-cpu for candidate in candidates: try: - pkgver = handle_package_version(candidate, has_dockerfile_template) + pkgver = importlib.metadata.version(candidate) if pkgver == candidate: packages.extend(["tensorflow"]) else: _tf_version = importlib.metadata.version(candidate) @@ -136,14 +124,12 @@ def construct_python_options( except importlib.metadata.PackageNotFoundError: pass else: if not is_torch_available(): raise ValueError("PyTorch is not available. Make sure to have it locally installed.") - packages.extend([handle_package_version("torch", has_dockerfile_template)]) + packages.extend([importlib.metadata.version("torch")]) wheels: list[str] = [] built_wheels = build_editable(llm_fs.getsyspath("/")) if built_wheels is not None: wheels.append(llm_fs.getsyspath(f"/{built_wheels.split('/')[-1]}")) - - return PythonOptions(packages=packages, wheels=wheels, lock_packages=False) - + return PythonOptions(packages=packages, wheels=wheels, lock_packages=False, extra_index_url=["https://download.pytorch.org/whl/cu118"]) def construct_docker_options( llm: openllm.LLM[t.Any, t.Any], @@ -164,7 +150,6 @@ def construct_docker_options( ] _bentoml_config_options += " " if _bentoml_config_options else "" + " ".join(_bentoml_config_options_opts) env: EnvVarMixin = llm.config["env"] - env_dict = { env.framework: env.framework_value, env.config: f"'{llm.config.model_dump_json().decode()}'", @@ -175,7 +160,6 @@ def construct_docker_options( "BENTOML_CONFIG_OPTIONS": f"'{_bentoml_config_options}'", env.model_id: f"/home/bentoml/bento/models/{llm.tag.path()}", # This is the default BENTO_PATH var } - if adapter_map: env_dict["BITSANDBYTES_NOWELCOME"] = os.environ.get("BITSANDBYTES_NOWELCOME", "1") # We need to handle None separately here, as env from subprocess doesn't accept None value. @@ -184,7 +168,6 @@ def construct_docker_options( if _env.bettertransformer_value is not None: env_dict[_env.bettertransformer] = str(_env.bettertransformer_value) if _env.quantize_value is not None: env_dict[_env.quantize] = _env.quantize_value env_dict[_env.runtime] = _env.runtime_value - return DockerOptions( cuda_version="11.8.0", env=env_dict, @@ -193,7 +176,6 @@ def construct_docker_options( python_version="3.9", ) - @inject def create_bento( bento_tag: bentoml.Tag, @@ -235,7 +217,7 @@ def create_bento( description=f"OpenLLM service for {llm.config['start_name']}", include=list(llm_fs.walk.files()), exclude=["/venv", "/.venv", "__pycache__/", "*.py[cod]", "*$py.class"], - python=construct_python_options(llm, llm_fs, dockerfile_template is None, extra_dependencies, adapter_map), + python=construct_python_options(llm, llm_fs, extra_dependencies, adapter_map), docker=construct_docker_options(llm, llm_fs, workers_per_resource, quantize, bettertransformer, adapter_map, dockerfile_template, runtime, serialisation_format), models=[llm_spec], ) diff --git a/src/openllm/cli.py b/src/openllm/cli.py index 9655c4298..cc43fcc01 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -1281,12 +1281,11 @@ def build_command( llm_fs.makedir(src_folder_name, recreate=True) fs.copy.copy_dir(src_fs, _adapter_id, llm_fs, src_folder_name) adapter_map[src_folder_name] = name - except FileNotFoundError: - # this is the remote adapter, then just added back - # note that there is a drawback here. If the path of the local adapter - # path have the same name as the remote, then we currently don't support - # that edge case. - adapter_map[_adapter_id] = name + # this is the remote adapter, then just added back + # note that there is a drawback here. If the path of the local adapter + # path have the same name as the remote, then we currently don't support + # that edge case. + except FileNotFoundError: adapter_map[_adapter_id] = name os.environ["OPENLLM_ADAPTER_MAP"] = orjson.dumps(adapter_map).decode() bento_tag = bentoml.Tag.from_taglike(f"{llm.llm_type}-service:{llm.tag.version}".lower().strip()) try: @@ -1332,8 +1331,6 @@ def build_command( elif not overwrite: _echo(f"'{model_name}' already has a Bento built [{bento}]. To overwrite it pass '--overwrite'.", fg="yellow") _echo( "📖 Next steps:\n\n" - + "* Serving BentoLLM locally with 'openllm start':\n" - + f" $ openllm start {bento.tag}\n\n" + "* Push to BentoCloud with 'bentoml push':\n" + f" $ bentoml push {bento.tag}\n\n" + "* Containerize your Bento with 'bentoml containerize':\n" @@ -1350,10 +1347,9 @@ def build_command( if push: BentoMLContainer.bentocloud_client.get().push_bento(bento, context=t.cast(CliContext, ctx.obj).cloud_context) elif containerize: backend = t.cast("DefaultBuilder", os.getenv("BENTOML_CONTAINERIZE_BACKEND", "docker")) - _echo(f"Building {bento} into a LLMContainer using backend '{backend}'", fg="magenta") try: bentoml.container.health(backend) except subprocess.CalledProcessError: raise OpenLLMException(f"Failed to use backend {backend}") from None - bentoml.container.build(bento.tag, backend=backend, features=("grpc",)) + bentoml.container.build(bento.tag, backend=backend, features=("grpc","io")) return bento diff --git a/src/openllm/serialisation/__init__.py b/src/openllm/serialisation/__init__.py index cbcc0321d..736c8e6bf 100644 --- a/src/openllm/serialisation/__init__.py +++ b/src/openllm/serialisation/__init__.py @@ -59,12 +59,7 @@ transformers = LazyLoader("transformers", globals(), "transformers") -def import_model( - llm: openllm.LLM[t.Any, t.Any], - *decls: t.Any, - trust_remote_code: bool, - **attrs: t.Any, -) -> bentoml.Model: +def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> bentoml.Model: if llm.runtime == "transformers": return openllm.transformers.import_model(llm, *decls, trust_remote_code=trust_remote_code, **attrs) elif llm.runtime == "ggml": @@ -73,7 +68,7 @@ def import_model( raise ValueError(f"Unknown runtime: {llm.config['runtime']}") -def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model: +def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model: if llm.runtime == "transformers": return openllm.transformers.get(llm, auto_import=auto_import) elif llm.runtime == "ggml": @@ -82,7 +77,7 @@ def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Mo raise ValueError(f"Unknown runtime: {llm.config['runtime']}") -def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs: t.Any) -> None: +def save_pretrained(llm: openllm.LLM[M, T], save_directory: str, **attrs: t.Any) -> None: if llm.runtime == "transformers": return openllm.transformers.save_pretrained(llm, save_directory, **attrs) elif llm.runtime == "ggml": @@ -91,7 +86,7 @@ def save_pretrained(llm: openllm.LLM[t.Any, t.Any], save_directory: str, **attrs raise ValueError(f"Unknown runtime: {llm.config['runtime']}") -def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M: +def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M: if llm.runtime == "transformers": return openllm.transformers.load_model(llm, *decls, **attrs) elif llm.runtime == "ggml": diff --git a/src/openllm/serialisation/transformers.py b/src/openllm/serialisation/transformers.py index db843f8a0..37e574d95 100644 --- a/src/openllm/serialisation/transformers.py +++ b/src/openllm/serialisation/transformers.py @@ -56,70 +56,39 @@ _object_setattr = object.__setattr__ -def process_transformers_config( - model_id: str, trust_remote_code: bool, **attrs: t.Any -) -> tuple[_transformers.PretrainedConfig, dict[str, t.Any], dict[str, t.Any]]: +def process_transformers_config(model_id: str, trust_remote_code: bool, **attrs: t.Any) -> tuple[_transformers.PretrainedConfig, dict[str, t.Any], dict[str, t.Any]]: """Process transformers config and return PretrainedConfig with hub_kwargs and the rest of kwargs.""" config: _transformers.PretrainedConfig | None = attrs.pop("config", None) - # this logic below is synonymous to handling `from_pretrained` attrs. hub_attrs = {k: attrs.pop(k) for k in HUB_ATTRS if k in attrs} if not isinstance(config, _transformers.PretrainedConfig): copied_attrs = copy.deepcopy(attrs) - if copied_attrs.get("torch_dtype", None) == "auto": - copied_attrs.pop("torch_dtype") - config, attrs = t.cast( - "tuple[_transformers.PretrainedConfig, dict[str, t.Any]]", - _transformers.AutoConfig.from_pretrained( - model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs - ), - ) - return config, hub_attrs, attrs - + if copied_attrs.get("torch_dtype", None) == "auto": copied_attrs.pop("torch_dtype") + config, attrs = _transformers.AutoConfig.from_pretrained(model_id, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **hub_attrs, **copied_attrs) + return t.cast("_transformers.PretrainedConfig", config), hub_attrs, t.cast("dict[str, t.Any]", attrs) def infer_tokenizers_class_for_llm(__llm: openllm.LLM[t.Any, T]) -> T: tokenizer_class = __llm.config["tokenizer_class"] - if tokenizer_class is None: - tokenizer_class = "AutoTokenizer" + if tokenizer_class is None: tokenizer_class = "AutoTokenizer" __cls = getattr(_transformers, tokenizer_class) - if __cls is None: - raise ValueError( - f"{tokenizer_class} is not a valid Tokenizer class from 'transformers.' Set '{__llm}.__config__[\"trust_remote_code\"] = True' and try again." - ) + if __cls is None: raise ValueError(f"{tokenizer_class} is not a valid Tokenizer class from 'transformers.' Set '{__llm}.__config__[\"trust_remote_code\"] = True' and try again.") return __cls - -def infer_autoclass_from_llm_config( - llm: openllm.LLM[t.Any, t.Any], config: _transformers.PretrainedConfig -) -> _BaseAutoModelClass: +def infer_autoclass_from_llm_config(llm: openllm.LLM[M, T], config: _transformers.PretrainedConfig) -> _BaseAutoModelClass: if llm.config["trust_remote_code"]: autoclass = "AutoModelForSeq2SeqLM" if llm.config["model_type"] == "seq2seq_lm" else "AutoModelForCausalLM" - if not hasattr(config, "auto_map"): - raise ValueError( - f"Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping" - ) + if not hasattr(config, "auto_map"): raise ValueError(f"Invalid configuraiton for {llm.model_id}. ``trust_remote_code=True`` requires `transformers.PretrainedConfig` to contain a `auto_map` mapping") # in case this model doesn't use the correct auto class for model type, for example like chatglm # where it uses AutoModel instead of AutoModelForCausalLM. Then we fallback to AutoModel - if autoclass not in config.auto_map: - autoclass = "AutoModel" + if autoclass not in config.auto_map: autoclass = "AutoModel" return getattr(_transformers, autoclass) else: - if type(config) in _transformers.MODEL_FOR_CAUSAL_LM_MAPPING: - idx = 0 - elif type(config) in _transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: - idx = 1 - else: - raise OpenLLMException(f"Model type {type(config)} is not supported yet.") - + if type(config) in _transformers.MODEL_FOR_CAUSAL_LM_MAPPING: idx = 0 + elif type(config) in _transformers.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: idx = 1 + else: raise OpenLLMException(f"Model type {type(config)} is not supported yet.") return getattr(_transformers, FRAMEWORK_TO_AUTOCLASS_MAPPING[llm.__llm_implementation__][idx]) - -def import_model( - llm: openllm.LLM[t.Any, t.Any], - *decls: t.Any, - trust_remote_code: bool, - **attrs: t.Any, -) -> bentoml.Model: +def import_model(llm: openllm.LLM[M, T], *decls: t.Any, trust_remote_code: bool, **attrs: t.Any) -> bentoml.Model: """Auto detect model type from given model_id and import it to bentoml's model store. For all kwargs, it will be parsed into `transformers.AutoConfig.from_pretrained` first, @@ -140,20 +109,13 @@ def import_model( config, hub_attrs, attrs = process_transformers_config(llm.model_id, trust_remote_code, **attrs) _, tokenizer_attrs = llm.llm_parameters quantize_method = llm._quantize_method - safe_serialisation = first_not_none( - attrs.get("safe_serialization"), default=llm._serialisation_format == "safetensors" - ) - if llm.__llm_implementation__ == "vllm": - # Disable safe serialization with vLLM - safe_serialisation = False - metadata: DictStrAny = { - "safe_serialisation": safe_serialisation, - "_quantize": quantize_method if quantize_method is not None else False, - } + safe_serialisation = first_not_none(attrs.get("safe_serialization"), default=llm._serialisation_format == "safetensors") + # Disable safe serialization with vLLM + if llm.__llm_implementation__ == "vllm": safe_serialisation = False + metadata: DictStrAny = {"safe_serialisation": safe_serialisation, "_quantize": quantize_method if quantize_method is not None else False} signatures: DictStrAny = {} if quantize_method == "gptq": - if not is_autogptq_available(): - raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'") + if not is_autogptq_available(): raise OpenLLMException("GPTQ quantisation requires 'auto-gptq' (Not found in local environment). Install it with 'pip install \"openllm[gptq]\"'") if llm.config["model_type"] != "causal_lm": raise OpenLLMException(f"GPTQ only support Causal LM (got {llm.__class__} of {llm.config['model_type']})") model = autogptq.AutoGPTQForCausalLM.from_quantized( llm.model_id, @@ -211,11 +173,9 @@ def import_model( # NOTE: We need to free up the cache after importing the model # in the case where users first run openllm start without the model # available locally. - if is_torch_available() and torch.cuda.is_available(): - torch.cuda.empty_cache() - + if is_torch_available() and torch.cuda.is_available(): torch.cuda.empty_cache() -def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Model: +def get(llm: openllm.LLM[M, T], auto_import: bool = False) -> bentoml.Model: """Return an instance of ``bentoml.Model`` from given LLM instance. By default, it will try to check the model in the local store. @@ -225,29 +185,17 @@ def get(llm: openllm.LLM[t.Any, t.Any], auto_import: bool = False) -> bentoml.Mo """ try: model = bentoml.models.get(llm.tag) - if model.info.module not in ( - "openllm.serialisation.transformers", - # compat with bentoml.transformers.get - "bentoml.transformers", - "bentoml._internal.frameworks.transformers", - __name__, - ): - raise bentoml.exceptions.NotFound( - f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'." - ) + # compat with bentoml.transformers.get + if model.info.module not in ("openllm.serialisation.transformers", "bentoml.transformers", "bentoml._internal.frameworks.transformers", __name__): + raise bentoml.exceptions.NotFound(f"Model {model.tag} was saved with module {model.info.module}, not loading with 'openllm.serialisation.transformers'.") if "runtime" in model.info.labels and model.info.labels["runtime"] != llm.runtime: - raise OpenLLMException( - f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}." - ) + raise OpenLLMException(f"Model {model.tag} was saved with runtime {model.info.labels['runtime']}, not loading with {llm.runtime}.") return model except bentoml.exceptions.NotFound: - if auto_import: - return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__) + if auto_import: return import_model(llm, trust_remote_code=llm.__llm_trust_remote_code__) raise - - -def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M: +def load_model(llm: openllm.LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M: """Load the model from BentoML store. By default, it will try to find check the model in the local store. @@ -287,7 +235,6 @@ def load_model(llm: openllm.LLM[M, t.Any], *decls: t.Any, **attrs: t.Any) -> M: model = BetterTransformer.transform(model) return t.cast("M", model) - def save_pretrained( llm: openllm.LLM[M, T], save_directory: str, From 493c18b34a1e1e0c4e0bb2ab4544b25501ce9c55 Mon Sep 17 00:00:00 2001 From: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> Date: Tue, 25 Jul 2023 03:23:20 +0000 Subject: [PATCH 2/2] chore: add changelog Signed-off-by: aarnphm-ec2-dev <29749331+aarnphm@users.noreply.github.com> --- changelog.d/141.fix.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 changelog.d/141.fix.md diff --git a/changelog.d/141.fix.md b/changelog.d/141.fix.md new file mode 100644 index 000000000..b949394db --- /dev/null +++ b/changelog.d/141.fix.md @@ -0,0 +1,3 @@ +Fixes model location while running within BentoContainer correctly + +This makes sure that the tags and model path are inferred correctly, based on BENTO_PATH and /.dockerenv