Skip to content

Commit

Permalink
[MM] Add support for probing and loading SDXL VAE checkpoint files (#…
Browse files Browse the repository at this point in the history
…6524)

* add support for probing and loading SDXL VAE checkpoint files

* broaden regexp probe for SDXL VAEs

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
  • Loading branch information
lstein and Lincoln Stein committed Jun 20, 2024
1 parent a43d602 commit b03073d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
11 changes: 3 additions & 8 deletions invokeai/backend/model_manager/load/model_loaders/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@


@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class VAELoader(GenericDiffusersLoader):
"""Class to load VAE models."""

Expand All @@ -40,12 +39,8 @@ def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path:
return True

def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
# TODO(MM2): check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"VAE conversion not supported for model type: {config.base}")
else:
assert isinstance(config, CheckpointConfigBase)
config_file = self._app_config.legacy_conf_path / config.config_path
assert isinstance(config, CheckpointConfigBase)
config_file = self._app_config.legacy_conf_path / config.config_path

if model_path.suffix == ".safetensors":
checkpoint = safetensors_load_file(model_path, device="cpu")
Expand Down
12 changes: 10 additions & 2 deletions invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,16 @@ def get_scheduler_prediction_type(self) -> SchedulerPredictionType:

class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
# VAEs of all base types have the same structure, so we wimp out and
# guess using the name.
for regexp, basetype in [
(r"xl", BaseModelType.StableDiffusionXL),
(r"sd2", BaseModelType.StableDiffusion2),
(r"vae", BaseModelType.StableDiffusion1),
]:
if re.search(regexp, self.model_path.name, re.IGNORECASE):
return basetype
raise InvalidModelConfigException("Cannot determine base type")


class LoRACheckpointProbe(CheckpointProbeBase):
Expand Down

0 comments on commit b03073d

Please sign in to comment.