Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: issue overriding default service config from config file #4627

Merged
merged 3 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/bentoml/_internal/configuration/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def __init__(
if override_defaults:
if migration is not None:
override_defaults = migration(
default_config=self.config,
override_config=dict(flatten_dict(override_defaults)),
)
config_merger.merge(self.config, override_defaults)
Expand All @@ -85,7 +84,6 @@ def __init__(
# Running migration layer if it exists
if migration is not None:
override = migration(
default_config=self.config,
override_config=dict(flatten_dict(override)),
)
config_merger.merge(self.config, override)
Expand All @@ -97,7 +95,6 @@ def __init__(
# Running migration layer if it exists
if migration is not None:
override_config_json = migration(
default_config=self.config,
override_config=dict(flatten_dict(override_config_json)),
)
config_merger.merge(self.config, override_config_json)
Expand All @@ -122,9 +119,7 @@ def __init__(
}
# Running migration layer if it exists
if migration is not None:
override_config_map = migration(
default_config=self.config, override_config=override_config_map
)
override_config_map = migration(override_config=override_config_map)
# Previous behaviour, before configuration versioning.
try:
override = unflatten(override_config_map)
Expand All @@ -133,6 +128,9 @@ def __init__(
f"Failed to parse config options from the env var:\n{e}.\n*** Note: You can use '\"' to quote the key if it contains special characters. ***"
) from None
config_merger.merge(self.config, override)

if finalize_config := getattr(spec_module, "finalize_config", None):
finalize_config(self.config)
expand_env_var_in_values(self.config)

if validate_schema:
Expand Down
42 changes: 18 additions & 24 deletions src/bentoml/_internal/configuration/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
import typing as t
from copy import deepcopy
from numbers import Real

import schema as s
Expand All @@ -13,14 +14,11 @@
from ..helpers import ensure_larger_than
from ..helpers import ensure_larger_than_zero
from ..helpers import ensure_range
from ..helpers import flatten_dict
from ..helpers import is_valid_ip_address
from ..helpers import rename_fields
from ..helpers import validate_otlp_protocol
from ..helpers import validate_tracing_type

__all__ = ["SCHEMA", "migration"]

TRACING_CFG = {
"exporter_type": s.Or(s.And(str, s.Use(str.lower), validate_tracing_type), None),
"sample_rate": s.Or(s.And(float, ensure_range(0, 1)), None),
Expand Down Expand Up @@ -194,7 +192,7 @@
)


def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t.Any]):
def migration(*, override_config: dict[str, t.Any]):
# We will use a flattened config to make it easier to migrate,
# Then we will convert it back to a nested config.
if depth(override_config) > 1:
Expand Down Expand Up @@ -310,8 +308,13 @@ def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t.
replace_with=f"runners.{runner_name}.traffic.timeout",
)

return unflatten(override_config)


def finalize_config(config: dict[str, t.Any]) -> None:
from ..containers import config_merger

# 8. if runner is overriden, set the runner default values
default_runner_config = dict(flatten_dict(default_config["runners"]))
RUNNER_CFG_KEYS = [
"batching",
"resources",
Expand All @@ -320,23 +323,14 @@ def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t.
"traffic",
"workers_per_resource",
]
default_runner_config: dict[str, t.Any] = {}
for runner_name, runner_cfg in default_config["runners"].items():
if runner_name in RUNNER_CFG_KEYS:
default_runner_config[runner_name] = runner_cfg
default_runner_config: dict[str, t.Any] = {
key: value for key, value in config["runners"].items() if key in RUNNER_CFG_KEYS
}

for key in list(override_config):
if key.startswith("runners."):
key_parts = key.split(".")
runner_name = key_parts[1]
if runner_name in RUNNER_CFG_KEYS:
default_runner_config[".".join(key_parts[1:])] = override_config[key]
for i in range(2, len(key_parts)):
if (k := ".".join(key_parts[1:i])) in default_runner_config:
del default_runner_config[k]
else:
if runner_name not in default_config["runners"].keys():
default_config["runners"][runner_name] = unflatten(
default_runner_config
)
return unflatten(override_config)
for runner_name, runner_cfg in config["runners"].items():
if runner_name in RUNNER_CFG_KEYS:
continue
# key is a runner name
config["runners"][runner_name] = config_merger.merge(
deepcopy(default_runner_config), runner_cfg
)
45 changes: 19 additions & 26 deletions src/bentoml/_internal/configuration/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@
from ..helpers import ensure_larger_than
from ..helpers import ensure_larger_than_zero
from ..helpers import ensure_range
from ..helpers import flatten_dict
from ..helpers import is_valid_ip_address
from ..helpers import validate_otlp_protocol
from ..helpers import validate_tracing_type

__all__ = ["SCHEMA", "migration"]

TRACING_CFG = {
"exporter_type": s.Or(s.And(str, s.Use(str.lower), validate_tracing_type), None),
"sample_rate": s.Or(s.And(float, ensure_range(0, 1)), None),
Expand Down Expand Up @@ -187,7 +184,7 @@
)


def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t.Any]):
def migration(*, override_config: dict[str, t.Any]):
# We will use a flattened config to make it easier to migrate,
# Then we will convert it back to a nested config.
if depth(override_config) > 1:
Expand All @@ -196,6 +193,12 @@ def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t.
if "version" not in override_config:
override_config["version"] = 2

return unflatten(override_config)


def finalize_config(config: dict[str, t.Any]) -> dict[str, t.Any]:
from ..containers import config_merger

SERVICE_CFG_KEYS = [
"batching",
"resources",
Expand All @@ -212,26 +215,16 @@ def migration(*, default_config: dict[str, t.Any], override_config: dict[str, t.
"monitoring",
"tracing",
]
default_service_config: dict[str, t.Any] = {}
for svc, svc_cfg in default_config["services"].items():
if svc in SERVICE_CFG_KEYS:
default_service_config[svc] = svc_cfg
default_service_config = dict(flatten_dict(default_service_config))

for key in list(override_config):
if key.startswith("services."):
# NOTE: We need to remove the quotation in case the runner name includes dashes.
# Since unflatten_dict will include the quotes for given name
key_parts = [s.replace('"', "") for s in key.split(".")]
service_name = key_parts[1]
if service_name in SERVICE_CFG_KEYS:
default_service_config[".".join(key_parts[1:])] = override_config[key]
for i in range(2, len(key_parts)):
if (k := ".".join(key_parts[1:i])) in default_service_config:
del default_service_config[k]
else:
if service_name not in default_config["services"].keys():
default_config["services"][service_name] = unflatten(
default_service_config
)
return unflatten(override_config)
default_service_config = {
key: value
for key, value in config["services"].items()
if key in SERVICE_CFG_KEYS
}

for svc, service_config in config["services"].items():
if svc in SERVICE_CFG_KEYS:
continue
config["services"][svc] = config_merger.merge(
default_service_config, service_config
)
Loading