Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: default processors for controlnet & t2i adapter #5896

Merged
merged 16 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 7 additions & 2 deletions invokeai/app/services/model_records/model_records_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType
from invokeai.backend.model_manager.config import (
ControlAdapterDefaultSettings,
MainModelDefaultSettings,
ModelVariantType,
SchedulerPredictionType,
)


class DuplicateModelException(Exception):
Expand Down Expand Up @@ -68,7 +73,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
description: Optional[str] = Field(description="Model description", default=None)
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None
)

Expand Down
28 changes: 20 additions & 8 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import time
from enum import Enum
from typing import Literal, Optional, Type, Union
from typing import Literal, Optional, Type, TypeAlias, Union

import torch
from diffusers.models.modeling_utils import ModelMixin
Expand Down Expand Up @@ -131,7 +131,7 @@ class ModelSourceType(str, Enum):
HFRepoID = "hf_repo_id"


class ModelDefaultSettings(BaseModel):
class MainModelDefaultSettings(BaseModel):
vae: str | None
vae_precision: str | None
scheduler: SCHEDULER_NAME_VALUES | None
Expand All @@ -140,6 +140,11 @@ class ModelDefaultSettings(BaseModel):
cfg_rescale_multiplier: float | None


class ControlAdapterDefaultSettings(BaseModel):
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
preprocessor: str | None


class ModelConfigBase(BaseModel):
"""Base class for model configuration information."""

Expand All @@ -156,9 +161,6 @@ class ModelConfigBase(BaseModel):
source_api_response: Optional[str] = Field(
description="The original API response from the source, as stringified JSON.", default=None
)
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)

@staticmethod
Expand Down Expand Up @@ -232,7 +234,13 @@ def get_tag() -> Tag:
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")


class ControlNetDiffusersConfig(DiffusersConfigBase):
class ControlAdapterConfigBase(BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None
)


class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):
"""Model config for ControlNet models (diffusers version)."""

type: Literal[ModelType.ControlNet] = ModelType.ControlNet
Expand All @@ -243,7 +251,7 @@ def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")


class ControlNetCheckpointConfig(CheckpointConfigBase):
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase):
"""Model config for ControlNet models (diffusers version)."""

type: Literal[ModelType.ControlNet] = ModelType.ControlNet
Expand Down Expand Up @@ -279,6 +287,9 @@ def get_tag() -> Tag:
class MainConfigBase(ModelConfigBase):
type: Literal[ModelType.Main] = ModelType.Main
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)


class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
Expand Down Expand Up @@ -324,7 +335,7 @@ def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")


class T2IAdapterConfig(ModelConfigBase):
class T2IAdapterConfig(ModelConfigBase, ControlAdapterConfigBase):
"""Model config for T2I."""

type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
Expand Down Expand Up @@ -376,6 +387,7 @@ def get_model_discriminator_value(v: Any) -> str:
]

AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, ControlAdapterDefaultSettings]


class ModelConfigFactory(object):
Expand Down
47 changes: 43 additions & 4 deletions invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .config import (
AnyModelConfig,
BaseModelType,
ControlAdapterDefaultSettings,
InvalidModelConfigException,
ModelConfigFactory,
ModelFormat,
Expand Down Expand Up @@ -159,6 +160,12 @@ def probe(
fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)

fields["default_settings"] = (
fields.get("default_settings") or probe.get_default_settings(fields["name"])
if isinstance(probe, ControlAdapterProbe)
else None
)

if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()

Expand Down Expand Up @@ -329,6 +336,38 @@ def _scan_model(cls, model_name: str, checkpoint: Path) -> None:
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")


class ControlAdapterProbe(ProbeBase):
"""Adds `get_default_settings` for ControlNet and T2IAdapter probes"""

# TODO(psyche): It would be nice to get these from the invocations, but that creates circular dependencies.
# "canny": CannyImageProcessorInvocation.get_type()
MODEL_NAME_TO_PREPROCESSOR = {
"canny": "canny_image_processor",
"mlsd": "mlsd_image_processor",
"depth": "depth_anything_image_processor",
"bae": "normalbae_image_processor",
"normal": "normalbae_image_processor",
"sketch": "pidi_image_processor",
"scribble": "lineart_image_processor",
"lineart": "lineart_image_processor",
"lineart_anime": "lineart_anime_image_processor",
"softedge": "hed_image_processor",
"shuffle": "content_shuffle_image_processor",
"pose": "dw_openpose_image_processor",
"mediapipe": "mediapipe_face_processor",
"pidi": "pidi_image_processor",
"zoe": "zoe_depth_image_processor",
"color": "color_map_image_processor",
}

@classmethod
def get_default_settings(cls, model_name: str) -> Optional[ControlAdapterDefaultSettings]:
for k, v in cls.MODEL_NAME_TO_PREPROCESSOR.items():
if k in model_name:
return ControlAdapterDefaultSettings(preprocessor=v)
return None


# ##################################################3
# Checkpoint probing
# ##################################################3
Expand Down Expand Up @@ -452,7 +491,7 @@ def get_base_type(self) -> BaseModelType:
raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type")


class ControlNetCheckpointProbe(CheckpointProbeBase):
class ControlNetCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe):
"""Class for probing controlnets."""

