From 6ad1948a44a97349aa207247e42df342d565440e Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 18 Apr 2024 21:33:54 -0400 Subject: [PATCH 01/25] add InvokeAIAppConfig schema migration system --- .../app/services/config/config_default.py | 169 +++++++++--------- .../app/services/config/config_migrate.py | 129 +++++++++++++ pyproject.toml | 1 + 3 files changed, 212 insertions(+), 87 deletions(-) create mode 100644 invokeai/app/services/config/config_migrate.py diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 54a092d03e7..7734dde828d 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -20,6 +20,8 @@ from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS from invokeai.frontend.cli.arg_parser import InvokeAIArgs +from .config_migrate import ConfigMigrator + INIT_FILE = Path("invokeai.yaml") DB_FILE = Path("invokeai.db") LEGACY_INIT_FILE = Path("invokeai.init") @@ -348,75 +350,6 @@ def settings_customise_sources( return (init_settings,) -def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig: - """Migrate a v3 config dictionary to a current config object. - - Args: - config_dict: A dictionary of settings from a v3 config file. - - Returns: - An instance of `InvokeAIAppConfig` with the migrated settings. - - """ - parsed_config_dict: dict[str, Any] = {} - for _category_name, category_dict in config_dict["InvokeAI"].items(): - for k, v in category_dict.items(): - # `outdir` was renamed to `outputs_dir` in v4 - if k == "outdir": - parsed_config_dict["outputs_dir"] = v - # `max_cache_size` was renamed to `ram` some time in v3, but both names were used - if k == "max_cache_size" and "ram" not in category_dict: - parsed_config_dict["ram"] = v - # `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used - if k == "max_vram_cache_size" and "vram" not in category_dict: - parsed_config_dict["vram"] = v - # autocast was removed in v4.0.1 - if k == "precision" and v == "autocast": - parsed_config_dict["precision"] = "auto" - if k == "conf_path": - parsed_config_dict["legacy_models_yaml_path"] = v - if k == "legacy_conf_dir": - # The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows). - if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion": - # If if the incoming config has the default value, skip - continue - elif Path(v).name == "stable-diffusion": - # Else if the path ends in "stable-diffusion", we assume the parent is the new correct path. - parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent) - else: - # Else we do not attempt to migrate this setting - parsed_config_dict["legacy_conf_dir"] = v - elif k in InvokeAIAppConfig.model_fields: - # skip unknown fields - parsed_config_dict[k] = v - # When migrating the config file, we should not include currently-set environment variables. - config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict) - - return config - - -def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig: - """Migrate v4.0.0 config dictionary to a current config object. - - Args: - config_dict: A dictionary of settings from a v4.0.0 config file. - - Returns: - An instance of `InvokeAIAppConfig` with the migrated settings. - """ - parsed_config_dict: dict[str, Any] = {} - for k, v in config_dict.items(): - # autocast was removed from precision in v4.0.1 - if k == "precision" and v == "autocast": - parsed_config_dict["precision"] = "auto" - else: - parsed_config_dict[k] = v - if k == "schema_version": - parsed_config_dict[k] = CONFIG_SCHEMA_VERSION - config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict) - return config - - def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: """Load and migrate a config file to the latest version. @@ -432,29 +365,20 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: assert isinstance(loaded_config_dict, dict) - if "InvokeAI" in loaded_config_dict: - # This is a v3 config file, attempt to migrate it - shutil.copy(config_path, config_path.with_suffix(".yaml.bak")) - try: - # loaded_config_dict could be the wrong shape, but we will catch all exceptions below - migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType] - except Exception as e: - shutil.copy(config_path.with_suffix(".yaml.bak"), config_path) - raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e - migrated_config.write_file(config_path) - return migrated_config - - if loaded_config_dict["schema_version"] == "4.0.0": - loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict) - loaded_config_dict.write_file(config_path) + shutil.copy(config_path, config_path.with_suffix(".yaml.bak")) + try: + # loaded_config_dict could be the wrong shape, but we will catch all exceptions below + migrated_config_dict = ConfigMigrator.migrate(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType] + except Exception as e: + shutil.copy(config_path.with_suffix(".yaml.bak"), config_path) + raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e # Attempt to load as a v4 config file try: - # Meta is not included in the model fields, so we need to validate it separately - config = InvokeAIAppConfig.model_validate(loaded_config_dict) + config = InvokeAIAppConfig.model_validate(migrated_config_dict) assert ( config.schema_version == CONFIG_SCHEMA_VERSION - ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}" + ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}" return config except Exception as e: raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e @@ -504,6 +428,7 @@ def get_config() -> InvokeAIAppConfig: if config.config_file_path.exists(): config_from_file = load_and_migrate_config(config.config_file_path) + config_from_file.write_file(config.config_file_path) # Clobbering here will overwrite any settings that were set via environment variables config.update_config(config_from_file, clobber=False) else: @@ -512,3 +437,73 @@ def get_config() -> InvokeAIAppConfig: default_config.write_file(config.config_file_path, as_example=False) return config + + +#################################################### +# VERSION MIGRATIONS +#################################################### + + +@ConfigMigrator.register(from_version="0.0.0", to_version="4.0.0") +def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]: + """Migrate a v3 config dictionary to a current config object. + + Args: + config_dict: A dictionary of settings from a v3 config file. + + Returns: + A dictionary of settings from a 4.0.0 config file. + + """ + parsed_config_dict: dict[str, Any] = {} + for _category_name, category_dict in config_dict["InvokeAI"].items(): + for k, v in category_dict.items(): + # `outdir` was renamed to `outputs_dir` in v4 + if k == "outdir": + parsed_config_dict["outputs_dir"] = v + # `max_cache_size` was renamed to `ram` some time in v3, but both names were used + if k == "max_cache_size" and "ram" not in category_dict: + parsed_config_dict["ram"] = v + # `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used + if k == "max_vram_cache_size" and "vram" not in category_dict: + parsed_config_dict["vram"] = v + # autocast was removed in v4.0.1 + if k == "precision" and v == "autocast": + parsed_config_dict["precision"] = "auto" + if k == "conf_path": + parsed_config_dict["legacy_models_yaml_path"] = v + if k == "legacy_conf_dir": + # The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows). + if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion": + # If if the incoming config has the default value, skip + continue + elif Path(v).name == "stable-diffusion": + # Else if the path ends in "stable-diffusion", we assume the parent is the new correct path. + parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent) + else: + # Else we do not attempt to migrate this setting + parsed_config_dict["legacy_conf_dir"] = v + elif k in InvokeAIAppConfig.model_fields: + # skip unknown fields + parsed_config_dict[k] = v + return parsed_config_dict + + +@ConfigMigrator.register(from_version="4.0.0", to_version="4.0.1") +def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]: + """Migrate v4.0.0 config dictionary to v4.0.1. + + Args: + config_dict: A dictionary of settings from a v4.0.0 config file. + + Returns: + A dictionary of settings from a v4.0.1 config file + """ + parsed_config_dict: dict[str, Any] = {} + for k, v in config_dict.items(): + # autocast was removed from precision in v4.0.1 + if k == "precision" and v == "autocast": + parsed_config_dict["precision"] = "auto" + else: + parsed_config_dict[k] = v + return parsed_config_dict diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py new file mode 100644 index 00000000000..6f070222ee6 --- /dev/null +++ b/invokeai/app/services/config/config_migrate.py @@ -0,0 +1,129 @@ +# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team + +""" +Utility class for migrating among versions of the InvokeAI app config schema. +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Callable, List, TypeVar + +from pydantic import BaseModel, ConfigDict, field_validator +from version_parser import Version + +if TYPE_CHECKING: + pass + +AppConfigDict = TypeVar("AppConfigDict", bound=dict[str, Any]) + + +class AppVersion(Version): + """Stringlike object that sorts like a version.""" + + def __hash__(self) -> int: # noqa D105 + return hash(str(self)) + + def __repr__(self) -> str: # noqa D105 + return f"AppVersion('{str(self)}')" + + +class ConfigMigratorBase(ABC): + """This class allows migrators to register their input and output versions.""" + + @classmethod + @abstractmethod + def register( + cls, from_version: AppVersion, to_version: AppVersion + ) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]: + """Define a decorator which registers the migration between two versions.""" + + @classmethod + @abstractmethod + def migrate(cls, config: AppConfigDict) -> AppConfigDict: + """ + Use the registered migration steps to bring config up to latest version. + + :param config: The original configuration. + :return: The new configuration, lifted up to the latest version. + + As a side effect, the new configuration will be written to disk. + """ + + +class MigrationEntry(BaseModel): + """Defines an individual migration.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + from_version: AppVersion + to_version: AppVersion + function: Callable[[AppConfigDict], AppConfigDict] + + @field_validator("from_version", "to_version", mode="before") + @classmethod + def _string_to_version(cls, v: str | AppVersion) -> AppVersion: # noqa D102 + if isinstance(v, str): + return AppVersion(v) + else: + return v + + +class ConfigMigrator(ConfigMigratorBase): + """This class allows migrators to register their input and output versions.""" + + _migrations: List[MigrationEntry] = [] + + @classmethod + def register( + cls, + from_version: AppVersion | str, + to_version: AppVersion | str, + ) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]: + """Define a decorator which registers the migration between two versions.""" + + def decorator(function: Callable[[AppConfigDict], AppConfigDict]) -> Callable[[AppConfigDict], AppConfigDict]: + if from_version in cls._migrations: + raise ValueError( + f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered." + ) + cls._migrations.append(MigrationEntry(from_version=from_version, to_version=to_version, function=function)) + return function + + return decorator + + @staticmethod + def _check_for_overlaps(migrations: List[MigrationEntry]) -> None: + current_version = AppVersion("0.0.0") + for m in migrations: + if current_version > m.from_version: + raise ValueError(f"Version range overlap detected while processing function {m.function.__name__}") + + @classmethod + def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict: + """ + Use the registered migration steps to bring config up to latest version. + + :param config: The original configuration. + :return: The new configuration, lifted up to the latest version. + + As a side effect, the new configuration will be written to disk. + If an inconsistency in the registered migration steps' `from_version` + and `to_version` parameters are identified, this will raise a + ValueError exception. + """ + # Sort migrations by version number and raise a ValueError if + # any version range overlaps are detected. Discontinuities are ok + sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version) + cls._check_for_overlaps(sorted_migrations) + + if "InvokeAI" in config_dict: + version = AppVersion("3.0.0") + else: + version = AppVersion(config_dict["schema_version"]) + + for migration in sorted_migrations: + if version >= migration.from_version and version < migration.to_version: + config_dict = migration.function(config_dict) + version = migration.to_version + + config_dict["schema_version"] = str(version) + return config_dict diff --git a/pyproject.toml b/pyproject.toml index 86cbb8315ce..85401f8c97d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,7 @@ dependencies = [ "semver~=3.0.1", "send2trash", "test-tube~=0.7.5", + "version-parser", "windows-curses; sys_platform=='win32'", ] From 36495b730d5d9ad01850db7b857bb5c096d55208 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 18 Apr 2024 21:54:17 -0400 Subject: [PATCH 02/25] use packaging.version rather than version-parse --- .../app/services/config/config_migrate.py | 32 +++++++------------ pyproject.toml | 1 - 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py index 6f070222ee6..2c43f13a572 100644 --- a/invokeai/app/services/config/config_migrate.py +++ b/invokeai/app/services/config/config_migrate.py @@ -7,8 +7,8 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Callable, List, TypeVar +from packaging.version import Version from pydantic import BaseModel, ConfigDict, field_validator -from version_parser import Version if TYPE_CHECKING: pass @@ -16,23 +16,13 @@ AppConfigDict = TypeVar("AppConfigDict", bound=dict[str, Any]) -class AppVersion(Version): - """Stringlike object that sorts like a version.""" - - def __hash__(self) -> int: # noqa D105 - return hash(str(self)) - - def __repr__(self) -> str: # noqa D105 - return f"AppVersion('{str(self)}')" - - class ConfigMigratorBase(ABC): """This class allows migrators to register their input and output versions.""" @classmethod @abstractmethod def register( - cls, from_version: AppVersion, to_version: AppVersion + cls, from_version: Version, to_version: Version ) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]: """Define a decorator which registers the migration between two versions.""" @@ -54,15 +44,15 @@ class MigrationEntry(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - from_version: AppVersion - to_version: AppVersion + from_version: Version + to_version: Version function: Callable[[AppConfigDict], AppConfigDict] @field_validator("from_version", "to_version", mode="before") @classmethod - def _string_to_version(cls, v: str | AppVersion) -> AppVersion: # noqa D102 + def _string_to_version(cls, v: str | Version) -> Version: # noqa D102 if isinstance(v, str): - return AppVersion(v) + return Version(v) else: return v @@ -75,8 +65,8 @@ class ConfigMigrator(ConfigMigratorBase): @classmethod def register( cls, - from_version: AppVersion | str, - to_version: AppVersion | str, + from_version: Version | str, + to_version: Version | str, ) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]: """Define a decorator which registers the migration between two versions.""" @@ -92,7 +82,7 @@ def decorator(function: Callable[[AppConfigDict], AppConfigDict]) -> Callable[[A @staticmethod def _check_for_overlaps(migrations: List[MigrationEntry]) -> None: - current_version = AppVersion("0.0.0") + current_version = Version("0.0.0") for m in migrations: if current_version > m.from_version: raise ValueError(f"Version range overlap detected while processing function {m.function.__name__}") @@ -116,9 +106,9 @@ def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict: cls._check_for_overlaps(sorted_migrations) if "InvokeAI" in config_dict: - version = AppVersion("3.0.0") + version = Version("3.0.0") else: - version = AppVersion(config_dict["schema_version"]) + version = Version(config_dict["schema_version"]) for migration in sorted_migrations: if version >= migration.from_version and version < migration.to_version: diff --git a/pyproject.toml b/pyproject.toml index 85401f8c97d..86cbb8315ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,6 @@ dependencies = [ "semver~=3.0.1", "send2trash", "test-tube~=0.7.5", - "version-parser", "windows-curses; sys_platform=='win32'", ] From b612c739546f46a2566e938f6cd219393c6012f4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 Apr 2024 17:09:50 +1000 Subject: [PATCH 03/25] tidy(config): remove unused TYPE_CHECKING block --- invokeai/app/services/config/config_migrate.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py index 2c43f13a572..b0e43f1a40d 100644 --- a/invokeai/app/services/config/config_migrate.py +++ b/invokeai/app/services/config/config_migrate.py @@ -5,14 +5,11 @@ """ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, List, TypeVar +from typing import Any, Callable, List, TypeVar from packaging.version import Version from pydantic import BaseModel, ConfigDict, field_validator -if TYPE_CHECKING: - pass - AppConfigDict = TypeVar("AppConfigDict", bound=dict[str, Any]) From e39f035264eb581be492d7ba938c11dba02cbd61 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 Apr 2024 17:11:13 +1000 Subject: [PATCH 04/25] tidy(config): removed extraneous ABC We don't need separate implementations for this class, let's not complicate it with an ABC --- .../app/services/config/config_migrate.py | 26 +------------------ 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py index b0e43f1a40d..f746a8d0498 100644 --- a/invokeai/app/services/config/config_migrate.py +++ b/invokeai/app/services/config/config_migrate.py @@ -4,7 +4,6 @@ Utility class for migrating among versions of the InvokeAI app config schema. """ -from abc import ABC, abstractmethod from typing import Any, Callable, List, TypeVar from packaging.version import Version @@ -13,29 +12,6 @@ AppConfigDict = TypeVar("AppConfigDict", bound=dict[str, Any]) -class ConfigMigratorBase(ABC): - """This class allows migrators to register their input and output versions.""" - - @classmethod - @abstractmethod - def register( - cls, from_version: Version, to_version: Version - ) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]: - """Define a decorator which registers the migration between two versions.""" - - @classmethod - @abstractmethod - def migrate(cls, config: AppConfigDict) -> AppConfigDict: - """ - Use the registered migration steps to bring config up to latest version. - - :param config: The original configuration. - :return: The new configuration, lifted up to the latest version. - - As a side effect, the new configuration will be written to disk. - """ - - class MigrationEntry(BaseModel): """Defines an individual migration.""" @@ -54,7 +30,7 @@ def _string_to_version(cls, v: str | Version) -> Version: # noqa D102 return v -class ConfigMigrator(ConfigMigratorBase): +class ConfigMigrator: """This class allows migrators to register their input and output versions.""" _migrations: List[MigrationEntry] = [] From aca9e44a3a68098505de6b43dc45e6232948cff3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 Apr 2024 17:12:19 +1000 Subject: [PATCH 05/25] fix(config): use TypeAlias instead of TypeVar TypeVar is for generics, but the usage here is as an alias --- invokeai/app/services/config/config_migrate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py index f746a8d0498..591f86f04a6 100644 --- a/invokeai/app/services/config/config_migrate.py +++ b/invokeai/app/services/config/config_migrate.py @@ -4,12 +4,12 @@ Utility class for migrating among versions of the InvokeAI app config schema. """ -from typing import Any, Callable, List, TypeVar +from typing import Any, Callable, List, TypeAlias from packaging.version import Version from pydantic import BaseModel, ConfigDict, field_validator -AppConfigDict = TypeVar("AppConfigDict", bound=dict[str, Any]) +AppConfigDict: TypeAlias = dict[str, Any] class MigrationEntry(BaseModel): From 6f128c86b456cf3b2fc16de3c6f703327c899724 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 Apr 2024 17:19:54 +1000 Subject: [PATCH 06/25] tidy(config): use dataclass for MigrationEntry The only pydantic usage was to convert strings to `Version` objects. The reason to do this conversion was to allow the register decorator to accept strings. MigrationEntry is only created inside this class, so we can just create versions from each migration when instantiating MigrationEntry instead. Also, pydantic doesn't provide runtime time checking for arbitrary classes like Version, so we don't get any real benefit. --- .../app/services/config/config_migrate.py | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py index 591f86f04a6..3da734b2d95 100644 --- a/invokeai/app/services/config/config_migrate.py +++ b/invokeai/app/services/config/config_migrate.py @@ -4,31 +4,22 @@ Utility class for migrating among versions of the InvokeAI app config schema. """ +from dataclasses import dataclass from typing import Any, Callable, List, TypeAlias from packaging.version import Version -from pydantic import BaseModel, ConfigDict, field_validator AppConfigDict: TypeAlias = dict[str, Any] -class MigrationEntry(BaseModel): +@dataclass +class MigrationEntry: """Defines an individual migration.""" - model_config = ConfigDict(arbitrary_types_allowed=True) - from_version: Version to_version: Version function: Callable[[AppConfigDict], AppConfigDict] - @field_validator("from_version", "to_version", mode="before") - @classmethod - def _string_to_version(cls, v: str | Version) -> Version: # noqa D102 - if isinstance(v, str): - return Version(v) - else: - return v - class ConfigMigrator: """This class allows migrators to register their input and output versions.""" @@ -38,8 +29,8 @@ class ConfigMigrator: @classmethod def register( cls, - from_version: Version | str, - to_version: Version | str, + from_version: str, + to_version: str, ) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]: """Define a decorator which registers the migration between two versions.""" @@ -48,7 +39,9 @@ def decorator(function: Callable[[AppConfigDict], AppConfigDict]) -> Callable[[A raise ValueError( f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered." ) - cls._migrations.append(MigrationEntry(from_version=from_version, to_version=to_version, function=function)) + cls._migrations.append( + MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function) + ) return function return decorator From 5d411e446a09c41a23b5b583a8427e5a6490a1c0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 Apr 2024 17:21:05 +1000 Subject: [PATCH 07/25] tidy(config): use a type alias for the migration function --- invokeai/app/services/config/config_migrate.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py index 3da734b2d95..6cdbca7558f 100644 --- a/invokeai/app/services/config/config_migrate.py +++ b/invokeai/app/services/config/config_migrate.py @@ -10,6 +10,7 @@ from packaging.version import Version AppConfigDict: TypeAlias = dict[str, Any] +MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict] @dataclass @@ -18,7 +19,7 @@ class MigrationEntry: from_version: Version to_version: Version - function: Callable[[AppConfigDict], AppConfigDict] + function: MigrationFunction class ConfigMigrator: @@ -31,10 +32,10 @@ def register( cls, from_version: str, to_version: str, - ) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]: + ) -> Callable[[MigrationFunction], MigrationFunction]: """Define a decorator which registers the migration between two versions.""" - def decorator(function: Callable[[AppConfigDict], AppConfigDict]) -> Callable[[AppConfigDict], AppConfigDict]: + def decorator(function: MigrationFunction) -> MigrationFunction: if from_version in cls._migrations: raise ValueError( f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered." From d12fb7db68de7ed2d8c5ed3c5f27235d286f59bf Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 Apr 2024 17:25:53 +1000 Subject: [PATCH 08/25] fix(config): fix duplicate migration logic This was checking a `Version` object against a `MigrationEntry`, but what we want is to check the version object against `MigrationEntry.from_version` --- invokeai/app/services/config/config_migrate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py index 6cdbca7558f..5402555a47b 100644 --- a/invokeai/app/services/config/config_migrate.py +++ b/invokeai/app/services/config/config_migrate.py @@ -36,7 +36,7 @@ def register( """Define a decorator which registers the migration between two versions.""" def decorator(function: MigrationFunction) -> MigrationFunction: - if from_version in cls._migrations: + if any(from_version == m.from_version for m in cls._migrations): raise ValueError( f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered." ) From 984dd93798dabb62387f201280d2e3b34234c071 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 Apr 2024 17:50:31 +1000 Subject: [PATCH 09/25] tests(config): add failing test case to for config migrator --- tests/test_config.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index a6ea2a34806..d115d0b1af4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -14,6 +14,13 @@ ) from invokeai.frontend.cli.arg_parser import InvokeAIArgs +invalid_v4_0_1_config = """ +schema_version: 4.0.1 + +host: "192.168.1.1" +port: 8080 +""" + v4_config = """ schema_version: 4.0.0 @@ -155,11 +162,11 @@ def test_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None): with pytest.raises(AssertionError): load_and_migrate_config(temp_config_file) - -def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None): +@pytest.mark.parametrize("config_content", [invalid_v5_config, invalid_v1_config]) +def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None, config_content: str): """Test reading configuration from a file.""" temp_config_file = tmp_path / "temp_invokeai.yaml" - temp_config_file.write_text(invalid_v5_config) + temp_config_file.write_text(config_content) with pytest.raises(RuntimeError, match="Invalid schema version"): load_and_migrate_config(temp_config_file) From ab9ebef345250a8588c8aa6ce960cdf16f162cb4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 Apr 2024 17:52:51 +1000 Subject: [PATCH 10/25] tests(config): fix typo --- tests/test_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_config.py b/tests/test_config.py index d115d0b1af4..7029b5317b4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -162,7 +162,7 @@ def test_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None): with pytest.raises(AssertionError): load_and_migrate_config(temp_config_file) -@pytest.mark.parametrize("config_content", [invalid_v5_config, invalid_v1_config]) +@pytest.mark.parametrize("config_content", [invalid_v5_config, invalid_v4_0_1_config]) def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None, config_content: str): """Test reading configuration from a file.""" temp_config_file = tmp_path / "temp_invokeai.yaml" From 6eaed9a9cba1ea75eb3722bd67e5e56bd9142fa4 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 24 Apr 2024 21:36:28 -0400 Subject: [PATCH 11/25] check for strictly contiguous from_version->to_version ranges --- invokeai/app/services/config/config_default.py | 2 +- invokeai/app/services/config/config_migrate.py | 17 ++++++++++------- tests/test_config.py | 18 ++++++++++++++++-- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 7734dde828d..0aedb54a377 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -444,7 +444,7 @@ def get_config() -> InvokeAIAppConfig: #################################################### -@ConfigMigrator.register(from_version="0.0.0", to_version="4.0.0") +@ConfigMigrator.register(from_version="3.0.0", to_version="4.0.0") def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]: """Migrate a v3 config dictionary to a current config object. diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py index 5402555a47b..b3fe979d37a 100644 --- a/invokeai/app/services/config/config_migrate.py +++ b/invokeai/app/services/config/config_migrate.py @@ -48,11 +48,14 @@ def decorator(function: MigrationFunction) -> MigrationFunction: return decorator @staticmethod - def _check_for_overlaps(migrations: List[MigrationEntry]) -> None: - current_version = Version("0.0.0") + def _check_for_discontinuities(migrations: List[MigrationEntry]) -> None: + current_version = Version("3.0.0") for m in migrations: - if current_version > m.from_version: - raise ValueError(f"Version range overlap detected while processing function {m.function.__name__}") + if current_version != m.from_version: + raise ValueError( + f"Migration functions are not continuous. Expected from_version={current_version} but got from_version={m.from_version}, for migration function {m.function.__name__}" + ) + current_version = m.to_version @classmethod def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict: @@ -68,9 +71,9 @@ def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict: ValueError exception. """ # Sort migrations by version number and raise a ValueError if - # any version range overlaps are detected. Discontinuities are ok + # any version range overlaps are detected. sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version) - cls._check_for_overlaps(sorted_migrations) + cls._check_for_discontinuities(sorted_migrations) if "InvokeAI" in config_dict: version = Version("3.0.0") @@ -78,7 +81,7 @@ def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict: version = Version(config_dict["schema_version"]) for migration in sorted_migrations: - if version >= migration.from_version and version < migration.to_version: + if version == migration.from_version and version < migration.to_version: config_dict = migration.function(config_dict) version = migration.to_version diff --git a/tests/test_config.py b/tests/test_config.py index 7029b5317b4..08d28dee473 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,6 +4,7 @@ import pytest from omegaconf import OmegaConf +from packaging.version import Version from pydantic import ValidationError from invokeai.app.services.config.config_default import ( @@ -18,12 +19,13 @@ schema_version: 4.0.1 host: "192.168.1.1" -port: 8080 +port: "ice cream" """ v4_config = """ schema_version: 4.0.0 +precision: autocast host: "192.168.1.1" port: 8080 """ @@ -141,6 +143,16 @@ def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None): assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config +def test_migrate_v4(tmp_path: Path, patch_rootdir: None): + """Test migration from 4.0.0 to 4.0.1""" + temp_config_file = tmp_path / "temp_invokeai.yaml" + temp_config_file.write_text(v4_config) + + conf = load_and_migrate_config(temp_config_file) + assert Version(conf.schema_version) >= Version("4.0.1") + assert conf.precision == "auto" # we expect 'autocast' to be replaced with 'auto' during 4.0.1 migration + + def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None): """Test the failed migration of the config file.""" temp_config_file = tmp_path / "temp_invokeai.yaml" @@ -162,13 +174,15 @@ def test_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None): with pytest.raises(AssertionError): load_and_migrate_config(temp_config_file) + @pytest.mark.parametrize("config_content", [invalid_v5_config, invalid_v4_0_1_config]) def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None, config_content: str): """Test reading configuration from a file.""" temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file.write_text(config_content) - with pytest.raises(RuntimeError, match="Invalid schema version"): + # with pytest.raises(RuntimeError, match="Invalid schema version"): + with pytest.raises(RuntimeError): load_and_migrate_config(temp_config_file) From 8144a263deefda6306ae3c6e0aa379cec5c90172 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 24 Apr 2024 22:06:16 -0400 Subject: [PATCH 12/25] updated and reinstated the test_deny_nodes() unit test --- tests/test_config.py | 54 +++++++------------------------------------- 1 file changed, 8 insertions(+), 46 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 08d28dee473..f1b4b4a6ce7 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,10 +3,9 @@ from typing import Any import pytest -from omegaconf import OmegaConf from packaging.version import Version -from pydantic import ValidationError +from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.config.config_default import ( DefaultInvokeAIAppConfig, InvokeAIAppConfig, @@ -286,50 +285,10 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch InvokeAIArgs.did_parse = False -@pytest.mark.xfail( - reason=""" - This test fails when run as part of the full test suite. - - This test needs to deny nodes from being included in the InvocationsUnion by providing - an app configuration as a test fixture. Pytest executes all test files before running - tests, so the app configuration is already initialized by the time this test runs, and - the InvocationUnion is already created and the denied nodes are not omitted from it. - - This test passes when `test_config.py` is tested in isolation. - - Perhaps a solution would be to call `get_app_config().parse_args()` in - other test files? - """ -) -def test_deny_nodes(patch_rootdir): - # Allow integer, string and float, but explicitly deny float - allow_deny_nodes_conf = OmegaConf.create( - """ - InvokeAI: - Nodes: - allow_nodes: - - integer - - string - - float - deny_nodes: - - float - """ - ) - # must parse config before importing Graph, so its nodes union uses the config - get_config.cache_clear() - conf = get_config() - get_config.cache_clear() - conf.merge_from_file(conf=allow_deny_nodes_conf, argv=[]) - from invokeai.app.services.shared.graph import Graph - - # confirm graph validation fails when using denied node - Graph(nodes={"1": {"id": "1", "type": "integer"}}) - Graph(nodes={"1": {"id": "1", "type": "string"}}) - - with pytest.raises(ValidationError): - Graph(nodes={"1": {"id": "1", "type": "float"}}) - - from invokeai.app.invocations.baseinvocation import BaseInvocation +def test_deny_nodes(): + config = get_config() + config.allow_nodes = ["integer", "string", "float"] + config.deny_nodes = ["float"] # confirm invocations union will not have denied nodes all_invocations = BaseInvocation.get_invocations() @@ -341,3 +300,6 @@ def test_deny_nodes(patch_rootdir): assert has_integer assert has_string assert not has_float + + # may not be necessary + get_config.cache_clear() From d24877561d827e2f87b9658856e947d86e74a147 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 25 Apr 2024 00:22:09 -0400 Subject: [PATCH 13/25] reinstated failing deny_nodes validation test for Graph --- tests/test_config.py | 42 +++++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index f1b4b4a6ce7..5858bfa47a9 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,9 +1,11 @@ +from contextlib import contextmanager from pathlib import Path from tempfile import TemporaryDirectory from typing import Any import pytest from packaging.version import Version +from pydantic import ValidationError from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.config.config_default import ( @@ -12,6 +14,7 @@ get_config, load_and_migrate_config, ) +from invokeai.app.services.shared.graph import Graph from invokeai.frontend.cli.arg_parser import InvokeAIArgs invalid_v4_0_1_config = """ @@ -285,21 +288,34 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch InvokeAIArgs.did_parse = False +@contextmanager +def clear_config(): + try: + yield None + finally: + get_config.cache_clear() + + def test_deny_nodes(): - config = get_config() - config.allow_nodes = ["integer", "string", "float"] - config.deny_nodes = ["float"] + with clear_config(): + config = get_config() + config.allow_nodes = ["integer", "string", "float"] + config.deny_nodes = ["float"] - # confirm invocations union will not have denied nodes - all_invocations = BaseInvocation.get_invocations() + # confirm graph validation fails when using denied node + Graph(nodes={"1": {"id": "1", "type": "integer"}}) + Graph(nodes={"1": {"id": "1", "type": "string"}}) - has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1 - has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1 - has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1 + with pytest.raises(ValidationError): + Graph(nodes={"1": {"id": "1", "type": "float"}}) - assert has_integer - assert has_string - assert not has_float + # confirm invocations union will not have denied nodes + all_invocations = BaseInvocation.get_invocations() - # may not be necessary - get_config.cache_clear() + has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1 + has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1 + has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1 + + assert has_integer + assert has_string + assert not has_float From d852ca7a8d269c0d0d707214c73a488d4b569d2e Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 28 Apr 2024 14:31:38 -0400 Subject: [PATCH 14/25] added test for non-contiguous migration routines --- tests/test_config.py | 41 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 5858bfa47a9..80c7ccc950c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,7 +1,7 @@ from contextlib import contextmanager from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any +from typing import Any, Generator import pytest from packaging.version import Version @@ -9,11 +9,13 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.config.config_default import ( + CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, get_config, load_and_migrate_config, ) +from invokeai.app.services.config.config_migrate import ConfigMigrator from invokeai.app.services.shared.graph import Graph from invokeai.frontend.cli.arg_parser import InvokeAIArgs @@ -288,14 +290,49 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch InvokeAIArgs.did_parse = False +def test_migration_check() -> None: + new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"}) + assert new_config is not None + assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + + # Does this execute at compile time or run time? + @ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION, to_version=CONFIG_SCHEMA_VERSION + ".1") + def ok_migration(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"}) + assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".1" + + @ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".2", to_version=CONFIG_SCHEMA_VERSION + ".3") + def bad_migration(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + # Because there is no version for "*.1" => "*.2", this should fail. + with pytest.raises(ValueError): + ConfigMigrator.migrate({"schema_version": "4.0.0"}) + + @ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".1", to_version=CONFIG_SCHEMA_VERSION + ".2") + def good_migration(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + # should work now, because there is a continuous path to *.3 + new_config = ConfigMigrator.migrate(new_config) + assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".3" + + @contextmanager -def clear_config(): +def clear_config() -> Generator[None, None, None]: try: yield None finally: get_config.cache_clear() +@pytest.mark.xfail( + reason=""" + Currently this test is failing due to an issue described in issue #5983. +""" +) def test_deny_nodes(): with clear_config(): config = get_config() From a48abfacf4f3d8cf3c67c35be26a91b9990e78b7 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 2 May 2024 23:45:34 -0400 Subject: [PATCH 15/25] make config migrator into an instance; refactor location of get_config() --- invokeai/app/api_app.py | 2 +- invokeai/app/invocations/__init__.py | 2 +- invokeai/app/invocations/baseinvocation.py | 2 +- invokeai/app/services/config/__init__.py | 3 +- .../app/services/config/config_default.py | 166 ------------------ .../app/services/config/config_migrate.py | 121 +++++++++++-- invokeai/app/services/config/migrations.py | 108 ++++++++++++ .../app/services/shared/invocation_context.py | 2 +- .../app/services/shared/sqlite/sqlite_util.py | 2 +- .../sqlite_migrator/migrations/migration_8.py | 2 +- .../image_util/depth_anything/__init__.py | 2 +- .../image_util/dw_openpose/wholebody.py | 2 +- .../backend/image_util/infill_methods/lama.py | 2 +- .../image_util/infill_methods/patchmatch.py | 2 +- .../backend/image_util/invisible_watermark.py | 2 +- invokeai/backend/image_util/safety_checker.py | 2 +- .../stable_diffusion/diffusers_pipeline.py | 2 +- .../diffusion/shared_invokeai_diffusion.py | 2 +- invokeai/backend/util/devices.py | 2 +- invokeai/backend/util/logging.py | 3 +- tests/test_config.py | 102 ++++++++--- 21 files changed, 314 insertions(+), 219 deletions(-) create mode 100644 invokeai/app/services/config/migrations.py diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index ceaeb95147a..2efd338c1ec 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -26,7 +26,7 @@ import invokeai.frontend.web as web_dir from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.invocations.model import ModelIdentifierField -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.backend.util.devices import TorchDevice diff --git a/invokeai/app/invocations/__init__.py b/invokeai/app/invocations/__init__.py index cb1caa167ef..f9b0932b04d 100644 --- a/invokeai/app/invocations/__init__.py +++ b/invokeai/app/invocations/__init__.py @@ -3,7 +3,7 @@ from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config custom_nodes_path = Path(get_config().custom_nodes_path) custom_nodes_path.mkdir(parents=True, exist_ok=True) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 40c7b41caeb..ee4b88fa2e0 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -33,7 +33,7 @@ FieldKind, Input, ) -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.misc import uuid_string diff --git a/invokeai/app/services/config/__init__.py b/invokeai/app/services/config/__init__.py index 126692f08a8..ac154386dae 100644 --- a/invokeai/app/services/config/__init__.py +++ b/invokeai/app/services/config/__init__.py @@ -2,6 +2,7 @@ from invokeai.app.services.config.config_common import PagingArgumentParser -from .config_default import InvokeAIAppConfig, get_config +from .config_default import InvokeAIAppConfig +from .config_migrate import get_config __all__ = ["InvokeAIAppConfig", "get_config", "PagingArgumentParser"] diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 0aedb54a377..06d0ea27831 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -3,11 +3,8 @@ from __future__ import annotations -import locale import os import re -import shutil -from functools import lru_cache from pathlib import Path from typing import Any, Literal, Optional @@ -16,11 +13,7 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict -import invokeai.configs as model_configs from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS -from invokeai.frontend.cli.arg_parser import InvokeAIArgs - -from .config_migrate import ConfigMigrator INIT_FILE = Path("invokeai.yaml") DB_FILE = Path("invokeai.db") @@ -348,162 +341,3 @@ def settings_customise_sources( file_secret_settings: PydanticBaseSettingsSource, ) -> tuple[PydanticBaseSettingsSource, ...]: return (init_settings,) - - -def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: - """Load and migrate a config file to the latest version. - - Args: - config_path: Path to the config file. - - Returns: - An instance of `InvokeAIAppConfig` with the loaded and migrated settings. - """ - assert config_path.suffix == ".yaml" - with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file: - loaded_config_dict = yaml.safe_load(file) - - assert isinstance(loaded_config_dict, dict) - - shutil.copy(config_path, config_path.with_suffix(".yaml.bak")) - try: - # loaded_config_dict could be the wrong shape, but we will catch all exceptions below - migrated_config_dict = ConfigMigrator.migrate(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType] - except Exception as e: - shutil.copy(config_path.with_suffix(".yaml.bak"), config_path) - raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e - - # Attempt to load as a v4 config file - try: - config = InvokeAIAppConfig.model_validate(migrated_config_dict) - assert ( - config.schema_version == CONFIG_SCHEMA_VERSION - ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}" - return config - except Exception as e: - raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e - - -@lru_cache(maxsize=1) -def get_config() -> InvokeAIAppConfig: - """Get the global singleton app config. - - When first called, this function: - - Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file. - - Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint. - - Sets the root dir, if provided via CLI args. - - Logs in to HF if there is no valid token already. - - Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers). - - Reads and merges in settings from the config file if it exists, else writes out a default config file. - - On subsequent calls, the object is returned from the cache. - """ - # This object includes environment variables, as parsed by pydantic-settings - config = InvokeAIAppConfig() - - args = InvokeAIArgs.args - - # This flag serves as a proxy for whether the config was retrieved in the context of the full application or not. - # If it is False, we should just return a default config and not set the root, log in to HF, etc. - if not InvokeAIArgs.did_parse: - return config - - # Set CLI args - if root := getattr(args, "root", None): - config._root = Path(root) - if config_file := getattr(args, "config_file", None): - config._config_file = Path(config_file) - - # Create the example config file, with some extra example values provided - example_config = DefaultInvokeAIAppConfig() - example_config.remote_api_tokens = [ - URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"), - URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"), - ] - example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True) - - # Copy all legacy configs - We know `__path__[0]` is correct here - configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue] - shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True) - - if config.config_file_path.exists(): - config_from_file = load_and_migrate_config(config.config_file_path) - config_from_file.write_file(config.config_file_path) - # Clobbering here will overwrite any settings that were set via environment variables - config.update_config(config_from_file, clobber=False) - else: - # We should never write env vars to the config file - default_config = DefaultInvokeAIAppConfig() - default_config.write_file(config.config_file_path, as_example=False) - - return config - - -#################################################### -# VERSION MIGRATIONS -#################################################### - - -@ConfigMigrator.register(from_version="3.0.0", to_version="4.0.0") -def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]: - """Migrate a v3 config dictionary to a current config object. - - Args: - config_dict: A dictionary of settings from a v3 config file. - - Returns: - A dictionary of settings from a 4.0.0 config file. - - """ - parsed_config_dict: dict[str, Any] = {} - for _category_name, category_dict in config_dict["InvokeAI"].items(): - for k, v in category_dict.items(): - # `outdir` was renamed to `outputs_dir` in v4 - if k == "outdir": - parsed_config_dict["outputs_dir"] = v - # `max_cache_size` was renamed to `ram` some time in v3, but both names were used - if k == "max_cache_size" and "ram" not in category_dict: - parsed_config_dict["ram"] = v - # `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used - if k == "max_vram_cache_size" and "vram" not in category_dict: - parsed_config_dict["vram"] = v - # autocast was removed in v4.0.1 - if k == "precision" and v == "autocast": - parsed_config_dict["precision"] = "auto" - if k == "conf_path": - parsed_config_dict["legacy_models_yaml_path"] = v - if k == "legacy_conf_dir": - # The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows). - if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion": - # If if the incoming config has the default value, skip - continue - elif Path(v).name == "stable-diffusion": - # Else if the path ends in "stable-diffusion", we assume the parent is the new correct path. - parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent) - else: - # Else we do not attempt to migrate this setting - parsed_config_dict["legacy_conf_dir"] = v - elif k in InvokeAIAppConfig.model_fields: - # skip unknown fields - parsed_config_dict[k] = v - return parsed_config_dict - - -@ConfigMigrator.register(from_version="4.0.0", to_version="4.0.1") -def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]: - """Migrate v4.0.0 config dictionary to v4.0.1. - - Args: - config_dict: A dictionary of settings from a v4.0.0 config file. - - Returns: - A dictionary of settings from a v4.0.1 config file - """ - parsed_config_dict: dict[str, Any] = {} - for k, v in config_dict.items(): - # autocast was removed from precision in v4.0.1 - if k == "precision" and v == "autocast": - parsed_config_dict["precision"] = "auto" - else: - parsed_config_dict[k] = v - return parsed_config_dict diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py index b3fe979d37a..899fc43f398 100644 --- a/invokeai/app/services/config/config_migrate.py +++ b/invokeai/app/services/config/config_migrate.py @@ -4,12 +4,22 @@ Utility class for migrating among versions of the InvokeAI app config schema. """ +import locale +import shutil from dataclasses import dataclass -from typing import Any, Callable, List, TypeAlias +from functools import lru_cache +from pathlib import Path +from typing import Callable, List, TypeAlias +import yaml from packaging.version import Version -AppConfigDict: TypeAlias = dict[str, Any] +import invokeai.configs as model_configs +from invokeai.frontend.cli.arg_parser import InvokeAIArgs + +from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair +from .migrations import AppConfigDict, Migrations, MigrationsBase + MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict] @@ -25,22 +35,23 @@ class MigrationEntry: class ConfigMigrator: """This class allows migrators to register their input and output versions.""" - _migrations: List[MigrationEntry] = [] + def __init__(self, migrations: type[MigrationsBase]) -> None: + self._migrations: List[MigrationEntry] = [] + migrations.load(self) - @classmethod def register( - cls, + self, from_version: str, to_version: str, ) -> Callable[[MigrationFunction], MigrationFunction]: """Define a decorator which registers the migration between two versions.""" def decorator(function: MigrationFunction) -> MigrationFunction: - if any(from_version == m.from_version for m in cls._migrations): + if any((from_version == m.from_version) or (to_version == m.to_version) for m in self._migrations): raise ValueError( f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered." ) - cls._migrations.append( + self._migrations.append( MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function) ) return function @@ -57,8 +68,7 @@ def _check_for_discontinuities(migrations: List[MigrationEntry]) -> None: ) current_version = m.to_version - @classmethod - def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict: + def run_migrations(self, config_dict: AppConfigDict) -> AppConfigDict: """ Use the registered migration steps to bring config up to latest version. @@ -72,8 +82,8 @@ def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict: """ # Sort migrations by version number and raise a ValueError if # any version range overlaps are detected. - sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version) - cls._check_for_discontinuities(sorted_migrations) + sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version) + self._check_for_discontinuities(sorted_migrations) if "InvokeAI" in config_dict: version = Version("3.0.0") @@ -87,3 +97,92 @@ def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict: config_dict["schema_version"] = str(version) return config_dict + + +@lru_cache(maxsize=1) +def get_config() -> InvokeAIAppConfig: + """Get the global singleton app config. + + When first called, this function: + - Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file. + - Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint. + - Sets the root dir, if provided via CLI args. + - Logs in to HF if there is no valid token already. + - Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers). + - Reads and merges in settings from the config file if it exists, else writes out a default config file. + + On subsequent calls, the object is returned from the cache. + """ + # This object includes environment variables, as parsed by pydantic-settings + config = InvokeAIAppConfig() + + args = InvokeAIArgs.args + + # This flag serves as a proxy for whether the config was retrieved in the context of the full application or not. + # If it is False, we should just return a default config and not set the root, log in to HF, etc. + if not InvokeAIArgs.did_parse: + return config + + # Set CLI args + if root := getattr(args, "root", None): + config._root = Path(root) + if config_file := getattr(args, "config_file", None): + config._config_file = Path(config_file) + + # Create the example config file, with some extra example values provided + example_config = DefaultInvokeAIAppConfig() + example_config.remote_api_tokens = [ + URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"), + URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"), + ] + example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True) + + # Copy all legacy configs - We know `__path__[0]` is correct here + configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue] + shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True) + + if config.config_file_path.exists(): + config_from_file = load_and_migrate_config(config.config_file_path) + config_from_file.write_file(config.config_file_path) + # Clobbering here will overwrite any settings that were set via environment variables + config.update_config(config_from_file, clobber=False) + else: + # We should never write env vars to the config file + default_config = DefaultInvokeAIAppConfig() + default_config.write_file(config.config_file_path, as_example=False) + + return config + + +def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: + """Load and migrate a config file to the latest version. + + Args: + config_path: Path to the config file. + + Returns: + An instance of `InvokeAIAppConfig` with the loaded and migrated settings. + """ + assert config_path.suffix == ".yaml" + with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file: + loaded_config_dict = yaml.safe_load(file) + + assert isinstance(loaded_config_dict, dict) + + shutil.copy(config_path, config_path.with_suffix(".yaml.bak")) + try: + migrator = ConfigMigrator(Migrations) + migrated_config_dict = migrator.run_migrations(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType] + except Exception as e: + shutil.copy(config_path.with_suffix(".yaml.bak"), config_path) + raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e + + # Attempt to load as a v4 config file + try: + config = InvokeAIAppConfig.model_validate(migrated_config_dict) + assert ( + config.schema_version == CONFIG_SCHEMA_VERSION + ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}" + return config + except Exception as e: + raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e diff --git a/invokeai/app/services/config/migrations.py b/invokeai/app/services/config/migrations.py new file mode 100644 index 00000000000..4c6996d7db3 --- /dev/null +++ b/invokeai/app/services/config/migrations.py @@ -0,0 +1,108 @@ +# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team + +""" +Schema migrations to perform on an InvokeAIAppConfig object. + +The Migrations class defined in this module defines a series of +schema version migration steps for the InvokeAIConfig object. + +To define a new migration, add a migration function to +Migrations.load_migrations() following the existing examples. +""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any, TypeAlias + +from .config_default import InvokeAIAppConfig + +if TYPE_CHECKING: + from .config_migrate import ConfigMigrator + +AppConfigDict: TypeAlias = dict[str, Any] + + +class MigrationsBase(ABC): + """Define the config file migration steps to apply, abstract base class.""" + + @classmethod + @abstractmethod + def load(cls, migrator: "ConfigMigrator") -> None: + """Use the provided migrator to register the configuration migrations to be run.""" + + +class Migrations(MigrationsBase): + """Configuration migration steps to apply.""" + + @classmethod + def load(cls, migrator: "ConfigMigrator") -> None: + """Define migrations to perform.""" + + ################## + # 3.0.0 -> 4.0.0 # + ################## + @migrator.register(from_version="3.0.0", to_version="4.0.0") + def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]: + """Migrate a v3 config dictionary to a current config object. + + Args: + config_dict: A dictionary of settings from a v3 config file. + + Returns: + A dictionary of settings from a 4.0.0 config file. + + """ + parsed_config_dict: dict[str, Any] = {} + for _category_name, category_dict in config_dict["InvokeAI"].items(): + for k, v in category_dict.items(): + # `outdir` was renamed to `outputs_dir` in v4 + if k == "outdir": + parsed_config_dict["outputs_dir"] = v + # `max_cache_size` was renamed to `ram` some time in v3, but both names were used + if k == "max_cache_size" and "ram" not in category_dict: + parsed_config_dict["ram"] = v + # `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used + if k == "max_vram_cache_size" and "vram" not in category_dict: + parsed_config_dict["vram"] = v + # autocast was removed in v4.0.1 + if k == "precision" and v == "autocast": + parsed_config_dict["precision"] = "auto" + if k == "conf_path": + parsed_config_dict["legacy_models_yaml_path"] = v + if k == "legacy_conf_dir": + # The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows). + if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion": + # If if the incoming config has the default value, skip + continue + elif Path(v).name == "stable-diffusion": + # Else if the path ends in "stable-diffusion", we assume the parent is the new correct path. + parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent) + else: + # Else we do not attempt to migrate this setting + parsed_config_dict["legacy_conf_dir"] = v + elif k in InvokeAIAppConfig.model_fields: + # skip unknown fields + parsed_config_dict[k] = v + return parsed_config_dict + + ################## + # 4.0.0 -> 4.0.1 # + ################## + @migrator.register(from_version="4.0.0", to_version="4.0.1") + def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]: + """Migrate v4.0.0 config dictionary to v4.0.1. + + Args: + config_dict: A dictionary of settings from a v4.0.0 config file. + + Returns: + A dictionary of settings from a v4.0.1 config file + """ + parsed_config_dict: dict[str, Any] = {} + for k, v in config_dict.items(): + # autocast was removed from precision in v4.0.1 + if k == "precision" and v == "autocast": + parsed_config_dict["precision"] = "auto" + else: + parsed_config_dict[k] = v + return parsed_config_dict diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 9994d663e5e..68f7b11bcda 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -9,7 +9,7 @@ from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata from invokeai.app.services.boards.boards_common import BoardDTO -from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invocation_services import InvocationServices diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 1eed0b44092..2dcaaa8aedd 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -1,6 +1,6 @@ from logging import Logger -from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1 diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_8.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_8.py index 154a5236cae..4fb8cf46ef4 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_8.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_8.py @@ -1,7 +1,7 @@ import sqlite3 from pathlib import Path -from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py index c854fba3f23..ac3692cec80 100644 --- a/invokeai/backend/image_util/depth_anything/__init__.py +++ b/invokeai/backend/image_util/depth_anything/__init__.py @@ -9,7 +9,7 @@ from PIL import Image from torchvision.transforms import Compose -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize diff --git a/invokeai/backend/image_util/dw_openpose/wholebody.py b/invokeai/backend/image_util/dw_openpose/wholebody.py index 84f5afa989e..750f480fb23 100644 --- a/invokeai/backend/image_util/dw_openpose/wholebody.py +++ b/invokeai/backend/image_util/dw_openpose/wholebody.py @@ -5,7 +5,7 @@ import numpy as np import onnxruntime as ort -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.util.devices import TorchDevice diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py index 4268ec773d4..247dcb83c4e 100644 --- a/invokeai/backend/image_util/infill_methods/lama.py +++ b/invokeai/backend/image_util/infill_methods/lama.py @@ -6,7 +6,7 @@ from PIL import Image import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.util.devices import TorchDevice diff --git a/invokeai/backend/image_util/infill_methods/patchmatch.py b/invokeai/backend/image_util/infill_methods/patchmatch.py index 7e9cdf8fa41..0ed05a1ab66 100644 --- a/invokeai/backend/image_util/infill_methods/patchmatch.py +++ b/invokeai/backend/image_util/infill_methods/patchmatch.py @@ -9,7 +9,7 @@ from PIL import Image import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config class PatchMatch: diff --git a/invokeai/backend/image_util/invisible_watermark.py b/invokeai/backend/image_util/invisible_watermark.py index 84342e442fc..18c87dbffd2 100644 --- a/invokeai/backend/image_util/invisible_watermark.py +++ b/invokeai/backend/image_util/invisible_watermark.py @@ -10,7 +10,7 @@ from PIL import Image import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config config = get_config() diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py index 60dcd93fcc5..afc51c241a2 100644 --- a/invokeai/backend/image_util/safety_checker.py +++ b/invokeai/backend/image_util/safety_checker.py @@ -12,7 +12,7 @@ from transformers import AutoFeatureExtractor import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.silence_warnings import SilenceWarnings diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 8b90c815ae7..cac3f58d08a 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -20,7 +20,7 @@ from pydantic import Field from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f418133e49f..c3552088310 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -6,7 +6,7 @@ import torch from typing_extensions import TypeAlias -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( IPAdapterData, Range, diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index e8380dc8bcd..b7320bc9f05 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -3,7 +3,7 @@ import torch from deprecated import deprecated -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config # legacy APIs TorchPrecisionNames = Literal["float32", "float16", "bfloat16"] diff --git a/invokeai/backend/util/logging.py b/invokeai/backend/util/logging.py index 968604eb3d9..217a1a6cb87 100644 --- a/invokeai/backend/util/logging.py +++ b/invokeai/backend/util/logging.py @@ -180,8 +180,7 @@ from pathlib import Path from typing import Any, Dict, Optional -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import InvokeAIAppConfig, get_config try: import syslog diff --git a/tests/test_config.py b/tests/test_config.py index 80c7ccc950c..9b3c140e5e0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -12,10 +12,9 @@ CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, - get_config, - load_and_migrate_config, ) -from invokeai.app.services.config.config_migrate import ConfigMigrator +from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config +from invokeai.app.services.config.migrations import Migrations, MigrationsBase from invokeai.app.services.shared.graph import Graph from invokeai.frontend.cli.arg_parser import InvokeAIArgs @@ -72,6 +71,68 @@ """ +class GoodMigrations(MigrationsBase): + @classmethod + def load(cls, migrator: ConfigMigrator) -> None: + """Define migrations to perform.""" + + @migrator.register(from_version="3.0.0", to_version="10.0.0") + def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @migrator.register(from_version="10.0.0", to_version="10.0.1") + def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @migrator.register(from_version="10.0.1", to_version="10.0.2") + def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + +class BadMigrations1(MigrationsBase): + """This one fails because there is no path from 10.0.1 to 10.0.2""" + + @classmethod + def load(cls, migrator: ConfigMigrator) -> None: + """Define migrations to perform.""" + + @migrator.register(from_version="3.0.0", to_version="10.0.0") + def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @migrator.register(from_version="10.0.0", to_version="10.0.1") + def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @migrator.register(from_version="10.0.2", to_version="10.0.3") + def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + +class BadMigrations2(MigrationsBase): + """This one fails because the path to 10.0.2 is registered twice""" + + @classmethod + def load(cls, migrator: ConfigMigrator) -> None: + """Define migrations to perform.""" + + @migrator.register(from_version="3.0.0", to_version="10.0.0") + def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @migrator.register(from_version="10.0.0", to_version="10.0.1") + def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @migrator.register(from_version="10.0.1", to_version="10.0.2") + def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @migrator.register(from_version="10.0.0", to_version="10.0.2") + def migration_4(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @pytest.fixture def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None: """This may be overkill since the current tests don't need the root dir to exist""" @@ -291,33 +352,26 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch def test_migration_check() -> None: - new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"}) + # Test the default set of migrations + migrator = ConfigMigrator(Migrations) + new_config = migrator.run_migrations({"schema_version": "4.0.0"}) assert new_config is not None assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION - # Does this execute at compile time or run time? - @ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION, to_version=CONFIG_SCHEMA_VERSION + ".1") - def ok_migration(config_dict: dict[str, Any]) -> dict[str, Any]: - return config_dict - - new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"}) - assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".1" - - @ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".2", to_version=CONFIG_SCHEMA_VERSION + ".3") - def bad_migration(config_dict: dict[str, Any]) -> dict[str, Any]: - return config_dict + # Test a custom set of migrations + migrator = ConfigMigrator(GoodMigrations) + new_config = migrator.run_migrations({"schema_version": "10.0.0"}) + assert new_config["schema_version"] == "10.0.2" - # Because there is no version for "*.1" => "*.2", this should fail. + # Test a migration that should fail validation + migrator = ConfigMigrator(BadMigrations1) with pytest.raises(ValueError): - ConfigMigrator.migrate({"schema_version": "4.0.0"}) + new_config = migrator.run_migrations({"schema_version": "10.0.0"}) - @ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".1", to_version=CONFIG_SCHEMA_VERSION + ".2") - def good_migration(config_dict: dict[str, Any]) -> dict[str, Any]: - return config_dict - - # should work now, because there is a continuous path to *.3 - new_config = ConfigMigrator.migrate(new_config) - assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".3" + # Test another bad migration + migrator = ConfigMigrator(BadMigrations2) + with pytest.raises(ValueError): + new_config = migrator.run_migrations({"schema_version": "10.0.0"}) @contextmanager From 2dd42d09177cada83ccde9d9d50e3aa872ed81f3 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 3 May 2024 06:44:03 -0400 Subject: [PATCH 16/25] check that right no. of migration steps run --- tests/test_config.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_config.py b/tests/test_config.py index 9b3c140e5e0..e274344c849 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -72,20 +72,25 @@ class GoodMigrations(MigrationsBase): + methods_run: int = 0 + @classmethod def load(cls, migrator: ConfigMigrator) -> None: """Define migrations to perform.""" @migrator.register(from_version="3.0.0", to_version="10.0.0") def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]: + cls.methods_run += 1 return config_dict @migrator.register(from_version="10.0.0", to_version="10.0.1") def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]: + cls.methods_run += 1 return config_dict @migrator.register(from_version="10.0.1", to_version="10.0.2") def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]: + cls.methods_run += 1 return config_dict @@ -359,9 +364,17 @@ def test_migration_check() -> None: assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION # Test a custom set of migrations + GoodMigrations.methods_run = 0 migrator = ConfigMigrator(GoodMigrations) new_config = migrator.run_migrations({"schema_version": "10.0.0"}) assert new_config["schema_version"] == "10.0.2" + assert GoodMigrations.methods_run == 2 + + GoodMigrations.methods_run = 0 + migrator = ConfigMigrator(GoodMigrations) + new_config = migrator.run_migrations({"schema_version": "3.0.0"}) + assert new_config["schema_version"] == "10.0.2" + assert GoodMigrations.methods_run == 3 # Test a migration that should fail validation migrator = ConfigMigrator(BadMigrations1) From fc23b16a73d4be288a90ccfbc4ddc601eac25675 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Fri, 3 May 2024 06:49:16 -0400 Subject: [PATCH 17/25] add more checking of migration step operations --- tests/test_config.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_config.py b/tests/test_config.py index e274344c849..e66962308a0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -81,16 +81,19 @@ def load(cls, migrator: ConfigMigrator) -> None: @migrator.register(from_version="3.0.0", to_version="10.0.0") def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]: cls.methods_run += 1 + config_dict["migration_1"] = True return config_dict @migrator.register(from_version="10.0.0", to_version="10.0.1") def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]: cls.methods_run += 1 + config_dict["migration_2"] = True return config_dict @migrator.register(from_version="10.0.1", to_version="10.0.2") def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]: cls.methods_run += 1 + config_dict["migration_3"] = True return config_dict @@ -369,12 +372,15 @@ def test_migration_check() -> None: new_config = migrator.run_migrations({"schema_version": "10.0.0"}) assert new_config["schema_version"] == "10.0.2" assert GoodMigrations.methods_run == 2 + assert new_config.get("migration_2") + assert not new_config.get("migration_1") GoodMigrations.methods_run = 0 migrator = ConfigMigrator(GoodMigrations) new_config = migrator.run_migrations({"schema_version": "3.0.0"}) assert new_config["schema_version"] == "10.0.2" assert GoodMigrations.methods_run == 3 + assert all(new_config[x] for x in ["migration_1", "migration_2", "migration_3"]) # Test a migration that should fail validation migrator = ConfigMigrator(BadMigrations1) From 6946a3871f86ba855bda6db449a05eecdc31478d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 14 May 2024 16:20:40 +1000 Subject: [PATCH 18/25] feat(config): simplify config migrator logic - Remove `Migrations` class - unnecessary complexity on top of `MigrationEntry` - Move common classes to `config_common` - Tidy docstrings, variable names --- invokeai/app/services/config/config_common.py | 22 +++ .../app/services/config/config_migrate.py | 91 ++++----- invokeai/app/services/config/migrations.py | 173 +++++++++--------- 3 files changed, 141 insertions(+), 145 deletions(-) diff --git a/invokeai/app/services/config/config_common.py b/invokeai/app/services/config/config_common.py index 0765b93f2cf..0339c1bff25 100644 --- a/invokeai/app/services/config/config_common.py +++ b/invokeai/app/services/config/config_common.py @@ -12,6 +12,10 @@ import argparse import pydoc +from dataclasses import dataclass +from typing import Any, Callable, TypeAlias + +from packaging.version import Version class PagingArgumentParser(argparse.ArgumentParser): @@ -23,3 +27,21 @@ class PagingArgumentParser(argparse.ArgumentParser): def print_help(self, file=None) -> None: text = self.format_help() pydoc.pager(text) + + +AppConfigDict: TypeAlias = dict[str, Any] + +MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict] + + +@dataclass +class MigrationEntry: + """Defines an individual migration.""" + + from_version: Version + to_version: Version + function: MigrationFunction + + def __hash__(self) -> int: + # Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set. + return hash((self.from_version, self.to_version)) diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py index 899fc43f398..77d1ce85553 100644 --- a/invokeai/app/services/config/config_migrate.py +++ b/invokeai/app/services/config/config_migrate.py @@ -6,60 +6,38 @@ import locale import shutil -from dataclasses import dataclass +from copy import deepcopy from functools import lru_cache from pathlib import Path -from typing import Callable, List, TypeAlias import yaml from packaging.version import Version import invokeai.configs as model_configs +from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry +from invokeai.app.services.config.migrations import config_migration_1, config_migration_2 from invokeai.frontend.cli.arg_parser import InvokeAIArgs from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair -from .migrations import AppConfigDict, Migrations, MigrationsBase - -MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict] - - -@dataclass -class MigrationEntry: - """Defines an individual migration.""" - - from_version: Version - to_version: Version - function: MigrationFunction class ConfigMigrator: """This class allows migrators to register their input and output versions.""" - def __init__(self, migrations: type[MigrationsBase]) -> None: - self._migrations: List[MigrationEntry] = [] - migrations.load(self) - - def register( - self, - from_version: str, - to_version: str, - ) -> Callable[[MigrationFunction], MigrationFunction]: - """Define a decorator which registers the migration between two versions.""" + def __init__(self) -> None: + self._migrations: set[MigrationEntry] = set() - def decorator(function: MigrationFunction) -> MigrationFunction: - if any((from_version == m.from_version) or (to_version == m.to_version) for m in self._migrations): - raise ValueError( - f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered." - ) - self._migrations.append( - MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function) + def register(self, migration: MigrationEntry) -> None: + migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations) + migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations) + if migration_from_already_registered or migration_to_already_registered: + raise ValueError( + f"A migration from {migration.from_version} or to {migration.to_version} has already been registered." ) - return function - - return decorator + self._migrations.add(migration) @staticmethod - def _check_for_discontinuities(migrations: List[MigrationEntry]) -> None: + def _check_for_discontinuities(migrations: list[MigrationEntry]) -> None: current_version = Version("3.0.0") for m in migrations: if current_version != m.from_version: @@ -68,35 +46,38 @@ def _check_for_discontinuities(migrations: List[MigrationEntry]) -> None: ) current_version = m.to_version - def run_migrations(self, config_dict: AppConfigDict) -> AppConfigDict: + def run_migrations(self, original_config: AppConfigDict) -> AppConfigDict: """ - Use the registered migration steps to bring config up to latest version. + Use the registered migrations to bring config up to latest version. - :param config: The original configuration. - :return: The new configuration, lifted up to the latest version. + Args: + original_config: The original configuration. - As a side effect, the new configuration will be written to disk. - If an inconsistency in the registered migration steps' `from_version` - and `to_version` parameters are identified, this will raise a - ValueError exception. + Returns: + The new configuration, lifted up to the latest version. """ - # Sort migrations by version number and raise a ValueError if - # any version range overlaps are detected. + + # Sort migrations by version number and raise a ValueError if any version range overlaps are detected. sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version) self._check_for_discontinuities(sorted_migrations) - if "InvokeAI" in config_dict: + # Do not mutate the incoming dict - we don't know who else may be using it + migrated_config = deepcopy(original_config) + + # v3.0.0 configs did not have "schema_version", but did have "InvokeAI" + if "InvokeAI" in migrated_config: version = Version("3.0.0") else: - version = Version(config_dict["schema_version"]) + version = Version(migrated_config["schema_version"]) for migration in sorted_migrations: - if version == migration.from_version and version < migration.to_version: - config_dict = migration.function(config_dict) + if version == migration.from_version: + migrated_config = migration.function(migrated_config) version = migration.to_version - config_dict["schema_version"] = str(version) - return config_dict + # We must end on the latest version + assert migrated_config["schema_version"] == str(sorted_migrations[-1].to_version) + return migrated_config @lru_cache(maxsize=1) @@ -165,14 +146,16 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: """ assert config_path.suffix == ".yaml" with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file: - loaded_config_dict = yaml.safe_load(file) + loaded_config_dict: AppConfigDict = yaml.safe_load(file) assert isinstance(loaded_config_dict, dict) shutil.copy(config_path, config_path.with_suffix(".yaml.bak")) try: - migrator = ConfigMigrator(Migrations) - migrated_config_dict = migrator.run_migrations(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType] + migrator = ConfigMigrator() + migrator.register(config_migration_1) + migrator.register(config_migration_2) + migrated_config_dict = migrator.run_migrations(loaded_config_dict) except Exception as e: shutil.copy(config_path.with_suffix(".yaml.bak"), config_path) raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e diff --git a/invokeai/app/services/config/migrations.py b/invokeai/app/services/config/migrations.py index 4c6996d7db3..44b4e766b60 100644 --- a/invokeai/app/services/config/migrations.py +++ b/invokeai/app/services/config/migrations.py @@ -10,99 +10,90 @@ Migrations.load_migrations() following the existing examples. """ -from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any, TypeAlias + +from packaging.version import Version + +from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry from .config_default import InvokeAIAppConfig -if TYPE_CHECKING: - from .config_migrate import ConfigMigrator - -AppConfigDict: TypeAlias = dict[str, Any] - - -class MigrationsBase(ABC): - """Define the config file migration steps to apply, abstract base class.""" - - @classmethod - @abstractmethod - def load(cls, migrator: "ConfigMigrator") -> None: - """Use the provided migrator to register the configuration migrations to be run.""" - - -class Migrations(MigrationsBase): - """Configuration migration steps to apply.""" - - @classmethod - def load(cls, migrator: "ConfigMigrator") -> None: - """Define migrations to perform.""" - - ################## - # 3.0.0 -> 4.0.0 # - ################## - @migrator.register(from_version="3.0.0", to_version="4.0.0") - def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]: - """Migrate a v3 config dictionary to a current config object. - - Args: - config_dict: A dictionary of settings from a v3 config file. - - Returns: - A dictionary of settings from a 4.0.0 config file. - - """ - parsed_config_dict: dict[str, Any] = {} - for _category_name, category_dict in config_dict["InvokeAI"].items(): - for k, v in category_dict.items(): - # `outdir` was renamed to `outputs_dir` in v4 - if k == "outdir": - parsed_config_dict["outputs_dir"] = v - # `max_cache_size` was renamed to `ram` some time in v3, but both names were used - if k == "max_cache_size" and "ram" not in category_dict: - parsed_config_dict["ram"] = v - # `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used - if k == "max_vram_cache_size" and "vram" not in category_dict: - parsed_config_dict["vram"] = v - # autocast was removed in v4.0.1 - if k == "precision" and v == "autocast": - parsed_config_dict["precision"] = "auto" - if k == "conf_path": - parsed_config_dict["legacy_models_yaml_path"] = v - if k == "legacy_conf_dir": - # The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows). - if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion": - # If if the incoming config has the default value, skip - continue - elif Path(v).name == "stable-diffusion": - # Else if the path ends in "stable-diffusion", we assume the parent is the new correct path. - parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent) - else: - # Else we do not attempt to migrate this setting - parsed_config_dict["legacy_conf_dir"] = v - elif k in InvokeAIAppConfig.model_fields: - # skip unknown fields - parsed_config_dict[k] = v - return parsed_config_dict - - ################## - # 4.0.0 -> 4.0.1 # - ################## - @migrator.register(from_version="4.0.0", to_version="4.0.1") - def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]: - """Migrate v4.0.0 config dictionary to v4.0.1. - - Args: - config_dict: A dictionary of settings from a v4.0.0 config file. - - Returns: - A dictionary of settings from a v4.0.1 config file - """ - parsed_config_dict: dict[str, Any] = {} - for k, v in config_dict.items(): - # autocast was removed from precision in v4.0.1 - if k == "precision" and v == "autocast": - parsed_config_dict["precision"] = "auto" + +def migrate_v300_to_v400(original_config: AppConfigDict) -> AppConfigDict: + """Migrate a v3.0.0 config dict to v4.0.0. + + Changes in this migration: + - `outdir` was renamed to `outputs_dir` + - `max_cache_size` was renamed to `ram` + - `max_vram_cache_size` was renamed to `vram` + - `conf_path`, which pointed to the old `models.yaml`, was removed - but if need to stash it to migrate the entries + to the database + - `legacy_conf_dir` was changed from a path relative to the app root, to a path relative to $INVOKEAI_ROOT/configs + + Args: + config_dict: The v3.0.0 config dict to migrate. + + Returns: + The migrated v4.0.0 config dict. + """ + migrated_config: AppConfigDict = {} + for _category_name, category_dict in original_config["InvokeAI"].items(): + for k, v in category_dict.items(): + # `outdir` was renamed to `outputs_dir` in v4 + if k == "outdir": + migrated_config["outputs_dir"] = v + # `max_cache_size` was renamed to `ram` some time in v3, but both names were used + if k == "max_cache_size" and "ram" not in category_dict: + migrated_config["ram"] = v + # `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used + if k == "max_vram_cache_size" and "vram" not in category_dict: + migrated_config["vram"] = v + if k == "conf_path": + migrated_config["legacy_models_yaml_path"] = v + if k == "legacy_conf_dir": + # The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows). + if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion": + # If if the incoming config has the default value, skip + continue + elif Path(v).name == "stable-diffusion": + # Else if the path ends in "stable-diffusion", we assume the parent is the new correct path. + migrated_config["legacy_conf_dir"] = str(Path(v).parent) else: - parsed_config_dict[k] = v - return parsed_config_dict + # Else we do not attempt to migrate this setting + migrated_config["legacy_conf_dir"] = v + elif k in InvokeAIAppConfig.model_fields: + # skip unknown fields + migrated_config[k] = v + migrated_config["schema_version"] = "4.0.0" + return migrated_config + + +config_migration_1 = MigrationEntry( + from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migrate_v300_to_v400 +) + + +def migrate_v400_to_v401(original_config: AppConfigDict) -> AppConfigDict: + """Migrate a v4.0.0 config dict to v4.0.1. + + Changes in this migration: + - `precision: "autocast"` was removed, fall back to "auto" + + Args: + config_dict: The v4.0.0 config dict to migrate. + + Returns: + The migrated v4.0.1 config dict. + """ + migrated_config: AppConfigDict = {} + for k, v in original_config.items(): + # autocast was removed from precision in v4.0.1 + if k == "precision" and v == "autocast": + migrated_config["precision"] = "auto" + migrated_config["schema_version"] = "4.0.1" + return migrated_config + + +config_migration_2 = MigrationEntry( + from_version=Version("4.0.0"), to_version=Version("4.0.1"), function=migrate_v400_to_v401 +) From 18b5aafadecf6eff6a747dd4b3650c88f4de6117 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 14 May 2024 16:21:50 +1000 Subject: [PATCH 19/25] tidy(config): add "config" to class names to differentiate from SQLite migration classes --- invokeai/app/services/config/config_common.py | 8 +- .../app/services/config/config_migrate.py | 8 +- invokeai/app/services/config/migrations.py | 6 +- tests/test_config.py | 108 +----------------- 4 files changed, 12 insertions(+), 118 deletions(-) diff --git a/invokeai/app/services/config/config_common.py b/invokeai/app/services/config/config_common.py index 0339c1bff25..b1b910226d7 100644 --- a/invokeai/app/services/config/config_common.py +++ b/invokeai/app/services/config/config_common.py @@ -31,16 +31,16 @@ def print_help(self, file=None) -> None: AppConfigDict: TypeAlias = dict[str, Any] -MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict] +ConfigMigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict] @dataclass -class MigrationEntry: - """Defines an individual migration.""" +class ConfigMigration: + """Defines an individual config migration.""" from_version: Version to_version: Version - function: MigrationFunction + function: ConfigMigrationFunction def __hash__(self) -> int: # Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set. diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py index 77d1ce85553..355f1ccaeaa 100644 --- a/invokeai/app/services/config/config_migrate.py +++ b/invokeai/app/services/config/config_migrate.py @@ -14,7 +14,7 @@ from packaging.version import Version import invokeai.configs as model_configs -from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry +from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration from invokeai.app.services.config.migrations import config_migration_1, config_migration_2 from invokeai.frontend.cli.arg_parser import InvokeAIArgs @@ -25,9 +25,9 @@ class ConfigMigrator: """This class allows migrators to register their input and output versions.""" def __init__(self) -> None: - self._migrations: set[MigrationEntry] = set() + self._migrations: set[ConfigMigration] = set() - def register(self, migration: MigrationEntry) -> None: + def register(self, migration: ConfigMigration) -> None: migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations) migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations) if migration_from_already_registered or migration_to_already_registered: @@ -37,7 +37,7 @@ def register(self, migration: MigrationEntry) -> None: self._migrations.add(migration) @staticmethod - def _check_for_discontinuities(migrations: list[MigrationEntry]) -> None: + def _check_for_discontinuities(migrations: list[ConfigMigration]) -> None: current_version = Version("3.0.0") for m in migrations: if current_version != m.from_version: diff --git a/invokeai/app/services/config/migrations.py b/invokeai/app/services/config/migrations.py index 44b4e766b60..7c0e3f9251f 100644 --- a/invokeai/app/services/config/migrations.py +++ b/invokeai/app/services/config/migrations.py @@ -14,7 +14,7 @@ from packaging.version import Version -from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry +from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration from .config_default import InvokeAIAppConfig @@ -68,7 +68,7 @@ def migrate_v300_to_v400(original_config: AppConfigDict) -> AppConfigDict: return migrated_config -config_migration_1 = MigrationEntry( +config_migration_1 = ConfigMigration( from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migrate_v300_to_v400 ) @@ -94,6 +94,6 @@ def migrate_v400_to_v401(original_config: AppConfigDict) -> AppConfigDict: return migrated_config -config_migration_2 = MigrationEntry( +config_migration_2 = ConfigMigration( from_version=Version("4.0.0"), to_version=Version("4.0.1"), function=migrate_v400_to_v401 ) diff --git a/tests/test_config.py b/tests/test_config.py index e66962308a0..9feb723b015 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -9,12 +9,10 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.config.config_default import ( - CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, ) -from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config -from invokeai.app.services.config.migrations import Migrations, MigrationsBase +from invokeai.app.services.config.config_migrate import get_config, load_and_migrate_config from invokeai.app.services.shared.graph import Graph from invokeai.frontend.cli.arg_parser import InvokeAIArgs @@ -71,76 +69,6 @@ """ -class GoodMigrations(MigrationsBase): - methods_run: int = 0 - - @classmethod - def load(cls, migrator: ConfigMigrator) -> None: - """Define migrations to perform.""" - - @migrator.register(from_version="3.0.0", to_version="10.0.0") - def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]: - cls.methods_run += 1 - config_dict["migration_1"] = True - return config_dict - - @migrator.register(from_version="10.0.0", to_version="10.0.1") - def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]: - cls.methods_run += 1 - config_dict["migration_2"] = True - return config_dict - - @migrator.register(from_version="10.0.1", to_version="10.0.2") - def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]: - cls.methods_run += 1 - config_dict["migration_3"] = True - return config_dict - - -class BadMigrations1(MigrationsBase): - """This one fails because there is no path from 10.0.1 to 10.0.2""" - - @classmethod - def load(cls, migrator: ConfigMigrator) -> None: - """Define migrations to perform.""" - - @migrator.register(from_version="3.0.0", to_version="10.0.0") - def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]: - return config_dict - - @migrator.register(from_version="10.0.0", to_version="10.0.1") - def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]: - return config_dict - - @migrator.register(from_version="10.0.2", to_version="10.0.3") - def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]: - return config_dict - - -class BadMigrations2(MigrationsBase): - """This one fails because the path to 10.0.2 is registered twice""" - - @classmethod - def load(cls, migrator: ConfigMigrator) -> None: - """Define migrations to perform.""" - - @migrator.register(from_version="3.0.0", to_version="10.0.0") - def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]: - return config_dict - - @migrator.register(from_version="10.0.0", to_version="10.0.1") - def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]: - return config_dict - - @migrator.register(from_version="10.0.1", to_version="10.0.2") - def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]: - return config_dict - - @migrator.register(from_version="10.0.0", to_version="10.0.2") - def migration_4(config_dict: dict[str, Any]) -> dict[str, Any]: - return config_dict - - @pytest.fixture def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None: """This may be overkill since the current tests don't need the root dir to exist""" @@ -359,40 +287,6 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch InvokeAIArgs.did_parse = False -def test_migration_check() -> None: - # Test the default set of migrations - migrator = ConfigMigrator(Migrations) - new_config = migrator.run_migrations({"schema_version": "4.0.0"}) - assert new_config is not None - assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION - - # Test a custom set of migrations - GoodMigrations.methods_run = 0 - migrator = ConfigMigrator(GoodMigrations) - new_config = migrator.run_migrations({"schema_version": "10.0.0"}) - assert new_config["schema_version"] == "10.0.2" - assert GoodMigrations.methods_run == 2 - assert new_config.get("migration_2") - assert not new_config.get("migration_1") - - GoodMigrations.methods_run = 0 - migrator = ConfigMigrator(GoodMigrations) - new_config = migrator.run_migrations({"schema_version": "3.0.0"}) - assert new_config["schema_version"] == "10.0.2" - assert GoodMigrations.methods_run == 3 - assert all(new_config[x] for x in ["migration_1", "migration_2", "migration_3"]) - - # Test a migration that should fail validation - migrator = ConfigMigrator(BadMigrations1) - with pytest.raises(ValueError): - new_config = migrator.run_migrations({"schema_version": "10.0.0"}) - - # Test another bad migration - migrator = ConfigMigrator(BadMigrations2) - with pytest.raises(ValueError): - new_config = migrator.run_migrations({"schema_version": "10.0.0"}) - - @contextmanager def clear_config() -> Generator[None, None, None]: try: From 4c081d58e0c37ff728af9aca254410f77f53c66b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 14 May 2024 16:31:11 +1000 Subject: [PATCH 20/25] tidy(config): add note about circular deps in config_migrate.py --- .../app/services/config/config_migrate.py | 73 ++++++++++--------- 1 file changed, 37 insertions(+), 36 deletions(-) diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py index 355f1ccaeaa..c847ef77a8f 100644 --- a/invokeai/app/services/config/config_migrate.py +++ b/invokeai/app/services/config/config_migrate.py @@ -80,6 +80,43 @@ def run_migrations(self, original_config: AppConfigDict) -> AppConfigDict: return migrated_config +def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: + """Load and migrate a config file to the latest version. + + Args: + config_path: Path to the config file. + + Returns: + An instance of `InvokeAIAppConfig` with the loaded and migrated settings. + """ + assert config_path.suffix == ".yaml" + with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file: + loaded_config_dict: AppConfigDict = yaml.safe_load(file) + + assert isinstance(loaded_config_dict, dict) + + shutil.copy(config_path, config_path.with_suffix(".yaml.bak")) + try: + migrator = ConfigMigrator() + migrator.register(config_migration_1) + migrator.register(config_migration_2) + migrated_config_dict = migrator.run_migrations(loaded_config_dict) + except Exception as e: + shutil.copy(config_path.with_suffix(".yaml.bak"), config_path) + raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e + + # Attempt to load as a v4 config file + try: + config = InvokeAIAppConfig.model_validate(migrated_config_dict) + assert ( + config.schema_version == CONFIG_SCHEMA_VERSION + ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}" + return config + except Exception as e: + raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e + + +# TODO(psyche): This must must be in this file to avoid circular dependencies @lru_cache(maxsize=1) def get_config() -> InvokeAIAppConfig: """Get the global singleton app config. @@ -133,39 +170,3 @@ def get_config() -> InvokeAIAppConfig: default_config.write_file(config.config_file_path, as_example=False) return config - - -def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: - """Load and migrate a config file to the latest version. - - Args: - config_path: Path to the config file. - - Returns: - An instance of `InvokeAIAppConfig` with the loaded and migrated settings. - """ - assert config_path.suffix == ".yaml" - with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file: - loaded_config_dict: AppConfigDict = yaml.safe_load(file) - - assert isinstance(loaded_config_dict, dict) - - shutil.copy(config_path, config_path.with_suffix(".yaml.bak")) - try: - migrator = ConfigMigrator() - migrator.register(config_migration_1) - migrator.register(config_migration_2) - migrated_config_dict = migrator.run_migrations(loaded_config_dict) - except Exception as e: - shutil.copy(config_path.with_suffix(".yaml.bak"), config_path) - raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e - - # Attempt to load as a v4 config file - try: - config = InvokeAIAppConfig.model_validate(migrated_config_dict) - assert ( - config.schema_version == CONFIG_SCHEMA_VERSION - ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}" - return config - except Exception as e: - raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e From d4871029044142db46bd3211f632f8dd5b2d7a08 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 14 May 2024 16:55:22 +1000 Subject: [PATCH 21/25] fix(config): fix config _check_for_discontinuities Need to sort the migrations first. --- invokeai/app/services/config/config_migrate.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py index c847ef77a8f..df700ffc213 100644 --- a/invokeai/app/services/config/config_migrate.py +++ b/invokeai/app/services/config/config_migrate.py @@ -9,6 +9,7 @@ from copy import deepcopy from functools import lru_cache from pathlib import Path +from typing import Iterable import yaml from packaging.version import Version @@ -37,9 +38,10 @@ def register(self, migration: ConfigMigration) -> None: self._migrations.add(migration) @staticmethod - def _check_for_discontinuities(migrations: list[ConfigMigration]) -> None: + def _check_for_discontinuities(migrations: Iterable[ConfigMigration]) -> None: current_version = Version("3.0.0") - for m in migrations: + sorted_migrations = sorted(migrations, key=lambda x: x.from_version) + for m in sorted_migrations: if current_version != m.from_version: raise ValueError( f"Migration functions are not continuous. Expected from_version={current_version} but got from_version={m.from_version}, for migration function {m.function.__name__}" From 7d8b011f896c517238f07d3f9498ff59ee9030a0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 14 May 2024 16:55:44 +1000 Subject: [PATCH 22/25] fix(config): restore missing config field assignment in migration --- invokeai/app/services/config/migrations.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/invokeai/app/services/config/migrations.py b/invokeai/app/services/config/migrations.py index 7c0e3f9251f..c2512efbcf8 100644 --- a/invokeai/app/services/config/migrations.py +++ b/invokeai/app/services/config/migrations.py @@ -90,6 +90,9 @@ def migrate_v400_to_v401(original_config: AppConfigDict) -> AppConfigDict: # autocast was removed from precision in v4.0.1 if k == "precision" and v == "autocast": migrated_config["precision"] = "auto" + # skip unknown fields + elif k in InvokeAIAppConfig.model_fields: + migrated_config[k] = v migrated_config["schema_version"] = "4.0.1" return migrated_config From 964adb817cb7c31924e4eefc5cd37bba3437f190 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 14 May 2024 16:55:53 +1000 Subject: [PATCH 23/25] tests(config): update tests for config migration --- tests/test_config.py | 93 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/tests/test_config.py b/tests/test_config.py index 9feb723b015..43cbc6855f4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -8,11 +8,12 @@ from pydantic import ValidationError from invokeai.app.invocations.baseinvocation import BaseInvocation +from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration from invokeai.app.services.config.config_default import ( DefaultInvokeAIAppConfig, InvokeAIAppConfig, ) -from invokeai.app.services.config.config_migrate import get_config, load_and_migrate_config +from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config from invokeai.app.services.shared.graph import Graph from invokeai.frontend.cli.arg_parser import InvokeAIArgs @@ -75,6 +76,96 @@ def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None: monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path)) +def test_config_migrator_registers_migrations() -> None: + """Test that the config migrator registers migrations.""" + migrator = ConfigMigrator() + + def migration_func(config: AppConfigDict) -> AppConfigDict: + return config + + migration_1 = ConfigMigration(from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migration_func) + migration_2 = ConfigMigration(from_version=Version("4.0.0"), to_version=Version("5.0.0"), function=migration_func) + + migrator.register(migration_1) + assert migrator._migrations == {migration_1} + migrator.register(migration_2) + assert migrator._migrations == {migration_1, migration_2} + + +def test_config_migrator_rejects_duplicate_migrations() -> None: + """Test that the config migrator rejects duplicate migrations.""" + migrator = ConfigMigrator() + + def migration_func(config: AppConfigDict) -> AppConfigDict: + return config + + migration_1 = ConfigMigration(from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migration_func) + migrator.register(migration_1) + + # Re-register the same migration + with pytest.raises( + ValueError, + match=f"A migration from {migration_1.from_version} or to {migration_1.to_version} has already been registered.", + ): + migrator.register(migration_1) + + # Register a migration with the same from_version + migration_2 = ConfigMigration(from_version=Version("3.0.0"), to_version=Version("5.0.0"), function=migration_func) + with pytest.raises( + ValueError, + match=f"A migration from {migration_2.from_version} or to {migration_2.to_version} has already been registered.", + ): + migrator.register(migration_2) + + # Register a migration with the same to_version + migration_3 = ConfigMigration(from_version=Version("3.0.1"), to_version=Version("4.0.0"), function=migration_func) + with pytest.raises( + ValueError, + match=f"A migration from {migration_3.from_version} or to {migration_3.to_version} has already been registered.", + ): + migrator.register(migration_3) + + +def test_config_migrator_contiguous_migrations() -> None: + """Test that the config migrator requires contiguous migrations.""" + migrator = ConfigMigrator() + + def migration_1_func(config: AppConfigDict) -> AppConfigDict: + return {"schema_version": "4.0.0"} + + def migration_3_func(config: AppConfigDict) -> AppConfigDict: + return {"schema_version": "6.0.0"} + + migration_1 = ConfigMigration(from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migration_1_func) + migration_3 = ConfigMigration(from_version=Version("5.0.0"), to_version=Version("6.0.0"), function=migration_3_func) + + migrator.register(migration_1) + migrator.register(migration_3) + with pytest.raises(ValueError, match="Migration functions are not continuous"): + migrator._check_for_discontinuities(migrator._migrations) + + +def test_config_migrator_runs_migrations() -> None: + """Test that the config migrator runs migrations.""" + migrator = ConfigMigrator() + + def migration_1_func(config: AppConfigDict) -> AppConfigDict: + return {"schema_version": "4.0.0"} + + def migration_2_func(config: AppConfigDict) -> AppConfigDict: + return {"schema_version": "5.0.0"} + + migration_1 = ConfigMigration(from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migration_1_func) + migration_2 = ConfigMigration(from_version=Version("4.0.0"), to_version=Version("5.0.0"), function=migration_2_func) + + migrator.register(migration_1) + migrator.register(migration_2) + + original_config = {"schema_version": "3.0.0"} + migrated_config = migrator.run_migrations(original_config) + assert migrated_config == {"schema_version": "5.0.0"} + + def test_path_resolution_root_not_set(patch_rootdir: None): """Test path resolutions when the root is not explicitly set.""" config = InvokeAIAppConfig() From 6e40142a59c22e63a86d06796e03417130efc4e3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 14 May 2024 17:16:50 +1000 Subject: [PATCH 24/25] tests(config): test migrations directly, not via `load_and_migrate_config` --- tests/test_config.py | 38 +++++++++++++++++--------------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 43cbc6855f4..4ef4c2439cd 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,6 +4,7 @@ from typing import Any, Generator import pytest +import yaml from packaging.version import Version from pydantic import ValidationError @@ -14,6 +15,7 @@ InvokeAIAppConfig, ) from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config +from invokeai.app.services.config.migrations import migrate_v300_to_v400, migrate_v400_to_v401 from invokeai.app.services.shared.graph import Graph from invokeai.frontend.cli.arg_parser import InvokeAIArgs @@ -183,12 +185,10 @@ def test_read_config_from_file(tmp_path: Path, patch_rootdir: None): assert config.port == 8080 -def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None): +def test_migration_1_migrates_settings(tmp_path: Path, patch_rootdir: None): """Test reading configuration from a file.""" - temp_config_file = tmp_path / "temp_invokeai.yaml" - temp_config_file.write_text(v3_config) - - config = load_and_migrate_config(temp_config_file) + migrated_config_dict = migrate_v300_to_v400(yaml.safe_load(v3_config)) + config = InvokeAIAppConfig.model_validate(migrated_config_dict) assert config.outputs_dir == Path("/some/outputs/dir") assert config.host == "192.168.1.1" assert config.port == 8080 @@ -212,20 +212,18 @@ def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None): ("full/custom/path", Path("full/custom/path"), True), ], ) -def test_migrate_v3_legacy_conf_dir_defaults( - tmp_path: Path, patch_rootdir: None, legacy_conf_dir: str, expected_value: Path, expected_is_set: bool +def test_migration_1_handles_legacy_conf_dir_defaults( + legacy_conf_dir: str, expected_value: Path, expected_is_set: bool ): """Test reading configuration from a file.""" config_content = f"InvokeAI:\n Paths:\n legacy_conf_dir: {legacy_conf_dir}" - temp_config_file = tmp_path / "temp_invokeai.yaml" - temp_config_file.write_text(config_content) - - config = load_and_migrate_config(temp_config_file) + migrated_config_dict = migrate_v300_to_v400(yaml.safe_load(config_content)) + config = InvokeAIAppConfig.model_validate(migrated_config_dict) assert config.legacy_conf_dir == expected_value assert ("legacy_conf_dir" in config.model_fields_set) is expected_is_set -def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None): +def test_load_and_migrate_backs_up_file(tmp_path: Path, patch_rootdir: None): """Test the backup of the config file.""" temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file.write_text(v3_config) @@ -235,17 +233,15 @@ def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None): assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config -def test_migrate_v4(tmp_path: Path, patch_rootdir: None): +def test_migration_2_migrates_settings(): """Test migration from 4.0.0 to 4.0.1""" - temp_config_file = tmp_path / "temp_invokeai.yaml" - temp_config_file.write_text(v4_config) - - conf = load_and_migrate_config(temp_config_file) - assert Version(conf.schema_version) >= Version("4.0.1") - assert conf.precision == "auto" # we expect 'autocast' to be replaced with 'auto' during 4.0.1 migration + migrated_config_dict = migrate_v400_to_v401(yaml.safe_load(v4_config)) + config = InvokeAIAppConfig.model_validate(migrated_config_dict) + assert Version(config.schema_version) == Version("4.0.1") + assert config.precision == "auto" # we expect 'autocast' to be replaced with 'auto' during 4.0.1 migration -def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None): +def test_load_and_migrate_failed_migrate_backup(tmp_path: Path, patch_rootdir: None): """Test the failed migration of the config file.""" temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file.write_text(v3_config_with_bad_values) @@ -258,7 +254,7 @@ def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None): assert temp_config_file.read_text() == v3_config_with_bad_values -def test_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None): +def test_load_and_migrate_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None): """Test reading configuration from a file.""" temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file.write_text(invalid_config) From 8b76d112be49b9be35421ead363069e3c7059c02 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 14 May 2024 18:02:22 +1000 Subject: [PATCH 25/25] tests(config): set root to a tmp dir if didn't parse args This prevents tests from triggering config related parsing on your "live" root. --- .../app/services/config/config_migrate.py | 3 ++ tests/test_config.py | 38 ++++++++----------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py index df700ffc213..4884802e8c8 100644 --- a/invokeai/app/services/config/config_migrate.py +++ b/invokeai/app/services/config/config_migrate.py @@ -9,6 +9,7 @@ from copy import deepcopy from functools import lru_cache from pathlib import Path +from tempfile import TemporaryDirectory from typing import Iterable import yaml @@ -141,6 +142,8 @@ def get_config() -> InvokeAIAppConfig: # This flag serves as a proxy for whether the config was retrieved in the context of the full application or not. # If it is False, we should just return a default config and not set the root, log in to HF, etc. if not InvokeAIArgs.did_parse: + tmpdir = TemporaryDirectory() + config._root = Path(tmpdir.name) return config # Set CLI args diff --git a/tests/test_config.py b/tests/test_config.py index 4ef4c2439cd..0a451135cee 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,7 +1,7 @@ from contextlib import contextmanager from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Generator +from typing import Generator import pytest import yaml @@ -72,12 +72,6 @@ """ -@pytest.fixture -def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None: - """This may be overkill since the current tests don't need the root dir to exist""" - monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path)) - - def test_config_migrator_registers_migrations() -> None: """Test that the config migrator registers migrations.""" migrator = ConfigMigrator() @@ -168,14 +162,14 @@ def migration_2_func(config: AppConfigDict) -> AppConfigDict: assert migrated_config == {"schema_version": "5.0.0"} -def test_path_resolution_root_not_set(patch_rootdir: None): +def test_path_resolution_root_not_set(): """Test path resolutions when the root is not explicitly set.""" config = InvokeAIAppConfig() expected_root = InvokeAIAppConfig.find_root() assert config.root_path == expected_root -def test_read_config_from_file(tmp_path: Path, patch_rootdir: None): +def test_read_config_from_file(tmp_path: Path): """Test reading configuration from a file.""" temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file.write_text(v4_config) @@ -185,7 +179,7 @@ def test_read_config_from_file(tmp_path: Path, patch_rootdir: None): assert config.port == 8080 -def test_migration_1_migrates_settings(tmp_path: Path, patch_rootdir: None): +def test_migration_1_migrates_settings(tmp_path: Path): """Test reading configuration from a file.""" migrated_config_dict = migrate_v300_to_v400(yaml.safe_load(v3_config)) config = InvokeAIAppConfig.model_validate(migrated_config_dict) @@ -223,7 +217,7 @@ def test_migration_1_handles_legacy_conf_dir_defaults( assert ("legacy_conf_dir" in config.model_fields_set) is expected_is_set -def test_load_and_migrate_backs_up_file(tmp_path: Path, patch_rootdir: None): +def test_load_and_migrate_backs_up_file(tmp_path: Path): """Test the backup of the config file.""" temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file.write_text(v3_config) @@ -241,7 +235,7 @@ def test_migration_2_migrates_settings(): assert config.precision == "auto" # we expect 'autocast' to be replaced with 'auto' during 4.0.1 migration -def test_load_and_migrate_failed_migrate_backup(tmp_path: Path, patch_rootdir: None): +def test_load_and_migrate_failed_migrate_backup(tmp_path: Path): """Test the failed migration of the config file.""" temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file.write_text(v3_config_with_bad_values) @@ -254,7 +248,7 @@ def test_load_and_migrate_failed_migrate_backup(tmp_path: Path, patch_rootdir: N assert temp_config_file.read_text() == v3_config_with_bad_values -def test_load_and_migrate_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None): +def test_load_and_migrate_bails_on_invalid_config(tmp_path: Path): """Test reading configuration from a file.""" temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file.write_text(invalid_config) @@ -264,7 +258,7 @@ def test_load_and_migrate_bails_on_invalid_config(tmp_path: Path, patch_rootdir: @pytest.mark.parametrize("config_content", [invalid_v5_config, invalid_v4_0_1_config]) -def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None, config_content: str): +def test_bails_on_config_with_unsupported_version(tmp_path: Path, config_content: str): """Test reading configuration from a file.""" temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file.write_text(config_content) @@ -274,7 +268,7 @@ def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: load_and_migrate_config(temp_config_file) -def test_write_config_to_file(patch_rootdir: None): +def test_write_config_to_file(): """Test writing configuration to a file, checking for correct output.""" with TemporaryDirectory() as tmpdir: temp_config_path = Path(tmpdir) / "invokeai.yaml" @@ -289,7 +283,7 @@ def test_write_config_to_file(patch_rootdir: None): assert "port: 8080" in content -def test_update_config_with_dict(patch_rootdir: None): +def test_update_config_with_dict(): """Test updating the config with a dictionary.""" config = InvokeAIAppConfig() update_dict = {"host": "10.10.10.10", "port": 6060} @@ -298,7 +292,7 @@ def test_update_config_with_dict(patch_rootdir: None): assert config.port == 6060 -def test_update_config_with_object(patch_rootdir: None): +def test_update_config_with_object(): """Test updating the config with another config object.""" config = InvokeAIAppConfig() new_config = InvokeAIAppConfig(host="10.10.10.10", port=6060) @@ -307,7 +301,7 @@ def test_update_config_with_object(patch_rootdir: None): assert config.port == 6060 -def test_set_and_resolve_paths(patch_rootdir: None): +def test_set_and_resolve_paths(): """Test setting root and resolving paths based on it.""" with TemporaryDirectory() as tmpdir: config = InvokeAIAppConfig() @@ -316,7 +310,7 @@ def test_set_and_resolve_paths(patch_rootdir: None): assert config.db_path == Path(tmpdir).resolve() / "databases" / "invokeai.db" -def test_singleton_behavior(patch_rootdir: None): +def test_singleton_behavior(): """Test that get_config always returns the same instance.""" get_config.cache_clear() config1 = get_config() @@ -325,13 +319,13 @@ def test_singleton_behavior(patch_rootdir: None): get_config.cache_clear() -def test_default_config(patch_rootdir: None): +def test_default_config(): """Test that the default config is as expected.""" config = DefaultInvokeAIAppConfig() assert config.host == "127.0.0.1" -def test_env_vars(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch, tmp_path: Path): +def test_env_vars(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): """Test that environment variables are merged into the config""" monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path)) monkeypatch.setenv("INVOKEAI_HOST", "1.2.3.4") @@ -342,7 +336,7 @@ def test_env_vars(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch, tmp_path assert config.root_path == tmp_path -def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch, tmp_path: Path): +def test_get_config_writing(monkeypatch: pytest.MonkeyPatch, tmp_path: Path): """Test that get_config writes the appropriate files to disk""" # Trick the config into thinking it has already parsed args - this triggers the writing of the config file InvokeAIArgs.did_parse = True