diff --git a/src/bentoml/__init__.py b/src/bentoml/__init__.py index 4cae949b3f2..dc7985f3759 100644 --- a/src/bentoml/__init__.py +++ b/src/bentoml/__init__.py @@ -17,11 +17,12 @@ from typing import TYPE_CHECKING +from ._internal.configuration import load_config +from ._internal.configuration import save_config from ._internal.configuration import BENTOML_VERSION as __version__ -from ._internal.configuration import load_global_config # Inject dependencies and configurations -load_global_config() +load_config() # Bento management APIs from .bentos import get @@ -175,4 +176,6 @@ "transformers", "xgboost", "monitor", + "load_config", + "save_config", ] diff --git a/src/bentoml/_internal/configuration/__init__.py b/src/bentoml/_internal/configuration/__init__.py index 63f2de67a3d..84304ec14b2 100644 --- a/src/bentoml/_internal/configuration/__init__.py +++ b/src/bentoml/_internal/configuration/__init__.py @@ -1,11 +1,13 @@ +from __future__ import annotations + import os import re import typing as t import logging from functools import lru_cache -from bentoml.exceptions import BentoMLException -from bentoml.exceptions import BentoMLConfigException +from ...exceptions import BentoMLException +from ...exceptions import BentoMLConfigException try: from ..._version import __version__ @@ -87,14 +89,14 @@ def is_pypi_installed_bentoml() -> bool: return is_tagged and is_clean and not_been_modified -def get_bentoml_config_file_from_env() -> t.Optional[str]: +def get_bentoml_config_file_from_env() -> str | None: if CONFIG_ENV_VAR in os.environ: # User local config file for customizing bentoml return expand_env_var(os.environ.get(CONFIG_ENV_VAR, "")) return None -def get_bentoml_override_config_from_env() -> t.Optional[str]: +def get_bentoml_override_config_from_env() -> str | None: if CONFIG_OVERRIDE_ENV_VAR in os.environ: # User local config options for customizing bentoml return os.environ.get(CONFIG_OVERRIDE_ENV_VAR, None) @@ -129,7 +131,7 @@ def get_quiet_mode() -> bool: return False -def load_global_config(bentoml_config_file: t.Optional[str] = None): +def load_config(bentoml_config_file: str | None = None): """Load global configuration of BentoML""" from .containers import BentoMLContainer @@ -141,24 +143,22 @@ def load_global_config(bentoml_config_file: t.Optional[str] = None): if bentoml_config_file: if not bentoml_config_file.endswith((".yml", ".yaml")): raise BentoMLConfigException( - "BentoML config file specified in ENV VAR does not end with `.yaml`: " - f"`BENTOML_CONFIG={bentoml_config_file}`" + f"BentoML config file specified in ENV VAR does not end with either '.yaml' or '.yml': 'BENTOML_CONFIG={bentoml_config_file}'" ) from None if not os.path.isfile(bentoml_config_file): raise FileNotFoundError( - "BentoML config file specified in ENV VAR not found: " - f"`BENTOML_CONFIG={bentoml_config_file}`" + f"BentoML config file specified in ENV VAR not found: 'BENTOML_CONFIG={bentoml_config_file}'" ) from None - bentoml_configuration = BentoMLConfiguration( - override_config_file=bentoml_config_file, - override_config_values=get_bentoml_override_config_from_env(), + BentoMLContainer.config.set( + BentoMLConfiguration( + override_config_file=bentoml_config_file, + override_config_values=get_bentoml_override_config_from_env(), + ).to_dict() ) - BentoMLContainer.config.set(bentoml_configuration.as_dict()) - -def save_global_config(config_file_handle: t.IO[t.Any]): +def save_config(config_file_handle: t.IO[t.Any]): import yaml from ..configuration.containers import BentoMLContainer diff --git a/src/bentoml/_internal/configuration/containers.py b/src/bentoml/_internal/configuration/containers.py index 59f6f2e0e21..34813063590 100644 --- a/src/bentoml/_internal/configuration/containers.py +++ b/src/bentoml/_internal/configuration/containers.py @@ -10,12 +10,7 @@ from dataclasses import dataclass import yaml -from schema import Or -from schema import And -from schema import Use -from schema import Schema -from schema import Optional -from schema import SchemaError +import schema as s from simple_di import Provide from simple_di import providers from deepmerge.merger import Merger @@ -23,6 +18,10 @@ from . import expand_env_var from ..utils import split_with_quotes from ..utils import validate_or_create_dir +from .helpers import flatten_dict +from .helpers import load_config_file +from .helpers import get_default_config +from .helpers import import_configuration_spec from ..context import component_context from ..resource import CpuResource from ..resource import system_resources @@ -38,249 +37,50 @@ config_merger = Merger( # merge dicts - [(dict, "merge")], + type_strategies=[(dict, "merge")], # override all other types - ["override"], + fallback_strategies=["override"], # override conflicting types - ["override"], + type_conflict_strategies=["override"], ) logger = logging.getLogger(__name__) -_check_tracing_type: t.Callable[[str], bool] = lambda s: s in ( - "zipkin", - "jaeger", - "otlp", -) -_check_otlp_protocol: t.Callable[[str], bool] = lambda s: s in ( - "grpc", - "http", -) -_larger_than: t.Callable[[int | float], t.Callable[[int | float], bool]] = ( - lambda target: lambda val: val > target -) -_larger_than_zero: t.Callable[[int | float], bool] = _larger_than(0) - - -def _check_sample_rate(sample_rate: float) -> None: - if sample_rate == 0.0: - logger.warning( - "Tracing enabled, but sample_rate is unset or zero. No traces will be collected. Please refer to https://docs.bentoml.org/en/latest/guides/tracing.html for more details." - ) - - -def _is_ip_address(addr: str) -> bool: - import socket - - try: - socket.inet_aton(addr) - return True - except socket.error: - return False - - -RUNNER_CFG_KEYS = ["batching", "resources", "logging", "metrics", "timeout"] - -RUNNER_CFG_SCHEMA = { - Optional("batching"): { - Optional("enabled"): bool, - Optional("max_batch_size"): And(int, _larger_than_zero), - Optional("max_latency_ms"): And(int, _larger_than_zero), - }, - # note there is a distinction between being unset and None here; if set to 'None' - # in configuration for a specific runner, it will override the global configuration - Optional("resources"): Or({Optional(str): object}, lambda s: s == "system", None), # type: ignore (incomplete schema typing) - Optional("logging"): { - # TODO add logging level configuration - Optional("access"): { - Optional("enabled"): bool, - Optional("request_content_length"): Or(bool, None), - Optional("request_content_type"): Or(bool, None), - Optional("response_content_length"): Or(bool, None), - Optional("response_content_type"): Or(bool, None), - }, - }, - Optional("metrics"): { - "enabled": bool, - "namespace": str, - }, - Optional("timeout"): And(int, _larger_than_zero), -} - -SCHEMA = Schema( - { - "api_server": { - "workers": Or(And(int, _larger_than_zero), None), - "timeout": And(int, _larger_than_zero), - "backlog": And(int, _larger_than(64)), - Optional("ssl"): { - Optional("certfile"): Or(str, None), - Optional("keyfile"): Or(str, None), - Optional("keyfile_password"): Or(str, None), - Optional("version"): Or(And(int, _larger_than_zero), None), - Optional("cert_reqs"): Or(int, None), - Optional("ca_certs"): Or(str, None), - Optional("ciphers"): Or(str, None), - }, - "metrics": { - "enabled": bool, - "namespace": str, - Optional("duration"): { - Optional("min"): And(float, _larger_than_zero), - Optional("max"): And(float, _larger_than_zero), - Optional("factor"): And(float, _larger_than(1.0)), - }, - }, - "logging": { - # TODO add logging level configuration - "access": { - "enabled": bool, - "request_content_length": Or(bool, None), - "request_content_type": Or(bool, None), - "response_content_length": Or(bool, None), - "response_content_type": Or(bool, None), - "format": { - "trace_id": str, - "span_id": str, - }, - }, - }, - "http": { - "host": And(str, _is_ip_address), - "port": And(int, _larger_than_zero), - "cors": { - "enabled": bool, - "access_control_allow_origin": Or(str, None), - "access_control_allow_credentials": Or(bool, None), - "access_control_allow_headers": Or([str], str, None), - "access_control_allow_methods": Or([str], str, None), - "access_control_max_age": Or(int, None), - "access_control_expose_headers": Or([str], str, None), - }, - }, - "grpc": { - "host": And(str, _is_ip_address), - "port": And(int, _larger_than_zero), - "metrics": { - "port": And(int, _larger_than_zero), - "host": And(str, _is_ip_address), - }, - "reflection": {"enabled": bool}, - "channelz": {"enabled": bool}, - "max_concurrent_streams": Or(int, None), - "max_message_length": Or(int, None), - "maximum_concurrent_rpcs": Or(int, None), - }, - "runner_probe": { - "enabled": bool, - "timeout": int, - "period": int, - }, - }, - "runners": { - **RUNNER_CFG_SCHEMA, - Optional(str): RUNNER_CFG_SCHEMA, # type: ignore (incomplete schema typing) - }, - "tracing": { - "type": Or(And(str, Use(str.lower), _check_tracing_type), None), - "sample_rate": Or(And(float, lambda i: i >= 0 and i <= 1), None), - "excluded_urls": Or([str], str, None), - Optional("zipkin"): {"url": Or(str, None)}, - Optional("jaeger"): {"address": Or(str, None), "port": Or(int, None)}, - Optional("otlp"): { - "protocol": Or(And(str, Use(str.lower), _check_otlp_protocol), None), - "url": Or(str, None), - }, - }, - Optional("monitoring"): { - "enabled": bool, - Optional("type"): Or(str, None), - Optional("options"): Or(dict, None), - }, - Optional("yatai"): { - "default_server": Or(str, None), - "servers": { - str: { - "url": Or(str, None), - "access_token": Or(str, None), - "access_token_header": Or(str, None), - "tls": { - "root_ca_cert": Or(str, None), - "client_key": Or(str, None), - "client_cert": Or(str, None), - "client_certificate_file": Or(str, None), - }, - }, - }, - }, - } -) - -_WARNING_MESSAGE = ( - "field 'api_server.%s' is deprecated and has been renamed to 'api_server.http.%s'" -) - class BentoMLConfiguration: def __init__( self, - override_config_file: t.Optional[str] = None, - override_config_values: t.Optional[str] = None, + override_config_file: str | None = None, + override_config_values: str | None = None, + *, validate_schema: bool = True, + use_version: int = 1, ): - # Load default configuration - default_config_file = os.path.join( - os.path.dirname(__file__), "default_configuration.yaml" - ) - with open(default_config_file, "rb") as f: - self.config: t.Dict[str, t.Any] = yaml.safe_load(f) - if validate_schema: - try: - SCHEMA.validate(self.config) - except SchemaError as e: - raise BentoMLConfigException( - "Default configuration 'default_configuration.yml' does not" - " conform to the required schema." - ) from e + # Load default configuration with latest version. + self.config = get_default_config(version=use_version) + spec_module = import_configuration_spec(version=use_version) # User override configuration if override_config_file is not None: - logger.info("Applying user config override from %s" % override_config_file) - if not os.path.exists(override_config_file): - raise BentoMLConfigException( - f"Config file {override_config_file} not found" - ) - with open(override_config_file, "rb") as f: - override_config: dict[str, t.Any] = yaml.safe_load(f) - - # compatibility layer with old configuration pre gRPC features - # api_server.[cors|port|host] -> api_server.http.$^ - if "api_server" in override_config: - user_api_config = override_config["api_server"] - # max_request_size is deprecated - if "max_request_size" in user_api_config: - logger.warning( - "'api_server.max_request_size' is deprecated and has become obsolete." - ) - user_api_config.pop("max_request_size") - # check if user are using older configuration - if "http" not in user_api_config: - user_api_config["http"] = {} - # then migrate these fields to newer configuration fields. - for field in ["port", "host", "cors"]: - if field in user_api_config: - old_field = user_api_config.pop(field) - user_api_config["http"][field] = old_field - logger.warning(_WARNING_MESSAGE, field, field) - - config_merger.merge(override_config["api_server"], user_api_config) - - assert all( - key not in override_config["api_server"] - for key in ["cors", "max_request_size", "host", "port"] + logger.info( + "Applying user config override from path: %s" % override_config_file + ) + override = load_config_file(override_config_file) + if "version" not in override: + # If users does not define a version, we then by default assume they are using v1 + # and we will migrate it to latest version + logger.debug( + "User config does not define a version, assuming given config is version %d..." + % use_version ) - - config_merger.merge(self.config, override_config) + current = use_version + else: + current = override["version"] + migration = getattr(import_configuration_spec(current), "migration", None) + # Running migration layer if it exists + if migration is not None: + override = migration(override_config=dict(flatten_dict(override))) + config_merger.merge(self.config, override) if override_config_values is not None: logger.info( @@ -298,68 +98,62 @@ def __init__( split_with_quotes(line, sep="=", quote='"') for line in lines ] } + # Note that this values will only support latest version of configuration, + # as there is no way for us to infer what values user can pass in. + # however, if users pass in a version inside this value, we will that to migrate up + # if possible + if "version" in override_config_map: + override_version = override_config_map["version"] + logger.debug( + "Found defined 'version=%d' in BENTOML_CONFIG_OPTIONS." + % override_version + ) + migration = getattr( + import_configuration_spec(override_version), "migration", None + ) + # Running migration layer if it exists + if migration is not None: + override_config_map = migration(override_config=override_config_map) + # Previous behaviour, before configuration versioning. try: - override_config = unflatten(override_config_map) + override = unflatten(override_config_map) except ValueError as e: raise BentoMLConfigException( - f"Failed to parse config options from the env var: {e}. \n *** Note: You can use '\"' to quote the key if it contains special characters. ***" + f"Failed to parse config options from the env var:\n{e}.\n*** Note: You can use '\"' to quote the key if it contains special characters. ***" ) from None - config_merger.merge(self.config, override_config) + config_merger.merge(self.config, override) if override_config_file is not None or override_config_values is not None: self._finalize() if validate_schema: try: - SCHEMA.validate(self.config) - except SchemaError as e: + spec_module.SCHEMA.validate(self.config) + except s.SchemaError as e: raise BentoMLConfigException( - "Invalid configuration file was given." - ) from e + f"Invalid configuration file was given:\n{e}" + ) from None def _finalize(self): + RUNNER_CFG_KEYS = ["batching", "resources", "logging", "metrics", "timeout"] global_runner_cfg = {k: self.config["runners"][k] for k in RUNNER_CFG_KEYS} - for key in self.config["runners"]: - if key not in RUNNER_CFG_KEYS: - runner_cfg = self.config["runners"][key] + custom_runners_cfg = dict( + filter( + lambda kv: kv[0] not in RUNNER_CFG_KEYS, + self.config["runners"].items(), + ) + ) + if custom_runners_cfg: + for runner_name, runner_cfg in custom_runners_cfg.items(): # key is a runner name if runner_cfg.get("resources") == "system": runner_cfg["resources"] = system_resources() - self.config["runners"][key] = config_merger.merge( + self.config["runners"][runner_name] = config_merger.merge( deepcopy(global_runner_cfg), runner_cfg, ) - def override(self, keys: t.List[str], value: t.Any): - if keys is None: - raise BentoMLConfigException( - "Configuration override key is None." - ) from None - if len(keys) == 0: - raise BentoMLConfigException( - "Configuration override key is empty." - ) from None - if value is None: - return - - c = self.config - for key in keys[:-1]: - if key not in c: - raise BentoMLConfigException( - "Configuration override key is invalid, %s" % keys - ) from None - c = c[key] - c[keys[-1]] = value - - try: - SCHEMA.validate(self.config) - except SchemaError as e: - raise BentoMLConfigException( - "Configuration after applying override does not conform" - " to the required schema, key=%s, value=%s." % (keys, value) - ) from e - - def as_dict(self) -> providers.ConfigDictType: + def to_dict(self) -> providers.ConfigDictType: return t.cast(providers.ConfigDictType, self.config) @@ -429,36 +223,39 @@ def serve_info() -> ServeInfo: return get_serve_info() + cors = http.cors + @providers.SingletonFactory @staticmethod def access_control_options( - allow_origins: str | None = Provide[http.cors.access_control_allow_origin], - allow_credentials: bool - | None = Provide[http.cors.access_control_allow_credentials], - expose_headers: list[str] - | str - | None = Provide[http.cors.access_control_expose_headers], + allow_origin: str | None = Provide[cors.access_control_allow_origin], + allow_origin_regex: str + | None = Provide[cors.access_control_allow_origin_regex], + allow_credentials: bool | None = Provide[cors.access_control_allow_credentials], allow_methods: list[str] | str - | None = Provide[http.cors.access_control_allow_methods], + | None = Provide[cors.access_control_allow_methods], allow_headers: list[str] | str - | None = Provide[http.cors.access_control_allow_headers], - max_age: int | None = Provide[http.cors.access_control_max_age], + | None = Provide[cors.access_control_allow_headers], + max_age: int | None = Provide[cors.access_control_max_age], + expose_headers: list[str] + | str + | None = Provide[cors.access_control_expose_headers], ) -> dict[str, list[str] | str | int]: - kwargs = dict( - allow_origins=allow_origins, - allow_credentials=allow_credentials, - expose_headers=expose_headers, - allow_methods=allow_methods, - allow_headers=allow_headers, - max_age=max_age, - ) - - filtered_kwargs: dict[str, list[str] | str | int] = { - k: v for k, v in kwargs.items() if v is not None + return { + k: v + for k, v in { + "allow_origins": allow_origin, + "allow_origin_regex": allow_origin_regex, + "allow_credentials": allow_credentials, + "allow_methods": allow_methods, + "allow_headers": allow_headers, + "max_age": max_age, + "expose_headers": expose_headers, + }.items() + if v is not None } - return filtered_kwargs api_server_workers = providers.Factory[int]( lambda workers: workers or math.ceil(CpuResource.from_system()), @@ -480,16 +277,18 @@ def metrics_client( return PrometheusClient(multiproc_dir=multiproc_dir) + tracing = config.tracing + @providers.SingletonFactory @staticmethod def tracer_provider( - tracer_type: str = Provide[config.tracing.type], - sample_rate: t.Optional[float] = Provide[config.tracing.sample_rate], - zipkin_server_url: t.Optional[str] = Provide[config.tracing.zipkin.url], - jaeger_server_address: t.Optional[str] = Provide[config.tracing.jaeger.address], - jaeger_server_port: t.Optional[int] = Provide[config.tracing.jaeger.port], - otlp_server_protocol: t.Optional[str] = Provide[config.tracing.otlp.protocol], - otlp_server_url: t.Optional[str] = Provide[config.tracing.otlp.url], + tracer_type: str = Provide[tracing.exporter_type], + sample_rate: float | None = Provide[tracing.sample_rate], + timeout: int | None = Provide[tracing.timeout], + max_tag_value_length: int | None = Provide[tracing.max_tag_value_length], + zipkin: dict[str, t.Any] = Provide[tracing.zipkin], + jaeger: dict[str, t.Any] = Provide[tracing.jaeger], + otlp: dict[str, t.Any] = Provide[tracing.otlp], ): from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.resources import Resource @@ -506,9 +305,11 @@ def tracer_provider( if sample_rate is None: sample_rate = 0.0 - + if sample_rate == 0.0: + logger.debug( + "'tracing.sample_rate' is set to zero. No traces will be collected. Please refer to https://docs.bentoml.org/en/latest/guides/tracing.html for more details." + ) resource = {} - # User can optionally configure the resource with the following environment variables. Only # configure resource if user has not explicitly configured it. if ( @@ -523,57 +324,70 @@ def tracer_provider( resource[SERVICE_NAMESPACE] = component_context.bento_name if component_context.bento_version: resource[SERVICE_VERSION] = component_context.bento_version - + # create tracer provider provider = TracerProvider( sampler=ParentBasedTraceIdRatio(sample_rate), resource=Resource.create(resource), ) - - if tracer_type == "zipkin" and zipkin_server_url is not None: + if tracer_type == "zipkin" and any(zipkin.values()): from opentelemetry.exporter.zipkin.json import ZipkinExporter - exporter = ZipkinExporter(endpoint=zipkin_server_url) - provider.add_span_processor(BatchSpanProcessor(exporter)) - _check_sample_rate(sample_rate) - return provider - elif ( - tracer_type == "jaeger" - and jaeger_server_address is not None - and jaeger_server_port is not None - ): - from opentelemetry.exporter.jaeger.thrift import JaegerExporter - + exporter = ZipkinExporter( + endpoint=zipkin.get("endpoint"), + local_node_ipv4=zipkin.get("local_node_ipv4"), + local_node_ipv6=zipkin.get("local_node_ipv6"), + local_node_port=zipkin.get("local_node_port"), + max_tag_value_length=max_tag_value_length, + timeout=timeout, + ) + elif tracer_type == "jaeger" and any(jaeger.values()): + protocol = jaeger.get("protocol") + if protocol == "thrift": + from opentelemetry.exporter.jaeger.thrift import JaegerExporter + elif protocol == "grpc": + from opentelemetry.exporter.jaeger.proto.grpc import JaegerExporter + else: + raise InvalidArgument( + f"Invalid 'api_server.tracing.jaeger.protocol' value: {protocol}" + ) from None exporter = JaegerExporter( - agent_host_name=jaeger_server_address, agent_port=jaeger_server_port + collector_endpoint=jaeger.get("collector_endpoint"), + max_tag_value_length=max_tag_value_length, + timeout=timeout, + **jaeger[protocol], ) - provider.add_span_processor(BatchSpanProcessor(exporter)) - _check_sample_rate(sample_rate) - return provider - elif ( - tracer_type == "otlp" - and otlp_server_protocol is not None - and otlp_server_url is not None - ): - if otlp_server_protocol == "grpc": + elif tracer_type == "otlp" and any(otlp.values()): + protocol = otlp.get("protocol") + if protocol == "grpc": from opentelemetry.exporter.otlp.proto.grpc import trace_exporter - - elif otlp_server_protocol == "http": + elif protocol == "http": from opentelemetry.exporter.otlp.proto.http import trace_exporter else: raise InvalidArgument( - f"Invalid otlp protocol: {otlp_server_protocol}" + f"Invalid 'api_server.tracing.jaeger.protocol' value: {protocol}" ) from None - exporter = trace_exporter.OTLPSpanExporter(endpoint=otlp_server_url) - provider.add_span_processor(BatchSpanProcessor(exporter)) - _check_sample_rate(sample_rate) - return provider + exporter = trace_exporter.OTLPSpanExporter( + endpoint=otlp.get("endpoint", None), + compression=otlp.get("compression", None), + timeout=timeout, + **otlp[protocol], + ) + elif tracer_type == "in_memory": + # This will be used during testing, user shouldn't use this otherwise. + # We won't document this in documentation. + from opentelemetry.sdk.trace.export import in_memory_span_exporter + + exporter = in_memory_span_exporter.InMemorySpanExporter() else: return provider + # When exporter is set + provider.add_span_processor(BatchSpanProcessor(exporter)) + return provider @providers.SingletonFactory @staticmethod def tracing_excluded_urls( - excluded_urls: str | list[str] | None = Provide[config.tracing.excluded_urls], + excluded_urls: str | list[str] | None = Provide[tracing.excluded_urls], ): from opentelemetry.util.http import ExcludeList from opentelemetry.util.http import parse_excluded_urls @@ -592,27 +406,26 @@ def tracing_excluded_urls( @providers.SingletonFactory @staticmethod def duration_buckets( - metrics: dict[str, t.Any] = Provide[api_server_config.metrics] + duration: dict[str, t.Any] = Provide[api_server_config.metrics.duration] ) -> tuple[float, ...]: """ Returns a tuple of duration buckets in seconds. If not explicitly configured, the Prometheus default is returned; otherwise, a set of exponential buckets generated based on the configuration is returned. """ - from ..utils.metrics import DEFAULT_BUCKET + from ..utils.metrics import INF from ..utils.metrics import exponential_buckets - if "duration" in metrics: - duration: dict[str, float] = metrics["duration"] - if duration.keys() >= {"min", "max", "factor"}: + if "buckets" in duration: + return tuple(duration["buckets"]) + (INF,) + else: + if len(set(duration) - {"min", "max", "factor"}) == 0: return exponential_buckets( duration["min"], duration["factor"], duration["max"] ) raise BentoMLConfigException( - "Keys 'min', 'max', and 'factor' are required for " - f"'duration' configuration, '{duration}'." - ) - return DEFAULT_BUCKET + f"Keys 'min', 'max', and 'factor' are required for 'duration' configuration, '{duration!r}'." + ) from None @providers.SingletonFactory @staticmethod diff --git a/src/bentoml/_internal/configuration/helpers.py b/src/bentoml/_internal/configuration/helpers.py new file mode 100644 index 00000000000..997eeca61ae --- /dev/null +++ b/src/bentoml/_internal/configuration/helpers.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import os +import socket +import typing as t +import logging +from typing import TYPE_CHECKING +from functools import singledispatch + +import yaml +import schema as s + +from ..utils import LazyLoader +from ...exceptions import BentoMLConfigException + +if TYPE_CHECKING: + from types import ModuleType + +logger = logging.getLogger(__name__) + +TRACING_TYPE = ["zipkin", "jaeger", "otlp", "in_memory"] + + +def import_configuration_spec(version: int) -> ModuleType: # pragma: no cover + return LazyLoader( + f"v{version}", + globals(), + f"bentoml._internal.configuration.v{version}", + exc_msg=f"Configuration version %d does not exists." % version, + ) + + +@singledispatch +def depth(_: t.Any, _level: int = 0): # pragma: no cover + return _level + + +@depth.register(dict) +def _(d: dict[str, t.Any], level: int = 0, **kw: t.Any): + return max(depth(v, level + 1, **kw) for v in d.values()) + + +def rename_fields( + d: dict[str, t.Any], + current: str, + replace_with: str | None = None, + *, + remove_only: bool = False, +): + # We assume that the given dictionary is already flattened. + # This function will rename the keys in the dictionary. + # If `replace_with` is None, then the key will be removed. + if depth(d) != 1: + raise ValueError( + "Given dictionary is not flattened. Use flatten_dict first." + ) from None + if current in d: + if remove_only: + logger.warning("Field '%s' is deprecated and will be removed." % current) + d.pop(current) + else: + assert replace_with, "'replace_with' must be provided." + logger.warning( + "Field '%s' is deprecated and has been renamed to '%s'" + % (current, replace_with) + ) + d[replace_with] = d.pop(current) + + +punctuation = r"""!"#$%&'()*+,-./:;<=>?@[\]^`{|}~""" + + +def flatten_dict( + d: t.MutableMapping[str, t.Any], + parent: str = "", + sep: str = ".", +) -> t.Generator[tuple[str, t.Any], None, None]: + """Flatten nested dictionary into a single level dictionary.""" + for k, v in d.items(): + k = f'"{k}"' if any(i in punctuation for i in k) else k + nkey = parent + sep + k if parent else k + if isinstance(v, t.MutableMapping): + yield from flatten_dict( + t.cast(t.MutableMapping[str, t.Any], v), parent=nkey, sep=sep + ) + else: + yield nkey, v + + +def load_config_file(path: str) -> dict[str, t.Any]: + """Load configuration from given path.""" + if not os.path.exists(path): + raise BentoMLConfigException( + "Configuration file %s not found." % path + ) from None + with open(path, "rb") as f: + config = yaml.safe_load(f) + return config + + +def get_default_config(version: int) -> dict[str, t.Any]: + config = load_config_file( + os.path.join( + os.path.dirname(__file__), f"v{version}", "default_configuration.yaml" + ) + ) + mod = import_configuration_spec(version) + assert hasattr(mod, "SCHEMA"), ( + "version %d does not have a validation schema" % version + ) + try: + mod.SCHEMA.validate(config) + except s.SchemaError as e: + raise BentoMLConfigException( + "Default configuration for version %d does not conform to given schema:\n%s" + % (version, e) + ) from None + return config + + +def validate_tracing_type(tracing_type: str) -> bool: + return tracing_type in TRACING_TYPE + + +def validate_otlp_protocol(protocol: str) -> bool: + return protocol in ["grpc", "http"] + + +def ensure_larger_than(target: int | float) -> t.Callable[[int | float], bool]: + """Ensure that given value is (lower, inf]""" + + def v(value: int | float) -> bool: + return value > target + + return v + + +ensure_larger_than_zero = ensure_larger_than(0) + + +def ensure_range( + lower: int | float, upper: int | float +) -> t.Callable[[int | float], bool]: + """Ensure that given value is within the range of [lower, upper].""" + + def v(value: int | float) -> bool: + return lower <= value <= upper + + return v + + +def ensure_iterable_type(typ_: type) -> t.Callable[[t.MutableSequence[t.Any]], bool]: + """Ensure that given mutable sequence has all elements of given types.""" + + def v(value: t.MutableSequence[t.Any]) -> bool: + return all(isinstance(i, typ_) for i in value) + + return v + + +def is_valid_ip_address(addr: str) -> bool: + """Check if given string is a valid IP address.""" + try: + _ = socket.inet_aton(addr) + return True + except socket.error: + return False diff --git a/src/bentoml/_internal/configuration/v1/__init__.py b/src/bentoml/_internal/configuration/v1/__init__.py new file mode 100644 index 00000000000..2d9776edf7f --- /dev/null +++ b/src/bentoml/_internal/configuration/v1/__init__.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +import re +import typing as t + +import schema as s + +from ..helpers import depth +from ..helpers import ensure_range +from ..helpers import rename_fields +from ..helpers import ensure_larger_than +from ..helpers import is_valid_ip_address +from ..helpers import ensure_iterable_type +from ..helpers import validate_tracing_type +from ..helpers import validate_otlp_protocol +from ..helpers import ensure_larger_than_zero +from ...utils.metrics import DEFAULT_BUCKET +from ...utils.unflatten import unflatten + +__all__ = ["SCHEMA", "migration"] + +TRACING_CFG = { + "exporter_type": s.Or(s.And(str, s.Use(str.lower), validate_tracing_type), None), + "sample_rate": s.Or(s.And(float, ensure_range(0, 1)), None), + "timeout": s.Or(s.And(int, ensure_larger_than_zero), None), + "max_tag_value_length": s.Or(int, None), + "excluded_urls": s.Or([str], str, None), + "zipkin": { + "endpoint": s.Or(str, None), + "local_node_ipv4": s.Or(s.Or(s.And(str, is_valid_ip_address), int), None), + "local_node_ipv6": s.Or(s.Or(s.And(str, is_valid_ip_address), int), None), + "local_node_port": s.Or(s.And(int, ensure_larger_than_zero), None), + }, + "jaeger": { + "protocol": s.Or( + s.And(str, s.Use(str.lower), lambda d: d in ["thrift", "grpc"]), + None, + ), + "collector_endpoint": s.Or(str, None), + "thrift": { + "agent_host_name": s.Or(str, None), + "agent_port": s.Or(int, None), + "udp_split_oversized_batches": s.Or(bool, None), + }, + "grpc": { + "insecure": s.Or(bool, None), + }, + }, + "otlp": { + "protocol": s.Or(s.And(str, s.Use(str.lower), validate_otlp_protocol), None), + "endpoint": s.Or(str, None), + "compression": s.Or( + s.And(str, lambda d: d in {"gzip", "none", "deflate"}), None + ), + "http": { + "certificate_file": s.Or(str, None), + "headers": s.Or(dict, None), + }, + "grpc": { + "insecure": s.Or(bool, None), + "headers": s.Or(lambda d: isinstance(d, t.Sequence), None), + }, + }, +} +_API_SERVER_CONFIG = { + "workers": s.Or(s.And(int, ensure_larger_than_zero), None), + "timeout": s.And(int, ensure_larger_than_zero), + "backlog": s.And(int, ensure_larger_than(64)), + "metrics": { + "enabled": bool, + "namespace": str, + s.Optional("duration"): { + s.Optional("buckets", default=DEFAULT_BUCKET): s.Or( + s.And(list, ensure_iterable_type(float)), None + ), + s.Optional("min"): s.Or(s.And(float, ensure_larger_than_zero), None), + s.Optional("max"): s.Or(s.And(float, ensure_larger_than_zero), None), + s.Optional("factor"): s.Or(s.And(float, ensure_larger_than(1.0)), None), + }, + }, + "logging": { + "access": { + "enabled": bool, + "request_content_length": s.Or(bool, None), + "request_content_type": s.Or(bool, None), + "response_content_length": s.Or(bool, None), + "response_content_type": s.Or(bool, None), + "format": { + "trace_id": str, + "span_id": str, + }, + }, + }, + "http": { + "host": s.And(str, is_valid_ip_address), + "port": s.And(int, ensure_larger_than_zero), + "cors": { + "enabled": bool, + "access_control_allow_origin": s.Or(str, None), + "access_control_allow_origin_regex": s.Or( + s.And(str, s.Use(re.compile)), None + ), + "access_control_allow_credentials": s.Or(bool, None), + "access_control_allow_headers": s.Or([str], str, None), + "access_control_allow_methods": s.Or([str], str, None), + "access_control_max_age": s.Or(int, None), + "access_control_expose_headers": s.Or([str], str, None), + }, + }, + "grpc": { + "host": s.And(str, is_valid_ip_address), + "port": s.And(int, ensure_larger_than_zero), + "metrics": { + "port": s.And(int, ensure_larger_than_zero), + "host": s.And(str, is_valid_ip_address), + }, + "reflection": {"enabled": bool}, + "channelz": {"enabled": bool}, + "max_concurrent_streams": s.Or(int, None), + "max_message_length": s.Or(int, None), + "maximum_concurrent_rpcs": s.Or(int, None), + }, + s.Optional("ssl"): { + "enabled": bool, + s.Optional("certfile"): s.Or(str, None), + s.Optional("keyfile"): s.Or(str, None), + s.Optional("keyfile_password"): s.Or(str, None), + s.Optional("version"): s.Or(s.And(int, ensure_larger_than_zero), None), + s.Optional("cert_reqs"): s.Or(int, None), + s.Optional("ca_certs"): s.Or(str, None), + s.Optional("ciphers"): s.Or(str, None), + }, + "runner_probe": { + "enabled": bool, + "timeout": int, + "period": int, + }, +} +_RUNNER_CONFIG = { + s.Optional("batching"): { + s.Optional("enabled"): bool, + s.Optional("max_batch_size"): s.And(int, ensure_larger_than_zero), + s.Optional("max_latency_ms"): s.And(int, ensure_larger_than_zero), + }, + # NOTE: there is a distinction between being unset and None here; if set to 'None' + # in configuration for a specific runner, it will override the global configuration. + s.Optional("resources"): s.Or({s.Optional(str): object}, lambda s: s == "system", None), # type: ignore (incomplete schema typing) + s.Optional("logging"): { + s.Optional("access"): { + s.Optional("enabled"): bool, + s.Optional("request_content_length"): s.Or(bool, None), + s.Optional("request_content_type"): s.Or(bool, None), + s.Optional("response_content_length"): s.Or(bool, None), + s.Optional("response_content_type"): s.Or(bool, None), + }, + }, + s.Optional("metrics"): { + "enabled": bool, + "namespace": str, + }, + s.Optional("timeout"): s.And(int, ensure_larger_than_zero), +} +SCHEMA = s.Schema( + { + s.Optional("version", default=1): s.And(int, lambda v: v == 1), + "api_server": _API_SERVER_CONFIG, + "runners": { + **_RUNNER_CONFIG, + s.Optional(str): _RUNNER_CONFIG, + }, + "tracing": TRACING_CFG, + s.Optional("monitoring"): { + "enabled": bool, + s.Optional("type"): s.Or(str, None), + s.Optional("options"): s.Or(dict, None), + }, + } +) + + +def migration(*, override_config: dict[str, t.Any]): + # We will use a flattened config to make it easier to migrate, + # Then we will convert it back to a nested config. + if depth(override_config) > 1: + raise ValueError("'override_config' must be a flattened dictionary.") from None + + if "version" not in override_config: + override_config["version"] = 1 + + # First we migrate api_server field + # 1. remove api_server.max_request_size (deprecated) + rename_fields( + override_config, current="api_server.max_request_size", remove_only=True + ) + + # 2. migrate api_server.[host|port] -> api_server.http.[host|port] + for f in ["host", "port"]: + rename_fields( + override_config, + current=f"api_server.{f}", + replace_with=f"api_server.http.{f}", + ) + + # 3. migrate api_server.cors.[access_control_*] -> api_server.http.cors.[*] + rename_fields( + override_config, + current="api_server.cors.enabled", + replace_with="api_server.http.cors.enabled", + ) + for f in [ + "allow_origin", + "allow_credentials", + "allow_headers", + "allow_methods", + "max_age", + "expose_headers", + ]: + rename_fields( + override_config, + current=f"api_server.cors.access_control_{f}", + replace_with=f"api_server.http.cors.access_control_{f}", + ) + + # 4. if ssl is present, in version 2 we introduce a api_server.ssl.enabled field to determine + # whether user want to enable SSL. + if len([f for f in override_config if f.startswith("api_server.ssl")]) != 0: + override_config["api_server.ssl.enabled"] = True + + # 5. migrate all tracing fields to api_server.tracing + # 5.1. migrate tracing.type -> api_server.tracing.expoerter_type + rename_fields( + override_config, + current="tracing.type", + replace_with="tracing.exporter_type", + ) + # 5.2. for Zipkin and OTLP, migrate tracing.[exporter].url -> api_server.tracing.[exporter].endpoint + for exporter in ["zipkin", "otlp"]: + rename_fields( + override_config, + current=f"tracing.{exporter}.url", + replace_with=f"tracing.{exporter}.endpoint", + ) + # 5.3. For Jaeger, migrate tracing.jaeger.[address|port] -> api_server.tracing.jaeger.thrift.[agent_host_name|agent_port] + rename_fields( + override_config, + current="tracing.jaeger.address", + replace_with="tracing.jaeger.thrift.agent_host_name", + ) + rename_fields( + override_config, + current="tracing.jaeger.port", + replace_with="tracing.jaeger.thrift.agent_port", + ) + # we also need to choose which protocol to use for jaeger. + if ( + len( + [ + f + for f in override_config + if f.startswith("api_server.tracing.jaeger.thrift") + ] + ) + != 0 + ): + override_config["tracing.jaeger.protocol"] = "thrift" + # 6. Last but not least, moving logging.formatting.* -> api_server.logging.access.format.* + for f in ["trace_id", "span_id"]: + rename_fields( + override_config, + current=f"logging.formatting.{f}_format", + replace_with=f"api_server.logging.access.format.{f}", + ) + return unflatten(override_config) diff --git a/src/bentoml/_internal/configuration/default_configuration.yaml b/src/bentoml/_internal/configuration/v1/default_configuration.yaml similarity index 63% rename from src/bentoml/_internal/configuration/default_configuration.yaml rename to src/bentoml/_internal/configuration/v1/default_configuration.yaml index 6d019c8c1c4..9ddd2741d50 100644 --- a/src/bentoml/_internal/configuration/default_configuration.yaml +++ b/src/bentoml/_internal/configuration/v1/default_configuration.yaml @@ -1,10 +1,33 @@ +version: 1 api_server: - workers: ~ # When this is set to null the number of available CPU cores is used. + workers: ~ # cpu_count() will be used when null timeout: 60 backlog: 2048 metrics: enabled: true namespace: bentoml_api_server + duration: + # https://github.com/prometheus/client_python/blob/f17a8361ad3ed5bc47f193ac03b00911120a8d81/prometheus_client/metrics.py#L544 + buckets: + [ + 0.005, + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + ] + min: ~ + max: ~ + factor: ~ logging: access: enabled: true @@ -16,6 +39,7 @@ api_server: trace_id: 032x span_id: 016x ssl: + enabled: false certfile: ~ keyfile: ~ keyfile_password: ~ @@ -32,6 +56,7 @@ api_server: access_control_allow_credentials: ~ access_control_allow_methods: ~ access_control_allow_headers: ~ + access_control_allow_origin_regex: ~ access_control_max_age: ~ access_control_expose_headers: ~ grpc: @@ -51,13 +76,13 @@ api_server: enabled: true timeout: 1 period: 10 - runners: + resources: ~ + timeout: 300 batching: enabled: true max_batch_size: 100 max_latency_ms: 10000 - resources: ~ logging: access: enabled: true @@ -66,23 +91,38 @@ runners: response_content_length: true response_content_type: true metrics: - enabled: True + enabled: true namespace: bentoml_runner - timeout: 300 - tracing: - type: zipkin + exporter_type: ~ sample_rate: ~ excluded_urls: ~ + timeout: ~ + max_tag_value_length: ~ zipkin: - url: ~ + endpoint: ~ + local_node_ipv4: ~ + local_node_ipv6: ~ + local_node_port: ~ jaeger: - address: ~ - port: ~ + protocol: thrift + collector_endpoint: ~ + thrift: + agent_host_name: ~ + agent_port: ~ + udp_split_oversized_batches: ~ + grpc: + insecure: ~ otlp: protocol: ~ - url: ~ - + endpoint: ~ + compression: ~ + http: + certificate_file: ~ + headers: ~ + grpc: + headers: ~ + insecure: ~ monitoring: enabled: true type: default diff --git a/src/bentoml/_internal/runner/runner_handle/remote.py b/src/bentoml/_internal/runner/runner_handle/remote.py index 8ef5506f301..32a3dbdc5df 100644 --- a/src/bentoml/_internal/runner/runner_handle/remote.py +++ b/src/bentoml/_internal/runner/runner_handle/remote.py @@ -22,7 +22,6 @@ if TYPE_CHECKING: import yarl - import aiohttp from aiohttp import BaseConnector from aiohttp.client import ClientSession @@ -114,7 +113,7 @@ def strip_query_params(url: yarl.URL) -> str: trace_configs=[ create_trace_config( # Remove all query params from the URL attribute on the span. - url_filter=strip_query_params, # type: ignore + url_filter=strip_query_params, tracer_provider=BentoMLContainer.tracer_provider.get(), ) ], diff --git a/src/bentoml/_internal/utils/lazy_loader.py b/src/bentoml/_internal/utils/lazy_loader.py index d369c79d61b..2793aa4ac26 100644 --- a/src/bentoml/_internal/utils/lazy_loader.py +++ b/src/bentoml/_internal/utils/lazy_loader.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys import types import typing as t @@ -24,18 +26,20 @@ class LazyLoader(types.ModuleType): def __init__( self, local_name: str, - parent_module_globals: t.Dict[str, t.Any], + parent_module_globals: dict[str, t.Any], name: str, - warning: t.Optional[str] = None, - exc_msg: t.Optional[str] = None, + warning: str | None = None, + exc_msg: str | None = None, + exc: type[Exception] = MissingDependencyException, ): self._local_name = local_name self._parent_module_globals = parent_module_globals self._warning = warning self._exc_msg = exc_msg - self._module: t.Optional[types.ModuleType] = None + self._exc = exc + self._module: types.ModuleType | None = None - super(LazyLoader, self).__init__(str(name)) + super().__init__(str(name)) def _load(self) -> types.ModuleType: """Load the module and insert it into the parent's globals.""" @@ -46,7 +50,7 @@ def _load(self) -> types.ModuleType: # The additional add to sys.modules ensures library is actually loaded. sys.modules[self._local_name] = module except ModuleNotFoundError as err: - raise MissingDependencyException(f"{self._exc_msg}") from err + raise self._exc(f"{self._exc_msg} (reason: {err})") from None # Emit a warning if one was specified if self._warning: diff --git a/tests/unit/_internal/configuration/conftest.py b/tests/unit/_internal/configuration/conftest.py new file mode 100644 index 00000000000..30943ed4144 --- /dev/null +++ b/tests/unit/_internal/configuration/conftest.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING + +import pytest + +from bentoml._internal.configuration.containers import BentoMLConfiguration + +if TYPE_CHECKING: + from pathlib import Path + + from simple_di.providers import ConfigDictType + + +@pytest.fixture(scope="function", name="container_from_file") +def fixture_container_from_file(tmp_path: Path) -> t.Callable[[str], ConfigDictType]: + def inner(config: str) -> ConfigDictType: + path = tmp_path / "configuration.yaml" + path.write_text(config) + return BentoMLConfiguration(override_config_file=path.__fspath__()).to_dict() + + return inner + + +@pytest.fixture(scope="function", name="container_from_envvar") +def fixture_container_from_envvar(): + def inner(override: str) -> ConfigDictType: + return BentoMLConfiguration(override_config_values=override).to_dict() + + return inner diff --git a/tests/unit/_internal/test_configuration.py b/tests/unit/_internal/configuration/test_containers.py similarity index 51% rename from tests/unit/_internal/test_configuration.py rename to tests/unit/_internal/configuration/test_containers.py index 534882efacf..253857f8aa3 100644 --- a/tests/unit/_internal/test_configuration.py +++ b/tests/unit/_internal/configuration/test_containers.py @@ -1,113 +1,56 @@ from __future__ import annotations import typing as t -import logging from typing import TYPE_CHECKING import pytest from bentoml.exceptions import BentoMLConfigException -from bentoml._internal.configuration.containers import BentoMLConfiguration if TYPE_CHECKING: - from pathlib import Path - - from _pytest.logging import LogCaptureFixture from simple_di.providers import ConfigDictType -@pytest.fixture(scope="function", name="config_cls") -def fixture_config_cls(tmp_path: Path) -> t.Callable[[str], ConfigDictType]: - def inner(config: str) -> ConfigDictType: - path = tmp_path / "configuration.yaml" - path.write_text(config) - return BentoMLConfiguration(override_config_file=path.__fspath__()).as_dict() - - return inner - - -@pytest.mark.usefixtures("config_cls") -def test_backward_configuration( - config_cls: t.Callable[[str], ConfigDictType], caplog: LogCaptureFixture -): - OLD_CONFIG = """\ +@pytest.mark.usefixtures("container_from_file") +def test_validate_configuration(container_from_file: t.Callable[[str], ConfigDictType]): + CONFIG = """\ +version: 1 api_server: - max_request_size: 8624612341 - port: 5000 + http: host: 0.0.0.0 """ - with caplog.at_level(logging.WARNING): - bentoml_cfg = config_cls(OLD_CONFIG) - assert all( - i not in bentoml_cfg["api_server"] for i in ("max_request_size", "port", "host") - ) - assert "cors" not in bentoml_cfg["api_server"] - assert bentoml_cfg["api_server"]["http"]["host"] == "0.0.0.0" - assert bentoml_cfg["api_server"]["http"]["port"] == 5000 - - -@pytest.mark.usefixtures("config_cls") -def test_validate(config_cls: t.Callable[[str], ConfigDictType]): + config = container_from_file(CONFIG) + assert config["api_server"]["http"]["host"] == "0.0.0.0" + INVALID_CONFIG = """\ +version: 1 api_server: - host: localhost + cors: + max_age: 12345 """ with pytest.raises( BentoMLConfigException, match="Invalid configuration file was given:*" ): - config_cls(INVALID_CONFIG) + container_from_file(INVALID_CONFIG) -@pytest.mark.usefixtures("config_cls") -def test_backward_warning( - config_cls: t.Callable[[str], ConfigDictType], caplog: LogCaptureFixture +@pytest.mark.usefixtures("container_from_envvar") +def test_containers_from_envvar( + container_from_envvar: t.Callable[[str], ConfigDictType] ): - OLD_HOST = """\ -api_server: - host: 0.0.0.0 -""" - with caplog.at_level(logging.WARNING): - config_cls(OLD_HOST) - assert "field 'api_server.host' is deprecated" in caplog.text - caplog.clear() + envvar = 'api_server.http.host="127.0.0.1" api_server.http.port=5000' + config = container_from_envvar(envvar) + assert config["api_server"]["http"]["host"] == "127.0.0.1" + assert config["api_server"]["http"]["port"] == 5000 - OLD_PORT = """\ -api_server: - port: 4096 -""" - with caplog.at_level(logging.WARNING): - config_cls(OLD_PORT) - assert "field 'api_server.port' is deprecated" in caplog.text - caplog.clear() - OLD_MAX_REQUEST_SIZE = """\ -api_server: - max_request_size: 8624612341 -""" - with caplog.at_level(logging.WARNING): - config_cls(OLD_MAX_REQUEST_SIZE) - assert ( - "'api_server.max_request_size' is deprecated and has become obsolete." - in caplog.text - ) - caplog.clear() - - OLD_CORS = """\ -api_server: - cors: - enabled: false -""" - with caplog.at_level(logging.WARNING): - config_cls(OLD_CORS) - assert "field 'api_server.cors' is deprecated" in caplog.text - caplog.clear() - - -@pytest.mark.usefixtures("config_cls") +@pytest.mark.parametrize("version", [None, 1]) +@pytest.mark.usefixtures("container_from_file") def test_bentoml_configuration_runner_override( - config_cls: t.Callable[[str], ConfigDictType] + container_from_file: t.Callable[[str], ConfigDictType], version: int | None ): - OVERRIDE_RUNNERS = """\ + OVERRIDE_RUNNERS = f"""\ +{'version: %d' % version if version else ''} runners: batching: enabled: False @@ -133,7 +76,7 @@ def test_bentoml_configuration_runner_override( enabled: True """ - bentoml_cfg = config_cls(OVERRIDE_RUNNERS) + bentoml_cfg = container_from_file(OVERRIDE_RUNNERS) runner_cfg = bentoml_cfg["runners"] # test_runner_1 @@ -166,14 +109,16 @@ def test_bentoml_configuration_runner_override( assert test_runner_batching["resources"]["cpu"] == 4 # should use global -@pytest.mark.usefixtures("config_cls") -def test_runner_gpu_configuration(config_cls: t.Callable[[str], ConfigDictType]): +@pytest.mark.usefixtures("container_from_file") +def test_runner_gpu_configuration( + container_from_file: t.Callable[[str], ConfigDictType] +): GPU_INDEX = """\ runners: resources: nvidia.com/gpu: [1, 2, 4] """ - bentoml_cfg = config_cls(GPU_INDEX) + bentoml_cfg = container_from_file(GPU_INDEX) assert bentoml_cfg["runners"]["resources"] == {"nvidia.com/gpu": [1, 2, 4]} GPU_INDEX_WITH_STRING = """\ @@ -181,13 +126,13 @@ def test_runner_gpu_configuration(config_cls: t.Callable[[str], ConfigDictType]) resources: nvidia.com/gpu: "[1, 2, 4]" """ - bentoml_cfg = config_cls(GPU_INDEX_WITH_STRING) + bentoml_cfg = container_from_file(GPU_INDEX_WITH_STRING) # this behaviour can be confusing assert bentoml_cfg["runners"]["resources"] == {"nvidia.com/gpu": "[1, 2, 4]"} -@pytest.mark.usefixtures("config_cls") -def test_runner_timeouts(config_cls: t.Callable[[str], ConfigDictType]): +@pytest.mark.usefixtures("container_from_file") +def test_runner_timeouts(container_from_file: t.Callable[[str], ConfigDictType]): RUNNER_TIMEOUTS = """\ runners: timeout: 50 @@ -196,7 +141,7 @@ def test_runner_timeouts(config_cls: t.Callable[[str], ConfigDictType]): test_runner_2: resources: system """ - bentoml_cfg = config_cls(RUNNER_TIMEOUTS) + bentoml_cfg = container_from_file(RUNNER_TIMEOUTS) runner_cfg = bentoml_cfg["runners"] assert runner_cfg["timeout"] == 50 assert runner_cfg["test_runner_1"]["timeout"] == 100 diff --git a/tests/unit/_internal/configuration/test_helpers.py b/tests/unit/_internal/configuration/test_helpers.py new file mode 100644 index 00000000000..1e8c59e31d7 --- /dev/null +++ b/tests/unit/_internal/configuration/test_helpers.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import pytest + +from bentoml.exceptions import BentoMLConfigException +from bentoml._internal.configuration.helpers import flatten_dict +from bentoml._internal.configuration.helpers import rename_fields +from bentoml._internal.configuration.helpers import load_config_file +from bentoml._internal.configuration.helpers import is_valid_ip_address + +if TYPE_CHECKING: + from pathlib import Path + + from _pytest.logging import LogCaptureFixture + + +def test_flatten_dict(): + assert dict(flatten_dict({"a": 1, "b": {"c": 2, "d": {"e": 3}}})) == { + "a": 1, + "b.c": 2, + "b.d.e": 3, + } + + assert dict( + flatten_dict({"runners": {"iris_clf": {"nvidia.com/gpu": [0, 1]}}}) + ) == {'runners.iris_clf."nvidia.com/gpu"': [0, 1]} + + assert dict(flatten_dict({"a": 1, "b": 2}, sep="_")) == {"a": 1, "b": 2} + + +def test_rename_fields(caplog: LogCaptureFixture): + # first, if given field is not in the dictionary, nothing will happen + d = {"a": 1, "b": 2} + rename_fields(d, "c", "d") + assert "a" in d + + # second, if given field is in the dictionary, it will be renamed + d = {"api_server.port": 5000} + with caplog.at_level(logging.WARNING): + rename_fields(d, "api_server.port", "api_server.http.port") + assert ( + "Field 'api_server.port' is deprecated and has been renamed to 'api_server.http.port'" + in caplog.text + ) + assert "api_server.http.port" in d and d["api_server.http.port"] == 5000 + caplog.clear() + + # third, if given field is in the dictionary, and remove_only is True, it will be + # removed. + d = {"api_server.port": 5000} + with caplog.at_level(logging.WARNING): + rename_fields(d, "api_server.port", remove_only=True) + + assert "Field 'api_server.port' is deprecated and will be removed." in caplog.text + assert len(d) == 0 and "api_server.port" not in d + caplog.clear() + + with pytest.raises(AssertionError, match="'replace_with' must be provided."): + # fourth, if no replace_with field is given, an AssertionError will be raised + d = {"api_server.port": 5000} + rename_fields(d, "api_server.port") + with pytest.raises(ValueError, match="Given dictionary is not flattened. *"): + # fifth, if the given dictionary is not flattened, a ValueError will be raised + d = {"a": 1, "b": {"c": 2}} + rename_fields(d, "b.c", "b.d.c") + + +def test_load_config_file(tmp_path: Path): + config = tmp_path / "configuration.yaml" + config.write_text("api_server:\n port: 5000") + assert load_config_file(config.__fspath__()) == {"api_server": {"port": 5000}} + + with pytest.raises(BentoMLConfigException) as e: + load_config_file("/tmp/nonexistent.yaml") + assert "Configuration file /tmp/nonexistent.yaml not found." in str(e.value) + + +def test_valid_ip_address(): + assert is_valid_ip_address("0.0.0.0") + assert not is_valid_ip_address("asdfadsf:143") diff --git a/tests/unit/_internal/configuration/test_v1_migration.py b/tests/unit/_internal/configuration/test_v1_migration.py new file mode 100644 index 00000000000..ad7eae29294 --- /dev/null +++ b/tests/unit/_internal/configuration/test_v1_migration.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import typing as t +import logging +from typing import TYPE_CHECKING + +import pytest + +from bentoml.exceptions import BentoMLConfigException + +if TYPE_CHECKING: + from _pytest.logging import LogCaptureFixture + from simple_di.providers import ConfigDictType + + +@pytest.mark.usefixtures("container_from_file") +def test_backward_configuration( + container_from_file: t.Callable[[str], ConfigDictType], + caplog: LogCaptureFixture, +): + OLD_CONFIG = """\ +api_server: + max_request_size: 8624612341 + port: 5000 + host: 0.0.0.0 +""" + with caplog.at_level(logging.WARNING): + bentoml_cfg = container_from_file(OLD_CONFIG) + assert all( + i not in bentoml_cfg["api_server"] for i in ("max_request_size", "port", "host") + ) + assert bentoml_cfg["api_server"]["http"]["host"] == "0.0.0.0" + assert bentoml_cfg["api_server"]["http"]["port"] == 5000 + assert ( + "Field 'api_server.max_request_size' is deprecated and will be removed." + in caplog.text + ) + + +@pytest.mark.usefixtures("container_from_envvar") +def test_backward_from_envvar( + container_from_envvar: t.Callable[[str], ConfigDictType], + caplog: LogCaptureFixture, +): + envvar = 'version=1 tracing.type="jaeger" tracing.jaeger.address="localhost" tracing.jaeger.port=6831 tracing.sample_rate=0.7' + with caplog.at_level(logging.WARNING): + bentoml_cfg = container_from_envvar(envvar) + assert bentoml_cfg["version"] == 1 + assert bentoml_cfg["tracing"]["exporter_type"] == "jaeger" + + assert ( + "Field 'tracing.jaeger.address' is deprecated and has been renamed to 'tracing.jaeger.thrift.agent_host_name'" + in caplog.text + ) + + +@pytest.mark.usefixtures("container_from_file") +def test_validate(container_from_file: t.Callable[[str], ConfigDictType]): + INVALID_CONFIG = """\ +api_server: + host: localhost +""" + with pytest.raises( + BentoMLConfigException, match="Invalid configuration file was given:*" + ): + container_from_file(INVALID_CONFIG) + + +@pytest.mark.usefixtures("container_from_file") +def test_backward_warning( + container_from_file: t.Callable[[str], ConfigDictType], + caplog: LogCaptureFixture, +): + OLD_HOST = """\ +api_server: + host: 0.0.0.0 +""" + with caplog.at_level(logging.WARNING): + container_from_file(OLD_HOST) + assert "Field 'api_server.host' is deprecated" in caplog.text + caplog.clear() + + OLD_PORT = """\ +api_server: + port: 4096 +""" + with caplog.at_level(logging.WARNING): + container_from_file(OLD_PORT) + assert "Field 'api_server.port' is deprecated" in caplog.text + caplog.clear() + + OLD_CORS = """\ +api_server: + cors: + enabled: false +""" + with caplog.at_level(logging.WARNING): + container_from_file(OLD_CORS) + assert "Field 'api_server.cors.enabled' is deprecated" in caplog.text + caplog.clear() + + OLD_JAEGER_ADDRESS = """\ +tracing: + type: jaeger + jaeger: + address: localhost +""" + with caplog.at_level(logging.WARNING): + container_from_file(OLD_JAEGER_ADDRESS) + assert "Field 'tracing.jaeger.address' is deprecated" in caplog.text + caplog.clear() + + OLD_JAEGER_PORT = """\ +tracing: + type: jaeger + jaeger: + port: 6881 +""" + with caplog.at_level(logging.WARNING): + container_from_file(OLD_JAEGER_PORT) + assert "Field 'tracing.jaeger.port' is deprecated" in caplog.text + caplog.clear() + + OLD_ZIPKIN_URL = """\ +tracing: + type: zipkin + zipkin: + url: localhost:6881 +""" + with caplog.at_level(logging.WARNING): + container_from_file(OLD_ZIPKIN_URL) + assert ( + "Field 'tracing.zipkin.url' is deprecated and has been renamed to 'tracing.zipkin.endpoint'" + in caplog.text + ) + caplog.clear() + + OLD_OTLP_URL = """\ +tracing: + type: otlp + otlp: + url: localhost:6881 +""" + with caplog.at_level(logging.WARNING): + container_from_file(OLD_OTLP_URL) + assert ( + "Field 'tracing.otlp.url' is deprecated and has been renamed to 'tracing.otlp.endpoint'" + in caplog.text + ) + caplog.clear() diff --git a/typings/schema.pyi b/typings/schema.pyi index 737797ae542..9b07a0fdb06 100644 --- a/typings/schema.pyi +++ b/typings/schema.pyi @@ -1,22 +1,16 @@ from types import BuiltinFunctionType, FunctionType -from typing import Any, Callable, Dict, List -from typing import Literal as LiteralType -from typing import Optional as OptionalType -from typing import Tuple, Union, overload +import typing as t +from typing import overload -GenericType = Union[ - "Schema", "And", "Or", "Use", "Optional", "Regex", "Literal", "Const" -] -AcceptedDictType = Dict[Union[str, GenericType], Any] -_CallableLike = Union[FunctionType, BuiltinFunctionType, Callable[..., Any]] +OpsType = Schema | And | Or | Use | Optional | Regex | Literal | Const +AcceptedDictType = dict[str | OpsType, t.Any] +_CallableLike = FunctionType | BuiltinFunctionType | t.Callable[..., t.Any] +_SchemaLike = _CallableLike | OpsType class SchemaError(Exception): - @overload - def __init__(self, autos: str, errors: str = ...) -> None: ... - @overload - def __init__(self, autos: str, errors: List[str] = ...) -> None: ... - @overload - def __init__(self, autos: List[str], errors: None = ...) -> None: ... + def __init__( + self, autos: str | list[str], errors: str | list[str] | None = ... + ) -> None: ... @property def code(self) -> str: ... @@ -26,90 +20,74 @@ class SchemaOnlyOneAllowedError(SchemaError): ... class SchemaForbiddenKeyError(SchemaError): ... class SchemaUnexpectedTypeError(SchemaError): ... -class And: +class OpsMeta(t.Protocol): @overload - def __init__( - self, - *args: _CallableLike, - error: List[str] = ..., - schema: "Schema" = ..., - ignore_extra_keys: bool = ... - ) -> None: ... + def validate(self, data: AcceptedDictType, **kwargs: t.Any) -> AcceptedDictType: ... @overload + def validate(self, data: t.Any, **kwargs: t.Any) -> t.Any: ... + +class And(OpsMeta): def __init__( self, - *args: _CallableLike, - error: List[str] = ..., - schema: None = ..., + *args: _SchemaLike, + error: list[str] = ..., + schema: Schema | None = ..., ignore_extra_keys: bool = ... ) -> None: ... def __repr__(self) -> str: ... @property - def args(self) -> Tuple[_CallableLike, ...]: ... - def validate(self, data: Any) -> Any: ... + def args(self) -> tuple[_CallableLike, ...]: ... class Or(And): - @overload def __init__( self, - *args: _CallableLike, - error: List[str] = ..., + *args: _SchemaLike | t.MutableSequence[t.Any] | None, + error: list[str] = ..., schema: None = ..., ignore_extra_keys: bool = ..., only_one: bool = ... ) -> None: ... - @overload - def __init__( - self, - *args: _CallableLike, - error: List[str] = ..., - schema: "Schema" = ..., - ignore_extra_keys: bool = ..., - only_one: bool = ... - ) -> None: ... def reset(self) -> None: ... - def validate(self, data: Any) -> Any: ... -class Regex: - NAMES: List[str] = ... +class Regex(OpsMeta): + NAMES: list[str] = ... + def __init__( self, pattern_str: str, - flags: OptionalType[int] = ..., - error: OptionalType[str] = ..., + flags: int | None = ..., + error: str | None = ..., ) -> None: ... def __repr__(self) -> str: ... @property def pattern_str(self) -> str: ... - def validate(self, data: str) -> str: ... -class Use: - def __init__( - self, callable_: _CallableLike, error: OptionalType[str] = ... - ) -> None: ... +class Use(OpsMeta): + def __init__(self, callable_: _CallableLike, error: str | None = ...) -> None: ... def __repr__(self) -> str: ... - def validate(self, data: Any) -> Any: ... -COMPARABLE = LiteralType["0"] -CALLABLE = LiteralType["1"] -VALIDATOR = LiteralType["2"] -TYPE = LiteralType["3"] -DICT = LiteralType["4"] -ITERABLE = LiteralType["5"] +COMPARABLE = t.Literal["0"] +CALLABLE = t.Literal["1"] +VALIDATOR = t.Literal["2"] +TYPE = t.Literal["3"] +DICT = t.Literal["4"] +ITERABLE = t.Literal["5"] def _priority( - s: object, -) -> Union[CALLABLE, COMPARABLE, VALIDATOR, TYPE, DICT, ITERABLE]: ... -def _invoke_with_optional_kwargs(f: Callable[..., Any], **kwargs: Any) -> Any: ... + s: OpsType, +) -> CALLABLE | COMPARABLE | VALIDATOR | TYPE | DICT | ITERABLE: ... +def _invoke_with_optional_kwargs( + f: t.Callable[..., t.Any], **kwargs: t.Any +) -> t.Any: ... -class Schema: +class Schema(OpsMeta): def __init__( self, - schema: Union[Schema, AcceptedDictType], - error: OptionalType[str] = ..., - ignore_extra_keys: OptionalType[bool] = ..., - name: OptionalType[str] = ..., - description: OptionalType[str] = ..., + schema: _SchemaLike | type | AcceptedDictType, + error: str | None = ..., + ignore_extra_keys: bool | None = ..., + name: str | None = ..., + description: str | None = ..., as_reference: bool = ..., ) -> None: ... def __repr__(self) -> str: ... @@ -122,87 +100,55 @@ class Schema: @property def ignore_extra_keys(self) -> bool: ... @staticmethod - def _dict_key_priority(s: Union[GenericType, object]) -> Union[float, int]: ... + def _dict_key_priority(s: OpsType) -> float | int: ... @staticmethod - def _is_optional_type(s: Union[GenericType, object]) -> bool: ... - def is_valid(self, data: Any, **kwargs: Any) -> bool: ... + def _is_optional_type(s: OpsType) -> bool: ... + def is_valid(self, data: t.Any, **kwargs: t.Any) -> bool: ... def _prepend_schema_name(self, message: str) -> str: ... - @overload - def validate(self, data: AcceptedDictType) -> AcceptedDictType: ... - @overload - def validate(self, data: Any) -> Any: ... - def json_schema(self, schema_id: Any, use_refs: bool = ...) -> Dict[str, Any]: ... + def json_schema(self, schema_id: str, use_refs: bool = ...) -> dict[str, t.Any]: ... class Optional(Schema): _MARKER: object = ... default: object key: str - @overload - def __init__( - self, - schema: Union[Schema, AcceptedDictType], - error: Union[str, List[str]] = ..., - ignore_extra_keys: OptionalType[bool] = ..., - name: OptionalType[str] = ..., - description: OptionalType[str] = ..., - as_reference: OptionalType[bool] = ..., - default: OptionalType[Any] = ..., - ) -> None: ... - @overload + def __init__( self, - schema: str, - error: Union[str, List[str]] = ..., - ignore_extra_keys: OptionalType[bool] = ..., - name: OptionalType[str] = ..., - description: OptionalType[str] = ..., - as_reference: OptionalType[bool] = ..., - default: OptionalType[Any] = ..., + schema: _SchemaLike | type | str, + error: str | list[str] | None = ..., + ignore_extra_keys: bool | None = ..., + name: str | None = ..., + description: str | None = ..., + as_reference: bool | None = ..., + default: t.Any = ..., ) -> None: ... def __hash__(self) -> int: ... def __eq__(self, other: Optional) -> bool: ... def reset(self) -> None: ... -_HookCallback = Callable[[str, AcceptedDictType, str], SchemaError] +_HookCallback = t.Callable[[str, t.Any, str | list[str]], SchemaError | None | t.Any] class Hook(Schema): key: Schema - def __init__( - self, - schema: Union[Schema, AcceptedDictType], - error: OptionalType[str] = ..., - ignore_extra_keys: OptionalType[bool] = ..., - name: OptionalType[str] = ..., - description: OptionalType[str] = ..., - as_reference: OptionalType[bool] = ..., - handler: OptionalType[_HookCallback] = ..., - ) -> None: ... -class Forbidden(Hook): - handler: Callable[[str, AcceptedDictType, str], SchemaForbiddenKeyError] def __init__( self, - schema: AcceptedDictType, - error: OptionalType[str] = ..., - ignore_extra_keys: OptionalType[bool] = ..., - name: OptionalType[str] = ..., - description: OptionalType[str] = ..., - as_reference: OptionalType[bool] = ..., + schema: _SchemaLike | str, + error: str | list[str] | None = ..., + ignore_extra_keys: bool | None = ..., + name: str | None = ..., + description: str | None = ..., + as_reference: bool | None = ..., + handler: _HookCallback | None = ..., ) -> None: ... - @staticmethod - def _default_function(nkey: str, data: Any, error: Exception) -> None: ... -class Literal(object): - def __init__(self, value: str, description: OptionalType[str] = ...) -> None: ... - def __str__(self) -> str: ... - def __repr__(self) -> str: ... +class Forbidden(Hook): ... + +class Literal: + def __init__(self, value: str, description: str | None = ...) -> None: ... @property def description(self) -> str: ... @property def schema(self) -> str: ... -class Const(Schema): - def validate(self, data: Union[Schema, AcceptedDictType]) -> AcceptedDictType: ... - -def _callable_str(callable_: Callable[..., Any]) -> str: ... -def _plural_s(sized: str) -> str: ... +class Const(Schema): ... diff --git a/typings/schema/__init__.pyi b/typings/schema/__init__.pyi deleted file mode 100644 index c97b783d716..00000000000 --- a/typings/schema/__init__.pyi +++ /dev/null @@ -1,26 +0,0 @@ -import typing as t - -_Validator = t.Union[type, t.List[type], t.Callable[[t.Any], bool], "_Operater", None] - -class _Operater: - def __init__(self, *args: _Validator, **kwargs: t.Any) -> None: ... - -class And(_Operater): ... -class Or(_Operater): ... - -class Use(_Operater): - def __init__(self, transformer: t.Callable[[t.Any], t.Any]) -> None: ... - -class Optional: - def __init__(self, key: str, *args: t.Any, **kwargs: t.Any) -> None: ... - -class Schema: - def __init__( - self, - mapping: t.Dict[t.Union[Optional, str], t.Any], - ) -> None: ... - def validate(self, data: t.Any) -> t.Any: ... - -class SchemaError(Exception): ... - -__all__ = ["And", "Or", "Schema", "SchemaError", "Use", "Optional"]