def get_base_type(self) -> BaseModelType:
Expand Down Expand Up @@ -480,7 +519,7 @@ def get_base_type(self) -> BaseModelType:
raise NotImplementedError()


class T2IAdapterCheckpointProbe(CheckpointProbeBase):
class T2IAdapterCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()

Expand Down Expand Up @@ -618,7 +657,7 @@ def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal


class ControlNetFolderProbe(FolderProbeBase):
class ControlNetFolderProbe(FolderProbeBase, ControlAdapterProbe):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
if not config_file.exists():
Expand Down Expand Up @@ -692,7 +731,7 @@ def get_base_type(self) -> BaseModelType:
return BaseModelType.Any


class T2IAdapterFolderProbe(FolderProbeBase):
class T2IAdapterFolderProbe(FolderProbeBase, ControlAdapterProbe):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
if not config_file.exists():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { map } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';

export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => {
startAppListening({
Expand All @@ -36,61 +37,64 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni

const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();

if (!modelConfig || !modelConfig.default_settings) {
if (!modelConfig) {
return;
}

const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings;
if (isNonRefinerMainModelConfig(modelConfig) && modelConfig.default_settings) {
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } =
modelConfig.default_settings;

if (vae) {
// we store this as "default" within default settings
// to distinguish it from no default set
if (vae === 'default') {
dispatch(vaeSelected(null));
} else {
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
const vaeArray = map(data?.entities);
const validVae = vaeArray.find((model) => model.key === vae);
if (vae) {
// we store this as "default" within default settings
// to distinguish it from no default set
if (vae === 'default') {
dispatch(vaeSelected(null));
} else {
const { data } = modelsApi.endpoints.getVaeModels.select()(state);
const vaeArray = map(data?.entities);
const validVae = vaeArray.find((model) => model.key === vae);

const result = zParameterVAEModel.safeParse(validVae);
if (!result.success) {
return;
const result = zParameterVAEModel.safeParse(validVae);
if (!result.success) {
return;
}
dispatch(vaeSelected(result.data));
}
dispatch(vaeSelected(result.data));
}
}

if (vae_precision) {
if (isParameterPrecision(vae_precision)) {
dispatch(vaePrecisionChanged(vae_precision));
if (vae_precision) {
if (isParameterPrecision(vae_precision)) {
dispatch(vaePrecisionChanged(vae_precision));
}
}
}

if (cfg_scale) {
if (isParameterCFGScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale));
if (cfg_scale) {
if (isParameterCFGScale(cfg_scale)) {
dispatch(setCfgScale(cfg_scale));
}
}
}

if (cfg_rescale_multiplier) {
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
if (cfg_rescale_multiplier) {
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
}
}
}

if (steps) {
if (isParameterSteps(steps)) {
dispatch(setSteps(steps));
if (steps) {
if (isParameterSteps(steps)) {
dispatch(setSteps(steps));
}
}
}

if (scheduler) {
if (isParameterScheduler(scheduler)) {
dispatch(setScheduler(scheduler));
if (scheduler) {
if (isParameterScheduler(scheduler)) {
dispatch(setScheduler(scheduler));
}
}
}

dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) })));
}
},
});
};
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
import { useControlAdapterShouldAutoConfig } from 'features/controlAdapters/hooks/useControlAdapterShouldAutoConfig';
import { controlAdapterAutoConfigToggled } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isNil } from 'lodash-es';
Expand All @@ -14,12 +15,13 @@ type Props = {
const ControlAdapterShouldAutoConfig = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const shouldAutoConfig = useControlAdapterShouldAutoConfig(id);
const { modelConfig } = useControlAdapterModel(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();

const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlAdapterAutoConfigToggled({ id }));
}, [id, dispatch]);
dispatch(controlAdapterAutoConfigToggled({ id, modelConfig }));
}, [id, dispatch, modelConfig]);

if (isNil(shouldAutoConfig)) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import { useControlAdapterModel } from 'features/controlAdapters/hooks/useContro
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
import { memo, useCallback, useMemo } from 'react';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';

Expand All @@ -17,30 +16,30 @@ type ParamControlAdapterModelProps = {
const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const isEnabled = useControlAdapterIsEnabled(id);
const controlAdapterType = useControlAdapterType(id);
const model = useControlAdapterModel(id);
const { modelConfig } = useControlAdapterModel(id);
const dispatch = useAppDispatch();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);

const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType);

const _onChange = useCallback(
(model: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
if (!model) {
(modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => {
if (!modelConfig) {
return;
}
dispatch(
controlAdapterModelChanged({
id,
model: getModelKeyAndBase(model),
modelConfig,
})
);
},
[dispatch, id]
);

const selectedModel = useMemo(
() => (model && controlAdapterType ? { ...model, model_type: controlAdapterType } : null),
[controlAdapterType, model]
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
[controlAdapterType, modelConfig]
);

const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({
Expand Down