diff --git a/README.md b/README.md index c53d396..6c6dd31 100644 --- a/README.md +++ b/README.md @@ -56,17 +56,12 @@ Currently, those formats are supported: - TOML - requires the `toml` extra The format is automatically inferred from the config file extension. -You can still specify it manually using `Config.FORMAT`, for custom ones or configuring dump options. -Other formats can be added by subclassing either `Format` or `ReadOnlyFormat`. -To support more formats: -```python -from zenconfig import Config +Other formats can be added by subclassing `Format`. -Config.FORMATS.append(MyFormat) -# or -Config.FORMATS = [MyFormat] -``` +To register more formats: `Config.register_format(MyFormat)`. + +You can also force the format using `Config.FORMAT = MyFormat(...)`. This can be used to disable auto selection, or pass parameters to the format. ## Supported schemas Currently, those schemas are supported: @@ -75,8 +70,13 @@ Currently, those schemas are supported: - pydantic models - requires the `pydantic` extra -The format is automatically inferred from the config class. -You can still specify it manually using `Config.SCHEMA`, for custom ones or configuring dump options. +The schema is automatically inferred from the config class. + +Other schemas can be added by subclassing `Schema`. + +To register more schemas: `Config.register_schema(MySchema)`. + +You can also force the schema using `Config.SCHEMA = MySchema(...)`. This can be used to disable auto selection, or pass parameters to the schema. To use pydantic: ```python @@ -91,14 +91,5 @@ class MyPydanticConfig(Config, BaseModel): > to all class variable you override > otherwise pydantic will treat those as its own fields and complain. -To support more schemas: -```python -from zenconfig import Config - -Config.SCHEMAS.append(MySchema) -# or -Config.SCHEMAS = [MySchema] -``` - ## Contributing See [contributing guide](https://github.com/gpajot/zen-config/blob/main/CONTRIBUTING.md). diff --git a/pyproject.toml b/pyproject.toml index d6fe403..91d0bf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zenconfig" -version = "1.0.0" +version = "1.1.0" description = "Simple configuration loader for python." authors = ["Gabriel Pajot "] license = "MIT" diff --git a/tests/test_base.py b/tests/test_base.py index d0c7304..2679697 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -3,9 +3,7 @@ import pytest -from zenconfig.base import BaseConfig -from zenconfig.formats.abc import Format -from zenconfig.schemas.abc import Schema +from zenconfig.base import BaseConfig, Format, Schema class TestBaseConfig: diff --git a/tests/test_config_file.json b/tests/test_config_file.json new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_read.py b/tests/test_read.py index 9ff8d51..14ba90d 100644 --- a/tests/test_read.py +++ b/tests/test_read.py @@ -2,9 +2,8 @@ import pytest -from zenconfig.formats.abc import Format +from zenconfig.base import Format, Schema from zenconfig.read import ReadOnlyConfig -from zenconfig.schemas.abc import Schema class TestReadOnlyConfig: diff --git a/tests/test_write.py b/tests/test_write.py index faf941a..c63bcd1 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -2,8 +2,7 @@ import pytest -from zenconfig.formats.abc import Format -from zenconfig.schemas.abc import Schema +from zenconfig.base import Format, Schema from zenconfig.write import Config diff --git a/zenconfig/__init__.py b/zenconfig/__init__.py index b9fe0e7..5217606 100644 --- a/zenconfig/__init__.py +++ b/zenconfig/__init__.py @@ -1,2 +1,5 @@ +import zenconfig.formats +import zenconfig.schemas +from zenconfig.base import Format, Schema from zenconfig.read import ReadOnlyConfig from zenconfig.write import Config diff --git a/zenconfig/base.py b/zenconfig/base.py index 01c4461..21dad6b 100644 --- a/zenconfig/base.py +++ b/zenconfig/base.py @@ -1,12 +1,40 @@ import os -from abc import ABC +from abc import ABC, abstractmethod from pathlib import Path -from typing import ClassVar, List, Union +from typing import Any, ClassVar, Dict, Generic, List, Type, TypeVar, Union -from zenconfig.formats.abc import Format -from zenconfig.formats.selectors import FormatSelector, format_selectors -from zenconfig.schemas.abc import Schema -from zenconfig.schemas.selectors import SchemaSelector, schema_selectors + +class Format(ABC): + @classmethod + @abstractmethod + def handles(cls, path: Path) -> bool: + """Return whether the format handles the extension.""" + + @abstractmethod + def load(self, path: Path) -> Dict[str, Any]: + """Load the configuration file into a dict.""" + + @abstractmethod + def dump(self, path: Path, config: Dict[str, Any]) -> None: + """Dump in the configuration file.""" + + +C = TypeVar("C") + + +class Schema(ABC, Generic[C]): + @classmethod + @abstractmethod + def handles(cls, config_class: type) -> bool: + """Return whether the schema handles the config.""" + + @abstractmethod + def from_dict(self, cls: Type[C], cfg: Dict[str, Any]) -> C: + """Load the schema based on a dict configuration.""" + + @abstractmethod + def to_dict(self, config: Any) -> Dict[str, Any]: + """Dump the config to dict.""" class BaseConfig(ABC): @@ -14,12 +42,20 @@ class BaseConfig(ABC): PATH: ClassVar[Union[str, None]] = None _PATH: ClassVar[Union[Path, None]] = None - FORMATS: ClassVar[List[FormatSelector]] = format_selectors + FORMATS: ClassVar[List[Type[Format]]] = [] FORMAT: ClassVar[Union[Format, None]] = None - SCHEMAS: ClassVar[List[SchemaSelector]] = schema_selectors + SCHEMAS: ClassVar[List[Type[Schema]]] = [] SCHEMA: ClassVar[Union[Schema, None]] = None + @classmethod + def register_format(cls, format_class: Type[Format]) -> None: + cls.FORMATS.append(format_class) + + @classmethod + def register_schema(cls, schema_class: Type[Schema]) -> None: + cls.SCHEMAS.append(schema_class) + @classmethod def _path(cls) -> Path: if cls._PATH: @@ -41,24 +77,25 @@ def _format(cls) -> Format: if cls.FORMAT: return cls.FORMAT ext = cls._path().suffix - for selector in cls.FORMATS: - fmt = selector(ext) - if fmt: - cls.FORMAT = fmt - return fmt + path = cls._path() + for format_class in cls.FORMATS: + if not format_class.handles(path): + continue + cls.FORMAT = format_class() + return cls.FORMAT raise ValueError( - f"unsupported config file extension {ext} for config {cls.__qualname__}" + f"unsupported config file extension {ext} for config {cls.__qualname__}, maybe you are missing an extra" ) @classmethod def _schema(cls) -> Schema: if cls.SCHEMA: return cls.SCHEMA - for selector in cls.SCHEMAS: - schema = selector(cls) - if schema: - cls.SCHEMA = schema - return schema + for schema_class in cls.SCHEMAS: + if not schema_class.handles(cls): + continue + cls.SCHEMA = schema_class() + return cls.SCHEMA raise ValueError( f"could not infer config schema for config {cls.__qualname__}, maybe you are missing an extra" ) diff --git a/zenconfig/formats/__init__.py b/zenconfig/formats/__init__.py index e69de29..373bebe 100644 --- a/zenconfig/formats/__init__.py +++ b/zenconfig/formats/__init__.py @@ -0,0 +1,8 @@ +import contextlib + +from zenconfig.formats.json import JSONFormat + +with contextlib.suppress(ImportError): + from zenconfig.formats.yaml import YAMLFormat +with contextlib.suppress(ImportError): + from zenconfig.formats.toml import TOMLFormat diff --git a/zenconfig/formats/abc.py b/zenconfig/formats/abc.py deleted file mode 100644 index e0e4d23..0000000 --- a/zenconfig/formats/abc.py +++ /dev/null @@ -1,13 +0,0 @@ -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Any, Dict - - -class Format(ABC): - @abstractmethod - def load(self, path: Path) -> Dict[str, Any]: - """Load the configuration file into a dict.""" - - @abstractmethod - def dump(self, path: Path, config: Dict[str, Any]) -> None: - """Dump in the configuration file.""" diff --git a/zenconfig/formats/json.py b/zenconfig/formats/json.py index b1bb0bb..b37002c 100644 --- a/zenconfig/formats/json.py +++ b/zenconfig/formats/json.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Any, Dict -from zenconfig.formats.abc import Format +from zenconfig.base import BaseConfig, Format @dataclass @@ -12,6 +12,10 @@ class JSONFormat(Format): sort_keys: bool = True ensure_ascii: bool = False + @classmethod + def handles(cls, path: Path) -> bool: + return path.suffix == ".json" + def load(self, path: Path) -> Dict[str, Any]: return json.loads(path.read_text()) @@ -24,3 +28,6 @@ def dump(self, path: Path, config: Dict[str, Any]) -> None: ensure_ascii=self.ensure_ascii, ), ) + + +BaseConfig.register_format(JSONFormat) diff --git a/zenconfig/formats/selectors.py b/zenconfig/formats/selectors.py deleted file mode 100644 index 275ec10..0000000 --- a/zenconfig/formats/selectors.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Callable, List, Optional - -from zenconfig.formats.abc import Format - -FormatSelector = Callable[[str], Optional[Format]] - - -def _json_selector(ext: str) -> Optional[Format]: - if ext != ".json": - return None - from zenconfig.formats.json import JSONFormat - - return JSONFormat() - - -def _yaml_selector(ext: str) -> Optional[Format]: - if ext not in {".yml", ".yaml"}: - return None - try: - from zenconfig.formats.yaml import YAMLFormat - - return YAMLFormat() - except ImportError: - raise ValueError("yaml config is not supported, install the yaml extra") - - -def _toml_selector(ext: str) -> Optional[Format]: - if ext != ".toml": - return None - try: - from zenconfig.formats.toml import TOMLFormat - - return TOMLFormat() - except ImportError: - raise ValueError("toml config is not supported, install the toml extra") - - -format_selectors: List[FormatSelector] = [ - _json_selector, - _yaml_selector, - _toml_selector, -] diff --git a/zenconfig/formats/toml.py b/zenconfig/formats/toml.py index 39ccbf5..a024380 100644 --- a/zenconfig/formats/toml.py +++ b/zenconfig/formats/toml.py @@ -5,15 +5,27 @@ import tomli import tomli_w -from zenconfig.formats.abc import Format +from zenconfig.base import BaseConfig, Format @dataclass class TOMLFormat(Format): multiline_strings: bool = True + @classmethod + def handles(cls, path: Path) -> bool: + return path.suffix == ".toml" + def load(self, path: Path) -> Dict[str, Any]: return tomli.loads(path.read_text()) def dump(self, path: Path, config: Dict[str, Any]) -> None: - path.write_text(tomli_w.dumps(config, multiline_strings=self.multiline_strings)) + path.write_text( + tomli_w.dumps( + config, + multiline_strings=self.multiline_strings, + ) + ) + + +BaseConfig.register_format(TOMLFormat) diff --git a/zenconfig/formats/yaml.py b/zenconfig/formats/yaml.py index f85d5a3..d3be8ee 100644 --- a/zenconfig/formats/yaml.py +++ b/zenconfig/formats/yaml.py @@ -4,7 +4,7 @@ import yaml -from zenconfig.formats.abc import Format +from zenconfig.base import BaseConfig, Format @dataclass @@ -12,6 +12,10 @@ class YAMLFormat(Format): indent: int = 2 sort_keys: bool = True + @classmethod + def handles(cls, path: Path) -> bool: + return path.suffix in {".yml", ".yaml"} + def load(self, path: Path) -> Dict[str, Any]: return yaml.safe_load(path.read_text()) @@ -19,3 +23,6 @@ def dump(self, path: Path, config: Dict[str, Any]) -> None: path.write_text( yaml.safe_dump(config, indent=self.indent, sort_keys=self.sort_keys) ) + + +BaseConfig.register_format(YAMLFormat) diff --git a/zenconfig/schemas/__init__.py b/zenconfig/schemas/__init__.py index e69de29..bc2fd69 100644 --- a/zenconfig/schemas/__init__.py +++ b/zenconfig/schemas/__init__.py @@ -0,0 +1,7 @@ +import contextlib + +from zenconfig.schemas.dataclass import DataclassSchema +from zenconfig.schemas.dict import DictSchema + +with contextlib.suppress(ImportError): + from zenconfig.schemas.pydantic import PydanticSchema diff --git a/zenconfig/schemas/abc.py b/zenconfig/schemas/abc.py deleted file mode 100644 index af4d701..0000000 --- a/zenconfig/schemas/abc.py +++ /dev/null @@ -1,14 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, Type, TypeVar - -C = TypeVar("C") - - -class Schema(ABC, Generic[C]): - @abstractmethod - def from_dict(self, cls: Type[C], cfg: Dict[str, Any]) -> C: - """Load the schema based on a dict configuration.""" - - @abstractmethod - def to_dict(self, config: Any) -> Dict[str, Any]: - """Dump the config to dict.""" diff --git a/zenconfig/schemas/dataclass.py b/zenconfig/schemas/dataclass.py index 96a5876..455cd7e 100644 --- a/zenconfig/schemas/dataclass.py +++ b/zenconfig/schemas/dataclass.py @@ -1,12 +1,16 @@ from dataclasses import asdict, fields, is_dataclass from typing import Any, Dict, Type, TypeVar -from zenconfig.schemas.abc import Schema +from zenconfig.base import BaseConfig, Schema C = TypeVar("C") class DataclassSchema(Schema[C]): + @classmethod + def handles(cls, config_class: type) -> bool: + return is_dataclass(config_class) + def from_dict(self, cls: Type[C], cfg: Dict[str, Any]) -> C: return _load_nested(cls, cfg) @@ -20,6 +24,9 @@ def to_dict(self, config: C) -> Dict[str, Any]: return cfg +BaseConfig.register_schema(DataclassSchema) + + def _load_nested(cls: Type[C], cfg: Dict[str, Any]) -> C: """Load nested dataclasses.""" kwargs: Dict[str, Any] = {} diff --git a/zenconfig/schemas/dict.py b/zenconfig/schemas/dict.py index f3308b8..23eadee 100644 --- a/zenconfig/schemas/dict.py +++ b/zenconfig/schemas/dict.py @@ -1,13 +1,20 @@ from typing import Any, Dict, Type, TypeVar -from zenconfig.schemas.abc import Schema +from zenconfig.base import BaseConfig, Schema C = TypeVar("C", bound=dict) class DictSchema(Schema[C]): + @classmethod + def handles(cls, config_class: type) -> bool: + return issubclass(config_class, dict) + def from_dict(self, cls: Type[C], cfg: Dict[str, Any]) -> C: return cls(cfg) def to_dict(self, config: C) -> Dict[str, Any]: return dict(config) + + +BaseConfig.register_schema(DictSchema) diff --git a/zenconfig/schemas/pydantic.py b/zenconfig/schemas/pydantic.py index b8c4dac..886757b 100644 --- a/zenconfig/schemas/pydantic.py +++ b/zenconfig/schemas/pydantic.py @@ -3,7 +3,7 @@ from pydantic import BaseModel -from zenconfig.schemas.abc import Schema +from zenconfig.base import BaseConfig, Schema C = TypeVar("C", bound=BaseModel) @@ -13,6 +13,10 @@ class PydanticSchema(Schema[BaseModel]): exclude_unset: bool = False exclude_defaults: bool = True + @classmethod + def handles(cls, config_class: type) -> bool: + return issubclass(config_class, BaseModel) + def from_dict(self, cls: Type[C], cfg: Dict[str, Any]) -> C: return cls.parse_obj(cfg) @@ -21,3 +25,6 @@ def to_dict(self, config: C) -> Dict[str, Any]: exclude_unset=self.exclude_unset, exclude_defaults=self.exclude_defaults, ) + + +BaseConfig.register_schema(PydanticSchema) diff --git a/zenconfig/schemas/selectors.py b/zenconfig/schemas/selectors.py deleted file mode 100644 index 4b88b28..0000000 --- a/zenconfig/schemas/selectors.py +++ /dev/null @@ -1,41 +0,0 @@ -from dataclasses import is_dataclass -from typing import Callable, List, Optional - -from zenconfig.schemas.abc import Schema - -SchemaSelector = Callable[[type], Optional[Schema]] - - -def _dataclass_selector(cls: type) -> Optional[Schema]: - if not is_dataclass(cls): - return None - from zenconfig.schemas.dataclass import DataclassSchema - - return DataclassSchema() - - -def _pydantic_selector(cls: type) -> Optional[Schema]: - try: - from pydantic import BaseModel - except ImportError: - return None - if not issubclass(cls, BaseModel): - return None - from zenconfig.schemas.pydantic import PydanticSchema - - return PydanticSchema() - - -def _dict_selector(cls: type) -> Optional[Schema]: - if not issubclass(cls, dict): - return None - from zenconfig.schemas.dict import DictSchema - - return DictSchema() - - -schema_selectors: List[SchemaSelector] = [ - _dataclass_selector, - _pydantic_selector, - _dict_selector, -]