From 0355c6adc1ce30abdba90597ff149389bca5827a Mon Sep 17 00:00:00 2001 From: Gabriel Pajot Date: Tue, 4 Apr 2023 08:47:42 +0200 Subject: [PATCH] feat!: improve typing BREAKING CHANGES: - add `handles` method on `Schema` - `register_schema` now only accepts the schema instance as args --- README.md | 14 ++++++-------- pyproject.toml | 2 +- zenconfig/base.py | 19 +++++++++---------- zenconfig/formats/yaml.py | 6 +++++- zenconfig/read.py | 3 ++- zenconfig/schemas/attrs.py | 6 +++++- zenconfig/schemas/dataclass.py | 29 +++++++++++++++++++++++------ zenconfig/schemas/dict.py | 7 ++++++- zenconfig/schemas/pydantic.py | 6 +++++- zenconfig/write.py | 8 ++------ 10 files changed, 64 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 4effabc..8d82b11 100644 --- a/README.md +++ b/README.md @@ -60,9 +60,7 @@ Currently, those formats are supported: The format is automatically inferred from the config file extension. When loading from multiple files, files can be of multiple formats. -Other formats can be added by subclassing `Format`. - -To register more formats: `Config.register_format(MyFormat(...), ".ext1", ".ext2")`. +Other formats can be added by subclassing `Format`: `Config.register_format(MyFormat(...), ".ext1", ".ext2")`. > 💡 You can re-register a format to change dumping options. @@ -71,13 +69,11 @@ Currently, those schemas are supported: - plain dict - dataclasses - pydantic models - requires the `pydantic` extra -- attrs - requires the attrs extra +- attrs - requires the `attrs` extra The schema is automatically inferred from the config class. -Other schemas can be added by subclassing `Schema`. - -To register more schemas: `Config.register_schema(MySchema(...), lambda cls: ...)`. +Other schemas can be added by subclassing `Schema`: `Config.register_schema(MySchema(...))`. You can also force the schema by directly overriding the `SCHEMA` class attribute on your config. This can be used to disable auto selection, or pass arguments to the schema instance. @@ -91,7 +87,9 @@ For all schemas and formats, common built in types are handled [when dumping](ht > ⚠️ Keep in mind that only `attrs` and `pydantic` support casting when loading the config. -You can add custom encoders with `Config.ENCODERS`. For `pydantic`, stick with [the standard way of doing it](https://pydantic-docs.helpmanual.io/usage/exporting_models/#json_encoders). +You can add custom encoders with `Config.ENCODERS`. +For `pydantic`, stick with [the standard way of doing it](https://pydantic-docs.helpmanual.io/usage/exporting_models/#json_encoders). + ## Contributing See [contributing guide](https://github.com/gpajot/zen-config/blob/main/CONTRIBUTING.md). diff --git a/pyproject.toml b/pyproject.toml index daa35ab..efe982e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zenconfig" -version = "1.7.0" +version = "2.0.0" description = "Simple configuration loader for python." authors = ["Gabriel Pajot "] license = "MIT" diff --git a/zenconfig/base.py b/zenconfig/base.py index 537efd2..71332a8 100644 --- a/zenconfig/base.py +++ b/zenconfig/base.py @@ -3,7 +3,6 @@ from pathlib import Path from typing import ( Any, - Callable, ClassVar, Dict, Generic, @@ -44,6 +43,10 @@ def dump( class Schema(ABC, Generic[C]): """Abstract class for handling different config class types.""" + @abstractmethod + def handles(self, cls: type) -> bool: + """Return if a type is handled by this schema.""" + @abstractmethod def from_dict(self, cls: Type[C], cfg: Dict[str, Any]) -> C: """Load the schema based on a dict configuration.""" @@ -71,7 +74,7 @@ class BaseConfig(ABC): # All formats supported, by extension. __FORMATS: ClassVar[Dict[str, Format]] = {} # All schema classes supported. - __SCHEMAS: ClassVar[List[Tuple[Schema, Callable[[type], bool]]]] = [] + __SCHEMAS: ClassVar[List[Schema]] = [] @classmethod def register_format(cls, fmt: Format, *extensions: str) -> None: @@ -80,13 +83,9 @@ def register_format(cls, fmt: Format, *extensions: str) -> None: cls.__FORMATS[ext] = fmt @classmethod - def register_schema( - cls, - schema: Schema, - handles: Callable[[type], bool], - ) -> None: + def register_schema(cls, schema: Schema[C]) -> None: """Add a schema class to the list of supported ones.""" - cls.__SCHEMAS.append((schema, handles)) + cls.__SCHEMAS.append(schema) @classmethod def _paths(cls) -> Tuple[Path, ...]: @@ -130,8 +129,8 @@ def _schema(cls) -> Schema: """Get the schema instance for this config class.""" if cls.SCHEMA: return cls.SCHEMA - for schema, handles in cls.__SCHEMAS: - if not handles(cls): + for schema in cls.__SCHEMAS: + if not schema.handles(cls): continue cls.SCHEMA = schema return cls.SCHEMA diff --git a/zenconfig/formats/yaml.py b/zenconfig/formats/yaml.py index 290f0fd..584615a 100644 --- a/zenconfig/formats/yaml.py +++ b/zenconfig/formats/yaml.py @@ -21,7 +21,11 @@ def dump( config: Dict[str, Any], ) -> None: path.write_text( - yaml.safe_dump(config, indent=self.indent, sort_keys=self.sort_keys) + yaml.safe_dump( + config, + indent=self.indent, + sort_keys=self.sort_keys, + ), ) diff --git a/zenconfig/read.py b/zenconfig/read.py index 988ca0f..6c6f6ec 100644 --- a/zenconfig/read.py +++ b/zenconfig/read.py @@ -13,6 +13,7 @@ class MergeStrategy(IntEnum): SHALLOW = 1 DEEP = 2 + REPLACE = 3 C = TypeVar("C", bound="ReadOnlyConfig") @@ -39,7 +40,7 @@ def load(cls: Type[C]) -> C: path, ) config = fmt.load(path) - if not dict_config: + if not dict_config or cls.MERGE_STRATEGY is MergeStrategy.REPLACE: dict_config = config elif cls.MERGE_STRATEGY is MergeStrategy.SHALLOW: dict_config.update(config) diff --git a/zenconfig/schemas/attrs.py b/zenconfig/schemas/attrs.py index ea21c7e..ee695de 100644 --- a/zenconfig/schemas/attrs.py +++ b/zenconfig/schemas/attrs.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Type, TypeVar import attrs +from typing_extensions import TypeGuard from zenconfig.base import BaseConfig, Schema from zenconfig.encoder import Encoder @@ -9,6 +10,9 @@ class AttrsSchema(Schema[C]): + def handles(self, cls: type) -> TypeGuard[Type[attrs.AttrsInstance]]: + return attrs.has(cls) + def from_dict(self, cls: Type[C], cfg: Dict[str, Any]) -> C: return _load_nested(cls, cfg) @@ -16,7 +20,7 @@ def to_dict(self, config: C, encoder: Encoder) -> Dict[str, Any]: return encoder(attrs.asdict(config)) -BaseConfig.register_schema(AttrsSchema(), attrs.has) +BaseConfig.register_schema(AttrsSchema()) def _load_nested(cls: Type[C], cfg: Dict[str, Any]) -> C: diff --git a/zenconfig/schemas/dataclass.py b/zenconfig/schemas/dataclass.py index 42bd62d..0ef388d 100644 --- a/zenconfig/schemas/dataclass.py +++ b/zenconfig/schemas/dataclass.py @@ -1,27 +1,44 @@ -from dataclasses import asdict, fields, is_dataclass -from typing import Any, Dict, Type, TypeVar +from dataclasses import Field, asdict, fields, is_dataclass +from typing import ( + Any, + ClassVar, + Dict, + Protocol, + Type, + TypeVar, +) + +from typing_extensions import TypeGuard from zenconfig.base import BaseConfig, Schema from zenconfig.encoder import Encoder -C = TypeVar("C") + +class DataclassInstance(Protocol): + __dataclass_fields__: ClassVar[Dict[str, Field]] + + +C = TypeVar("C", bound=DataclassInstance) class DataclassSchema(Schema[C]): + def handles(self, cls: type) -> TypeGuard[Type[DataclassInstance]]: + return is_dataclass(cls) + def from_dict(self, cls: Type[C], cfg: Dict[str, Any]) -> C: return _load_nested(cls, cfg) def to_dict(self, config: C, encoder: Encoder) -> Dict[str, Any]: - return encoder(asdict(config)) # type: ignore [call-overload] + return encoder(asdict(config)) -BaseConfig.register_schema(DataclassSchema(), is_dataclass) +BaseConfig.register_schema(DataclassSchema()) def _load_nested(cls: Type[C], cfg: Dict[str, Any]) -> C: """Load nested dataclasses.""" kwargs: Dict[str, Any] = {} - for field in fields(cls): # type: ignore [arg-type] + for field in fields(cls): if field.name not in cfg: continue value = cfg[field.name] diff --git a/zenconfig/schemas/dict.py b/zenconfig/schemas/dict.py index 3edae1d..7af408f 100644 --- a/zenconfig/schemas/dict.py +++ b/zenconfig/schemas/dict.py @@ -1,5 +1,7 @@ from typing import Any, Dict, Type, TypeVar +from typing_extensions import TypeGuard + from zenconfig.base import BaseConfig, Schema from zenconfig.encoder import Encoder @@ -7,6 +9,9 @@ class DictSchema(Schema[C]): + def handles(self, cls: type) -> TypeGuard[Type[dict]]: + return issubclass(cls, dict) + def from_dict(self, cls: Type[C], cfg: Dict[str, Any]) -> C: return cls(cfg) @@ -14,4 +19,4 @@ def to_dict(self, config: C, encoder: Encoder) -> Dict[str, Any]: return encoder(config) -BaseConfig.register_schema(DictSchema(), lambda cls: issubclass(cls, dict)) +BaseConfig.register_schema(DictSchema()) diff --git a/zenconfig/schemas/pydantic.py b/zenconfig/schemas/pydantic.py index 60ea4a4..15f589b 100644 --- a/zenconfig/schemas/pydantic.py +++ b/zenconfig/schemas/pydantic.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Type, TypeVar from pydantic import BaseModel +from typing_extensions import TypeGuard from zenconfig.base import BaseConfig, Schema from zenconfig.encoder import Encoder, encode @@ -16,6 +17,9 @@ class PydanticSchema(Schema[C]): exclude_unset: bool = False exclude_defaults: bool = True + def handles(self, cls: type) -> TypeGuard[Type[BaseModel]]: + return issubclass(cls, BaseModel) + def from_dict(self, cls: Type[C], cfg: Dict[str, Any]) -> C: return cls.parse_obj(cfg) @@ -29,7 +33,7 @@ def to_dict(self, config: C, encoder: Encoder) -> Dict[str, Any]: ) -BaseConfig.register_schema(PydanticSchema(), lambda cls: issubclass(cls, BaseModel)) +BaseConfig.register_schema(PydanticSchema()) def _encoder(config: BaseModel) -> Encoder: diff --git a/zenconfig/write.py b/zenconfig/write.py index aa4bc2d..f58e7a0 100644 --- a/zenconfig/write.py +++ b/zenconfig/write.py @@ -1,8 +1,7 @@ import logging -import sys from abc import ABC from functools import partial -from typing import ClassVar, Dict, Optional +from typing import ClassVar, Optional from zenconfig.base import ZenConfigError from zenconfig.encoder import Encoder, Encoders, combine_encoders, encode @@ -45,12 +44,9 @@ def save(self) -> None: def clear(self) -> None: """Delete the config file(s).""" - kwargs: Dict[str, bool] = {} - if sys.version_info[:2] != (3, 7): - kwargs["missing_ok"] = True for path in self._paths(): logger.debug("deleting file at path %s", path) - path.unlink(**kwargs) + path.unlink(missing_ok=True) @classmethod def _encoder(cls) -> Encoder: