diff --git a/gto/_pydantic.py b/gto/_pydantic.py deleted file mode 100644 index c3d380d..0000000 --- a/gto/_pydantic.py +++ /dev/null @@ -1,30 +0,0 @@ -__all__ = [ - "BaseModel", - "BaseSettings", - "ValidationError", - "parse_obj_as", - "validator", - "InitSettingsSource", -] - - -try: - from pydantic.v1 import ( - BaseModel, - BaseSettings, - ValidationError, - parse_obj_as, - validator, - ) - from pydantic.v1.env_settings import InitSettingsSource -except ImportError: - from pydantic import ( # type: ignore[no-redef,assignment] - BaseModel, - BaseSettings, - ValidationError, - parse_obj_as, - validator, - ) - from pydantic.env_settings import ( # type: ignore[no-redef,assignment] - InitSettingsSource, - ) diff --git a/gto/base.py b/gto/base.py index c5871d9..cbb5286 100644 --- a/gto/base.py +++ b/gto/base.py @@ -1,6 +1,7 @@ from datetime import datetime from typing import Any, Dict, FrozenSet, List, Optional, Sequence, Union +from pydantic import BaseModel, ConfigDict from scmrepo.git import Git from gto.config import RegistryConfig @@ -12,7 +13,6 @@ ) from gto.versions import SemVer -from ._pydantic import BaseModel from .exceptions import ( ArtifactNotFound, ManyVersions, @@ -41,7 +41,7 @@ def event(self): return self.__class__.__name__.lower() def dict_state(self, exclude=None): - state = self.dict(exclude=exclude) + state = self.model_dump(exclude=exclude) state["event"] = self.event return state @@ -178,7 +178,7 @@ def ref(self): return self.authoring_event.ref def dict_state(self, exclude=None): - version = self.dict(exclude=exclude) + version = self.model_dump(exclude=exclude) version["is_active"] = self.is_active version["activated_at"] = self.activated_at version["created_at"] = self.created_at @@ -565,9 +565,7 @@ def find_version_at_commit( class BaseRegistryState(BaseModel): artifacts: Dict[str, Artifact] = {} - - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) def add_artifact(self, name): self.artifacts[name] = Artifact(artifact=name, versions=[]) @@ -623,9 +621,7 @@ class BaseManager(BaseModel): scm: Git actions: FrozenSet[Action] config: RegistryConfig - - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) def update_state(self, state: BaseRegistryState) -> BaseRegistryState: raise NotImplementedError diff --git a/gto/cli.py b/gto/cli.py index 3ecf38b..d04a265 100644 --- a/gto/cli.py +++ b/gto/cli.py @@ -810,7 +810,7 @@ def stages( def print_state(repo: str = option_repo): """Technical cmd: Print current registry state.""" state = make_ready_to_serialize( - gto.api._get_state(repo).dict() # pylint: disable=protected-access + gto.api._get_state(repo).model_dump() # pylint: disable=protected-access ) format_echo(state, "json") @@ -833,7 +833,7 @@ def doctor( echo(f"{EMOJI_FAIL} Fail to parse config") echo("---------------------------------") - gto.api._get_state(repo).dict() # pylint: disable=protected-access + gto.api._get_state(repo).model_dump() # pylint: disable=protected-access with cli_echo(): echo(f"{EMOJI_OK} No issues found") diff --git a/gto/config.py b/gto/config.py index 5551da8..62f5f3e 100644 --- a/gto/config.py +++ b/gto/config.py @@ -1,16 +1,22 @@ # pylint: disable=no-self-argument, inconsistent-return-statements, invalid-name, import-outside-toplevel import pathlib -from pathlib import Path from typing import Any, Dict, List, Optional +from pydantic import BaseModel, Field, field_validator +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + SettingsConfigDict, +) +from pydantic_settings import ( + YamlConfigSettingsSource as _YamlConfigSettingsSource, +) from ruamel.yaml import YAML from gto.constants import assert_name_is_valid from gto.exceptions import UnknownStage, UnknownType, WrongConfig from gto.ext import EnrichmentReader, find_enrichment_types, find_enrichments -from ._pydantic import BaseModel, BaseSettings, InitSettingsSource, validator - yaml = YAML(typ="safe", pure=True) yaml.default_flow_style = False @@ -27,45 +33,47 @@ def load(self) -> EnrichmentReader: class NoFileConfig(BaseSettings): # type: ignore[valid-type] INDEX: str = "artifacts.yaml" - TYPES: Optional[List[str]] = None - STAGES: Optional[List[str]] = None + CONFIG_FILE_NAME: Optional[str] = CONFIG_FILE_NAME LOG_LEVEL: str = "INFO" DEBUG: bool = False - ENRICHMENTS: List[EnrichmentConfig] = [] - AUTOLOAD_ENRICHMENTS: bool = True - CONFIG_FILE_NAME: Optional[str] = CONFIG_FILE_NAME EMOJIS: bool = True - class Config: - env_prefix = "gto_" + types: Optional[List[str]] = None + stages: Optional[List[str]] = None + enrichments: List[EnrichmentConfig] = Field(default_factory=list) + autoload_enrichments: bool = True + + model_config = SettingsConfigDict(env_prefix="gto_") def assert_type(self, name): assert_name_is_valid(name) # pylint: disable-next=unsupported-membership-test - if self.TYPES is not None and name not in self.TYPES: - raise UnknownType(name, self.TYPES) + if self.types is not None and name not in self.types: + raise UnknownType(name, self.types) def assert_stage(self, name): assert_name_is_valid(name) # pylint: disable-next=unsupported-membership-test - if self.STAGES is not None and name not in self.STAGES: - raise UnknownStage(name, self.STAGES) + if self.stages is not None and name not in self.stages: + raise UnknownStage(name, self.stages) @property - def enrichments(self) -> Dict[str, EnrichmentReader]: - res = {e.source: e for e in (e.load() for e in self.ENRICHMENTS)} - if self.AUTOLOAD_ENRICHMENTS: + def enrichments_(self) -> Dict[str, EnrichmentReader]: + res = {e.source: e for e in (e.load() for e in self.enrichments)} + if self.autoload_enrichments: return {**find_enrichments(), **res} return res - @validator("TYPES") + @field_validator("types") + @classmethod def types_are_valid(cls, v): # pylint: disable=no-self-use if v: for name in v: assert_name_is_valid(name) return v - @validator("STAGES") + @field_validator("stages") + @classmethod def stages_are_valid(cls, v): # pylint: disable=no-self-use if v: for name in v: @@ -77,61 +85,48 @@ def check_index_exist(self, repo: str): return index.exists() and index.is_file() -def _set_location_init_source(init_source: InitSettingsSource): - def inner(settings: "RegistryConfig"): - if "CONFIG_FILE_NAME" in init_source.init_kwargs: - settings.__dict__["CONFIG_FILE_NAME"] = init_source.init_kwargs[ - "CONFIG_FILE_NAME" - ] - return {} - - return inner +class YamlConfigSettingsSource(_YamlConfigSettingsSource): + def _read_file(self, file_path: pathlib.Path) -> dict[str, Any]: + with open(file_path, encoding=self.yaml_file_encoding) as yaml_file: + return yaml.load(yaml_file) or {} -def config_settings_source(settings: "RegistryConfig") -> Dict[str, Any]: - """ - A simple settings source that loads variables from a yaml file in GTO DIR - """ - - encoding = settings.__config__.env_file_encoding - config_file = getattr(settings, "CONFIG_FILE_NAME", CONFIG_FILE_NAME) - if not isinstance(config_file, Path): - config_file = Path(config_file) - if not config_file.exists(): - return {} - conf = yaml.load(config_file.read_text(encoding=encoding)) - - return {k.upper(): v for k, v in conf.items()} if conf else {} +class RegistryConfig(NoFileConfig): + model_config = SettingsConfigDict(env_prefix="gto_", env_file_encoding="utf-8") + def config_file_exists(self): + config = pathlib.Path(self.CONFIG_FILE_NAME) + return config.exists() and config.is_file() -class RegistryConfig(NoFileConfig): - class Config: - env_prefix = "gto_" - env_file_encoding = "utf-8" +def read_registry_config(config_file_name) -> "RegistryConfig": + class _RegistryConfig(RegistryConfig): @classmethod - def customise_sources( + def settings_customise_sources( cls, - init_settings, - env_settings, - file_secret_settings, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, ): + encoding = getattr(settings_cls.model_config, "env_file_encoding", "utf-8") return ( - _set_location_init_source(init_settings), init_settings, env_settings, - config_settings_source, + ( + YamlConfigSettingsSource( + settings_cls, + yaml_file=config_file_name, + yaml_file_encoding=encoding, + ) + ), + dotenv_settings, file_secret_settings, ) - def config_file_exists(self): - config = pathlib.Path(self.CONFIG_FILE_NAME) - return config.exists() and config.is_file() - - -def read_registry_config(config_file_name): try: - return RegistryConfig(CONFIG_FILE_NAME=config_file_name) + return _RegistryConfig(CONFIG_FILE_NAME=config_file_name) except Exception as e: # pylint: disable=bare-except raise WrongConfig(config_file_name) from e diff --git a/gto/constants.py b/gto/constants.py index c0a391f..3b46792 100644 --- a/gto/constants.py +++ b/gto/constants.py @@ -2,9 +2,9 @@ from enum import Enum from typing import Optional -from gto.exceptions import ValidationError +from pydantic import BaseModel -from ._pydantic import BaseModel +from gto.exceptions import ValidationError COMMIT = "commit" REF = "ref" diff --git a/gto/ext.py b/gto/ext.py index 9be4a82..dbdeed2 100644 --- a/gto/ext.py +++ b/gto/ext.py @@ -4,10 +4,9 @@ from typing import Dict, Optional, Type, Union import entrypoints +from pydantic import BaseModel from scmrepo.git import Git -from ._pydantic import BaseModel - ENRICHMENT_ENTRYPOINT = "gto.enrichment" @@ -29,7 +28,7 @@ def get_object(self) -> BaseModel: raise NotImplementedError def get_dict(self): - return self.get_object().dict() + return self.get_object().model_dump() @abstractmethod def get_human_readable(self) -> str: diff --git a/gto/index.py b/gto/index.py index 27d8508..3b780f6 100644 --- a/gto/index.py +++ b/gto/index.py @@ -18,6 +18,13 @@ Union, ) +from pydantic import ( + BaseModel, + ConfigDict, + TypeAdapter, + ValidationError, + field_validator, +) from ruamel.yaml import YAMLError from scmrepo.exceptions import SCMError from scmrepo.git import Git @@ -49,8 +56,6 @@ from gto.git_utils import RemoteRepoMixin from gto.ui import echo -from ._pydantic import BaseModel, ValidationError, parse_obj_as, validator - logger = logging.getLogger("gto") @@ -64,6 +69,7 @@ class Artifact(BaseModel): State = Dict[str, Artifact] +state_adapter = TypeAdapter(State) def not_frozen(func): @@ -110,7 +116,8 @@ class Index(BaseModel): state: State = {} # TODO should not be populated until load() is called frozen: bool = False - @validator("state") + @field_validator("state") + @classmethod def state_is_valid(cls, v): # pylint: disable=no-self-argument, no-self-use for name, artifact in v.items(): assert_name_is_valid(name) @@ -147,8 +154,9 @@ def read_yaml(stream: IO): else: contents = read_yaml(path_or_file) # check yaml contents is a valid State + try: - state = parse_obj_as(State, contents) + state = state_adapter.validate_python(contents) except ValidationError as e: raise WrongArtifactsYaml() from e # validate that specific names conform to the naming convention @@ -159,7 +167,7 @@ def read_yaml(stream: IO): def write_state(self, path_or_file: Union[str, IO]): if isinstance(path_or_file, str): with open(path_or_file, "w", encoding="utf8") as file: - state = self.dict(exclude_defaults=True).get("state", {}) + state = self.model_dump(exclude_defaults=True).get("state", {}) yaml.dump(state, file) @not_frozen @@ -216,7 +224,7 @@ def remove(self, name): class BaseIndexManager(BaseModel, ABC): - current: Optional[Index] + current: Optional[Index] = None config: RegistryConfig @abstractmethod @@ -321,7 +329,7 @@ class RepoIndexManager(FileIndexManager, RemoteRepoMixin): cloned: bool def __init__(self, scm: Git, cloned: bool, config): - super().__init__(scm=scm, cloned=cloned, config=config) # type: ignore[call-arg] + super().__init__(scm=scm, cloned=cloned, config=config, current=None) # type: ignore[call-arg] @classmethod @contextmanager @@ -390,8 +398,7 @@ def index_path(self): # TODO: config should be loaded from repo too return os.path.join(self.scm.root_dir, self.config.INDEX) - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) def get_commit_index( # type: ignore # pylint: disable=arguments-differ self, @@ -474,7 +481,7 @@ def from_scm( yield cls(scm=scm, config=config) def describe(self, name: str, rev: Optional[str] = None) -> List[EnrichmentInfo]: - enrichments = self.config.enrichments + enrichments = self.config.enrichments_ res = [] gto_enrichment = enrichments.pop("gto") gto_info = gto_enrichment.describe(self.scm, name, rev) @@ -572,7 +579,7 @@ def get_object(self) -> BaseModel: return self.artifact def get_human_readable(self) -> str: - return self.artifact.json() + return self.artifact.model_dump_json() def get_path(self): return self.artifact.path diff --git a/gto/registry.py b/gto/registry.py index 88e45c8..44e1474 100644 --- a/gto/registry.py +++ b/gto/registry.py @@ -4,6 +4,7 @@ from typing import List, Optional, TypeVar, cast from funcy import distinct +from pydantic import BaseModel, ConfigDict from scmrepo.git import Git from gto.base import ( @@ -36,8 +37,6 @@ from gto.ui import echo from gto.versions import SemVer -from ._pydantic import BaseModel - TBaseEvent = TypeVar("TBaseEvent", bound=BaseEvent) @@ -49,9 +48,7 @@ class GitRegistry(BaseModel, RemoteRepoMixin): stage_manager: TagStageManager enrichment_manager: EnrichmentManager config: RegistryConfig - - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) @classmethod @contextmanager @@ -556,7 +553,7 @@ def latest(self, name: str, all: bool = False, registered: bool = True): return artifact.get_latest_version(registered_only=registered) def _get_allowed_stages(self): - return self.config.STAGES + return self.config.stages def _get_used_stages(self): return sorted( diff --git a/gto/tag.py b/gto/tag.py index e7aec54..4bc8d38 100644 --- a/gto/tag.py +++ b/gto/tag.py @@ -4,10 +4,10 @@ from enum import Enum from typing import FrozenSet, Iterable, Optional, Union +from pydantic import BaseModel, ConfigDict from scmrepo.exceptions import RevError from scmrepo.git import Git, GitTag -from ._pydantic import BaseModel from .base import ( Artifact, Assignment, @@ -122,13 +122,11 @@ def parse_name_reference(name: str): class Tag(BaseModel): action: Action name: str - version: Optional[str] - stage: Optional[str] created_at: datetime.datetime tag: GitTag - - class Config: - arbitrary_types_allowed = True + version: Optional[str] = None + stage: Optional[str] = None + model_config = ConfigDict(arbitrary_types_allowed=True) def parse_tag(tag: GitTag): diff --git a/gto/utils.py b/gto/utils.py index fd148e7..2ec0129 100644 --- a/gto/utils.py +++ b/gto/utils.py @@ -6,12 +6,11 @@ from enum import Enum import click +from pydantic import BaseModel from tabulate import tabulate from gto.config import yaml -from ._pydantic import BaseModel - def flatten(obj): if isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)): @@ -42,7 +41,7 @@ def make_ready_to_serialize( if data is None: return data if isinstance(data, BaseModel): - return make_ready_to_serialize(data.dict()) + return make_ready_to_serialize(data.model_dump()) raise NotImplementedError( f"Serialisation is not implemented for {data_to_serialize} of type {type(data_to_serialize)}" ) diff --git a/pyproject.toml b/pyproject.toml index 9c0f4c5..fb7a128 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,9 +23,8 @@ dynamic = ["version"] dependencies = [ "entrypoints", "funcy", - # pydantic.v1.parse_obj is broken in ==2.0.0: - # https://github.com/pydantic/pydantic/issues/6361 - "pydantic>=1.9.0,<3,!=2.0.0", + "pydantic>=2", + "pydantic-settings>=2", "rich", "ruamel.yaml", "scmrepo>=3,<4", @@ -49,7 +48,10 @@ tests = [ ] dev = [ "gto[tests]", - "mypy==1.17.1", + "mypy==1.17.1; python_version > '3.9'", + # mypy>=1.11.0 crashes when used with pydantic-settings on Python 3.9, + # see: https://github.com/python/mypy/issues/17535 + "mypy<1.11.0; python_version <= '3.9'", "pylint==3.3.8", "types-PyYAML", "types-filelock", diff --git a/tests/test_api.py b/tests/test_api.py index 1a31e62..5d980b0 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -372,7 +372,7 @@ def test_if_stages_on_remote_git_repo_then_return_expected_stages(): def test_if_describe_on_remote_git_repo_then_return_expected_info(): result = gto.api.describe(repo=tests.resources.SAMPLE_REMOTE_REPO_URL, name="churn") - assert result.dict(exclude_defaults=True) == { + assert result.model_dump(exclude_defaults=True) == { "type": "model", "path": "models/churn.pkl", "virtual": False, diff --git a/tests/test_config.py b/tests/test_config.py index 4ed324c..064429f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -34,12 +34,12 @@ def _init_repo(tmp_dir: TmpDir, scm: Git) -> TmpDir: def test_config_load_index(init_repo: TmpDir): with RepoIndexManager.from_url(init_repo) as index: - assert index.config.TYPES == ["model", "dataset"] + assert index.config.types == ["model", "dataset"] def test_config_load_registry(init_repo: TmpDir): with GitRegistry.from_url(init_repo) as reg: - assert reg.config.TYPES == ["model", "dataset"] + assert reg.config.types == ["model", "dataset"] def test_stages(init_repo: TmpDir): diff --git a/tests/test_registry.py b/tests/test_registry.py index 5f78b0e..fd9a516 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -371,7 +371,7 @@ def iter_over(sequence): @pytest.mark.usefixtures("showcase") def test_registry_state_tag_tag(tmp_dir: TmpDir): with GitRegistry.from_url(tmp_dir) as reg: - appeared_state = reg.get_state().dict() + appeared_state = reg.get_state().model_dump() # TODO: update state exclude: Dict[str, List[str]] = { diff --git a/tests/utils.py b/tests/utils.py index 7c0fdbd..392132a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,8 +3,7 @@ from typing import Any, Dict, Sequence, Set, Union from funcy import omit - -from gto._pydantic import BaseModel +from pydantic import BaseModel def show_difference(left: Dict, right: Dict): @@ -29,7 +28,7 @@ def check_obj( skip_keys: Union[Set[str], Sequence[str]] = (), ): if isinstance(obj, BaseModel): - obj_values = obj.dict(exclude=set(skip_keys)) + obj_values = obj.model_dump(exclude=set(skip_keys)) else: obj_values = omit(obj, skip_keys) values = omit(values, skip_keys)