Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow lists of basemodel objects in omegaconf #5922

Merged
merged 6 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 18 additions & 1 deletion invokeai/app/services/config/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints

from omegaconf import DictConfig, ListConfig, OmegaConf
from omegaconf import DictConfig, DictKeyType, ListConfig, OmegaConf
from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict

from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
Expand Down Expand Up @@ -62,6 +63,22 @@ def to_yaml(self) -> str:
assert isinstance(category, str)
if category not in field_dict[type]:
field_dict[type][category] = {}
if isinstance(value, BaseModel):
dump = value.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True)
field_dict[type][category][name] = dump
continue
if isinstance(value, list):
if not value or len(value) == 0:
continue
primitive = isinstance(value[0], get_args(DictKeyType))
if not primitive:
val_list: List[Dict[str, Any]] = []
for list_val in value:
if isinstance(list_val, BaseModel):
dump = list_val.model_dump(exclude_defaults=True, exclude_unset=True, exclude_none=True)
val_list.append(dump)
field_dict[type][category][name] = val_list
continue
# keep paths as strings to make it easier to read
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
conf = OmegaConf.create(field_dict)
Expand Down
16 changes: 14 additions & 2 deletions invokeai/backend/install/invokeai_configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from argparse import Namespace
from enum import Enum
from pathlib import Path
from shutil import get_terminal_size
from shutil import copy, get_terminal_size, move
from typing import Any, Optional, Set, Tuple, Type, get_args, get_type_hints
from urllib import request

Expand Down Expand Up @@ -929,6 +929,10 @@ def main() -> None:

errors = set()
FORCE_FULL_PRECISION = opt.full_precision # FIXME global
new_init_file = config.root_path / "invokeai.yaml"
backup_init_file = new_init_file.with_suffix(".bak")
if new_init_file.exists():
copy(new_init_file, backup_init_file)

try:
# if we do a root migration/upgrade, then we are keeping previous
Expand All @@ -943,7 +947,6 @@ def main() -> None:
install_helper = InstallHelper(config, logger)

models_to_download = default_user_selections(opt, install_helper)
new_init_file = config.root_path / "invokeai.yaml"

if opt.yes_to_all:
write_default_options(opt, new_init_file)
Expand Down Expand Up @@ -975,8 +978,17 @@ def main() -> None:
input("Press any key to continue...")
except WindowTooSmallException as e:
logger.error(str(e))
if backup_init_file.exists():
move(backup_init_file, new_init_file)
hipsterusername marked this conversation as resolved.
Show resolved Hide resolved
except KeyboardInterrupt:
print("\nGoodbye! Come back soon.")
if backup_init_file.exists():
move(backup_init_file, new_init_file)
except Exception:
print("An error occurred during installation.")
if backup_init_file.exists():
move(backup_init_file, new_init_file)
print(traceback.format_exc(), file=sys.stderr)


# -------------------------------------
Expand Down