Skip to content

Commit

Permalink
[mm] Do not write diffuser model to disk when convert_cache set to ze…
Browse files Browse the repository at this point in the history
…ro (#6072)

* pass model config to _load_model

* make conversion work again

* do not write diffusers to disk when convert_cache set to 0

* adding same model to cache twice is a no-op, not an assertion error

* fix issues identified by psychedelicious during pr review

* following conversion, avoid redundant read of cached submodels

* fix error introduced while merging

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
  • Loading branch information
lstein and Lincoln Stein committed Mar 29, 2024
1 parent 0ac1c0f commit 3d6d89f
Show file tree
Hide file tree
Showing 14 changed files with 146 additions and 132 deletions.
10 changes: 8 additions & 2 deletions invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,8 @@ async def convert_model(
The return value is the model configuration for the converted model.
"""
model_manager = ApiDependencies.invoker.services.model_manager
loader = model_manager.load
logger = ApiDependencies.invoker.services.logger
loader = ApiDependencies.invoker.services.model_manager.load
store = ApiDependencies.invoker.services.model_manager.store
installer = ApiDependencies.invoker.services.model_manager.install

Expand All @@ -630,7 +630,13 @@ async def convert_model(
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")

# loading the model will convert it into a cached diffusers file
model_manager.load.load_model(model_config, submodel_type=SubModelType.Scheduler)
try:
cc_size = loader.convert_cache.max_size
if cc_size == 0: # temporary set the convert cache to a positive number so that cached model is written
loader._convert_cache.max_size = 1.0
loader.load_model(model_config, submodel_type=SubModelType.Scheduler)
finally:
loader._convert_cache.max_size = cc_size

# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)
Expand Down
39 changes: 0 additions & 39 deletions invokeai/backend/model_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,42 +33,3 @@
"SchedulerPredictionType",
"SubModelType",
]

########## to help populate the openapi_schema with format enums for each config ###########
# This code is no longer necessary?
# leave it here just in case
#
# import inspect
# from enum import Enum
# from typing import Any, Iterable, Dict, get_args, Set
# def _expand(something: Any) -> Iterable[type]:
# if isinstance(something, type):
# yield something
# else:
# for x in get_args(something):
# for y in _expand(x):
# yield y

# def _find_format(cls: type) -> Iterable[Enum]:
# if hasattr(inspect, "get_annotations"):
# fields = inspect.get_annotations(cls)
# else:
# fields = cls.__annotations__
# if "format" in fields:
# for x in get_args(fields["format"]):
# yield x
# for parent_class in cls.__bases__:
# for x in _find_format(parent_class):
# yield x
# return None

# def get_model_config_formats() -> Dict[str, Set[Enum]]:
# result: Dict[str, Set[Enum]] = {}
# for model_config in _expand(AnyModelConfig):
# for field in _find_format(model_config):
# if field is None:
# continue
# if not result.get(model_config.__qualname__):
# result[model_config.__qualname__] = set()
# result[model_config.__qualname__].add(field)
# return result
26 changes: 16 additions & 10 deletions invokeai/backend/model_manager/convert_ckpt_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""Conversion script for the Stable Diffusion checkpoints."""

from pathlib import Path
from typing import Dict
from typing import Dict, Optional

import torch
from diffusers import AutoencoderKL
Expand All @@ -15,6 +15,8 @@
)
from omegaconf import DictConfig

from . import AnyModel


def convert_ldm_vae_to_diffusers(
checkpoint: Dict[str, torch.Tensor],
Expand All @@ -33,11 +35,11 @@ def convert_ldm_vae_to_diffusers(

def convert_ckpt_to_diffusers(
checkpoint_path: str | Path,
dump_path: str | Path,
dump_path: Optional[str | Path] = None,
precision: torch.dtype = torch.float16,
use_safetensors: bool = True,
**kwargs,
):
) -> AnyModel:
"""
Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
and in addition a path-like object indicating the location of the desired diffusers
Expand All @@ -47,18 +49,20 @@ def convert_ckpt_to_diffusers(
pipe = pipe.to(precision)

# TO DO: save correct repo variant
pipe.save_pretrained(
dump_path,
safe_serialization=use_safetensors,
)
if dump_path:
pipe.save_pretrained(
dump_path,
safe_serialization=use_safetensors,
)
return pipe


def convert_controlnet_to_diffusers(
checkpoint_path: Path,
dump_path: Path,
dump_path: Optional[Path] = None,
precision: torch.dtype = torch.float16,
**kwargs,
):
) -> AnyModel:
"""
Takes all the arguments of download_controlnet_from_original_ckpt(),
and in addition a path-like object indicating the location of the desired diffusers
Expand All @@ -68,4 +72,6 @@ def convert_controlnet_to_diffusers(
pipe = pipe.to(precision)

# TO DO: save correct repo variant
pipe.save_pretrained(dump_path, safe_serialization=True)
if dump_path:
pipe.save_pretrained(dump_path, safe_serialization=True)
return pipe
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,20 @@ def __init__(self, cache_path: Path, max_size: float = 10.0):
self._cache_path = cache_path
self._max_size = max_size

# adjust cache size at startup in case it has been changed
if self._cache_path.exists():
self.make_room(0.0)

@property
def max_size(self) -> float:
"""Return the maximum size of this cache directory (GB)."""
return self._max_size

@max_size.setter
def max_size(self, value: float) -> None:
"""Set the maximum size of this cache directory (GB)."""
self._max_size = value

def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
return self._cache_path / key
Expand Down
12 changes: 12 additions & 0 deletions invokeai/backend/model_manager/load/load_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,15 @@ def get_size_fs(
) -> int:
"""Return size in bytes of the model, calculated before loading."""
pass

@property
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the convert cache associated with this loader."""
pass

@property
@abstractmethod
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the ram cache associated with this loader."""
pass
78 changes: 44 additions & 34 deletions invokeai/backend/model_manager/load/load_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

from logging import Logger
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional

from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
InvalidModelConfigException,
ModelRepoVariant,
SubModelType,
)
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
Expand Down Expand Up @@ -54,51 +53,43 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo
if model_config.type is ModelType.Main and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model")

model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
model_path = self._get_model_path(model_config)

if not model_path.exists():
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")

model_path = self._convert_if_needed(model_config, model_path, submodel_type)
locker = self._load_if_needed(model_config, model_path, submodel_type)
with skip_torch_weight_init():
locker = self._convert_and_load(model_config, model_path, submodel_type)
return LoadedModel(config=model_config, _locker=locker)

def _get_model_path(
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
model_base = self._app_config.models_path
result = (model_base / config.path).resolve(), config, submodel_type
return result

def _convert_if_needed(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> Path:
cache_path: Path = self._convert_cache.cache_path(config.key)
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the convert cache associated with this loader."""
return self._convert_cache

if not self._needs_conversion(config, model_path, cache_path):
return cache_path if cache_path.exists() else model_path
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the ram cache associated with this loader."""
return self._ram_cache

self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
return self._convert_model(config, model_path, cache_path)

def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
return False
def _get_model_path(self, config: AnyModelConfig) -> Path:
model_base = self._app_config.models_path
return (model_base / config.path).resolve()

def _load_if_needed(
def _convert_and_load(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> ModelLockerBase:
# TO DO: This is not thread safe!
try:
return self._ram_cache.get(config.key, submodel_type)
except IndexError:
pass

model_variant = getattr(config, "repo_variant", None)
self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type))

# This is where the model is actually loaded!
with skip_torch_weight_init():
loaded_model = self._load_model(model_path, model_variant=model_variant, submodel_type=submodel_type)
cache_path: Path = self._convert_cache.cache_path(config.key)
if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
else:
config.path = str(cache_path) if cache_path.exists() else str(self._get_model_path(config))
loaded_model = self._load_model(config, submodel_type)

self._ram_cache.put(
config.key,
Expand All @@ -123,15 +114,34 @@ def get_size_fs(
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
)

def _do_convert(
self, config: AnyModelConfig, model_path: Path, cache_path: Path, submodel_type: Optional[SubModelType] = None
) -> AnyModel:
self.convert_cache.make_room(calc_model_size_by_fs(model_path))
pipeline = self._convert_model(config, model_path, cache_path if self.convert_cache.max_size > 0 else None)
if submodel_type:
# Proactively load the various submodels into the RAM cache so that we don't have to re-convert
# the entire pipeline every time a new submodel is needed.
for subtype in SubModelType:
if subtype == submodel_type:
continue
if submodel := getattr(pipeline, subtype.value, None):
self._ram_cache.put(
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
)
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline

def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
return False

# This needs to be implemented in subclasses that handle checkpoints
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
raise NotImplementedError

# This needs to be implemented in the subclass
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
raise NotImplementedError
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def max_cache_size(self) -> float:
"""Return the cap on cache size."""
return self._max_cache_size

@max_cache_size.setter
def max_cache_size(self, value: float) -> None:
"""Set the cap on cache size."""
self._max_cache_size = value

@property
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
Expand Down Expand Up @@ -157,8 +162,9 @@ def put(
) -> None:
"""Store model under key and optional submodel_type."""
key = self._make_cache_key(key, submodel_type)
assert key not in self._cached_models

if key in self._cached_models:
return
self.make_room(size)
cache_record = CacheRecord(key, model, size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
Expand Down Expand Up @@ -405,6 +411,8 @@ def make_room(self, model_size: int) -> None:
#
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
# immediately when their reference count hits 0.
if self.stats:
self.stats.cleared = models_cleared
gc.collect()

torch.cuda.empty_cache()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
"""Class for ControlNet model loading in InvokeAI."""

from pathlib import Path
from typing import Optional

from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
Expand Down Expand Up @@ -33,7 +35,7 @@ def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path:
else:
return True

def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
assert isinstance(config, CheckpointConfigBase)
image_size = (
512
Expand All @@ -45,12 +47,12 @@ def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path:

self._logger.info(f"Converting {model_path} to diffusers format")
with open(self._app_config.legacy_conf_path / config.config_path, "r") as config_stream:
convert_controlnet_to_diffusers(
result = convert_controlnet_to_diffusers(
model_path,
output_path,
original_config_file=config_stream,
image_size=image_size,
precision=self._torch_dtype,
from_safetensors=model_path.suffix == ".safetensors",
)
return output_path
return result
Loading

0 comments on commit 3d6d89f

Please sign in to comment.