diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 0b56b7f6c02..094ade63838 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -18,7 +18,12 @@ ModelFormat, ModelType, ) -from invokeai.backend.model_manager.config import ModelDefaultSettings, ModelVariantType, SchedulerPredictionType +from invokeai.backend.model_manager.config import ( + ControlAdapterDefaultSettings, + MainModelDefaultSettings, + ModelVariantType, + SchedulerPredictionType, +) class DuplicateModelException(Exception): @@ -68,7 +73,7 @@ class ModelRecordChanges(BaseModelExcludeNull): description: Optional[str] = Field(description="Model description", default=None) base: Optional[BaseModelType] = Field(description="The base model.", default=None) trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) - default_settings: Optional[ModelDefaultSettings] = Field( + default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field( description="Default settings for this model", default=None ) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index de86b20fb28..9261f0e50f1 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -22,7 +22,7 @@ import time from enum import Enum -from typing import Literal, Optional, Type, Union +from typing import Literal, Optional, Type, TypeAlias, Union import torch from diffusers.models.modeling_utils import ModelMixin @@ -131,7 +131,7 @@ class ModelSourceType(str, Enum): HFRepoID = "hf_repo_id" -class ModelDefaultSettings(BaseModel): +class MainModelDefaultSettings(BaseModel): vae: str | None vae_precision: str | None scheduler: SCHEDULER_NAME_VALUES | None @@ -140,6 +140,11 @@ class ModelDefaultSettings(BaseModel): cfg_rescale_multiplier: float | None +class ControlAdapterDefaultSettings(BaseModel): + # This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer. + preprocessor: str | None + + class ModelConfigBase(BaseModel): """Base class for model configuration information.""" @@ -156,9 +161,6 @@ class ModelConfigBase(BaseModel): source_api_response: Optional[str] = Field( description="The original API response from the source, as stringified JSON.", default=None ) - default_settings: Optional[ModelDefaultSettings] = Field( - description="Default settings for this model", default=None - ) cover_image: Optional[str] = Field(description="Url for image to preview model", default=None) @staticmethod @@ -232,7 +234,13 @@ def get_tag() -> Tag: return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}") -class ControlNetDiffusersConfig(DiffusersConfigBase): +class ControlAdapterConfigBase(BaseModel): + default_settings: Optional[ControlAdapterDefaultSettings] = Field( + description="Default settings for this model", default=None + ) + + +class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase): """Model config for ControlNet models (diffusers version).""" type: Literal[ModelType.ControlNet] = ModelType.ControlNet @@ -243,7 +251,7 @@ def get_tag() -> Tag: return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}") -class ControlNetCheckpointConfig(CheckpointConfigBase): +class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase): """Model config for ControlNet models (diffusers version).""" type: Literal[ModelType.ControlNet] = ModelType.ControlNet @@ -279,6 +287,9 @@ def get_tag() -> Tag: class MainConfigBase(ModelConfigBase): type: Literal[ModelType.Main] = ModelType.Main trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) + default_settings: Optional[MainModelDefaultSettings] = Field( + description="Default settings for this model", default=None + ) class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase): @@ -324,7 +335,7 @@ def get_tag() -> Tag: return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}") -class T2IAdapterConfig(ModelConfigBase): +class T2IAdapterConfig(ModelConfigBase, ControlAdapterConfigBase): """Model config for T2I.""" type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter @@ -376,6 +387,7 @@ def get_model_discriminator_value(v: Any) -> str: ] AnyModelConfigValidator = TypeAdapter(AnyModelConfig) +AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, ControlAdapterDefaultSettings] class ModelConfigFactory(object): diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 75925dcf0b6..cfcb1f154eb 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -14,6 +14,7 @@ from .config import ( AnyModelConfig, BaseModelType, + ControlAdapterDefaultSettings, InvalidModelConfigException, ModelConfigFactory, ModelFormat, @@ -159,6 +160,12 @@ def probe( fields["format"] = fields.get("format") or probe.get_format() fields["hash"] = fields.get("hash") or ModelHash().hash(model_path) + fields["default_settings"] = ( + fields.get("default_settings") or probe.get_default_settings(fields["name"]) + if isinstance(probe, ControlAdapterProbe) + else None + ) + if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase): fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() @@ -329,6 +336,38 @@ def _scan_model(cls, model_name: str, checkpoint: Path) -> None: raise Exception("The model {model_name} is potentially infected by malware. Aborting import.") +class ControlAdapterProbe(ProbeBase): + """Adds `get_default_settings` for ControlNet and T2IAdapter probes""" + + # TODO(psyche): It would be nice to get these from the invocations, but that creates circular dependencies. + # "canny": CannyImageProcessorInvocation.get_type() + MODEL_NAME_TO_PREPROCESSOR = { + "canny": "canny_image_processor", + "mlsd": "mlsd_image_processor", + "depth": "depth_anything_image_processor", + "bae": "normalbae_image_processor", + "normal": "normalbae_image_processor", + "sketch": "pidi_image_processor", + "scribble": "lineart_image_processor", + "lineart": "lineart_image_processor", + "lineart_anime": "lineart_anime_image_processor", + "softedge": "hed_image_processor", + "shuffle": "content_shuffle_image_processor", + "pose": "dw_openpose_image_processor", + "mediapipe": "mediapipe_face_processor", + "pidi": "pidi_image_processor", + "zoe": "zoe_depth_image_processor", + "color": "color_map_image_processor", + } + + @classmethod + def get_default_settings(cls, model_name: str) -> Optional[ControlAdapterDefaultSettings]: + for k, v in cls.MODEL_NAME_TO_PREPROCESSOR.items(): + if k in model_name: + return ControlAdapterDefaultSettings(preprocessor=v) + return None + + # ##################################################3 # Checkpoint probing # ##################################################3 @@ -452,7 +491,7 @@ def get_base_type(self) -> BaseModelType: raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type") -class ControlNetCheckpointProbe(CheckpointProbeBase): +class ControlNetCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe): """Class for probing controlnets.""" def get_base_type(self) -> BaseModelType: @@ -480,7 +519,7 @@ def get_base_type(self) -> BaseModelType: raise NotImplementedError() -class T2IAdapterCheckpointProbe(CheckpointProbeBase): +class T2IAdapterCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe): def get_base_type(self) -> BaseModelType: raise NotImplementedError() @@ -618,7 +657,7 @@ def get_variant_type(self) -> ModelVariantType: return ModelVariantType.Normal -class ControlNetFolderProbe(FolderProbeBase): +class ControlNetFolderProbe(FolderProbeBase, ControlAdapterProbe): def get_base_type(self) -> BaseModelType: config_file = self.model_path / "config.json" if not config_file.exists(): @@ -692,7 +731,7 @@ def get_base_type(self) -> BaseModelType: return BaseModelType.Any -class T2IAdapterFolderProbe(FolderProbeBase): +class T2IAdapterFolderProbe(FolderProbeBase, ControlAdapterProbe): def get_base_type(self) -> BaseModelType: config_file = self.model_path / "config.json" if not config_file.exists(): diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts index da7cf6b6fea..e76f9de8f0b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/setDefaultSettings.ts @@ -21,6 +21,7 @@ import { makeToast } from 'features/system/util/makeToast'; import { t } from 'i18next'; import { map } from 'lodash-es'; import { modelsApi } from 'services/api/endpoints/models'; +import { isNonRefinerMainModelConfig } from 'services/api/types'; export const addSetDefaultSettingsListener = (startAppListening: AppStartListening) => { startAppListening({ @@ -36,61 +37,64 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap(); - if (!modelConfig || !modelConfig.default_settings) { + if (!modelConfig) { return; } - const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings; + if (isNonRefinerMainModelConfig(modelConfig) && modelConfig.default_settings) { + const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = + modelConfig.default_settings; - if (vae) { - // we store this as "default" within default settings - // to distinguish it from no default set - if (vae === 'default') { - dispatch(vaeSelected(null)); - } else { - const { data } = modelsApi.endpoints.getVaeModels.select()(state); - const vaeArray = map(data?.entities); - const validVae = vaeArray.find((model) => model.key === vae); + if (vae) { + // we store this as "default" within default settings + // to distinguish it from no default set + if (vae === 'default') { + dispatch(vaeSelected(null)); + } else { + const { data } = modelsApi.endpoints.getVaeModels.select()(state); + const vaeArray = map(data?.entities); + const validVae = vaeArray.find((model) => model.key === vae); - const result = zParameterVAEModel.safeParse(validVae); - if (!result.success) { - return; + const result = zParameterVAEModel.safeParse(validVae); + if (!result.success) { + return; + } + dispatch(vaeSelected(result.data)); } - dispatch(vaeSelected(result.data)); } - } - if (vae_precision) { - if (isParameterPrecision(vae_precision)) { - dispatch(vaePrecisionChanged(vae_precision)); + if (vae_precision) { + if (isParameterPrecision(vae_precision)) { + dispatch(vaePrecisionChanged(vae_precision)); + } } - } - if (cfg_scale) { - if (isParameterCFGScale(cfg_scale)) { - dispatch(setCfgScale(cfg_scale)); + if (cfg_scale) { + if (isParameterCFGScale(cfg_scale)) { + dispatch(setCfgScale(cfg_scale)); + } } - } - if (cfg_rescale_multiplier) { - if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) { - dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier)); + if (cfg_rescale_multiplier) { + if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) { + dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier)); + } } - } - if (steps) { - if (isParameterSteps(steps)) { - dispatch(setSteps(steps)); + if (steps) { + if (isParameterSteps(steps)) { + dispatch(setSteps(steps)); + } } - } - if (scheduler) { - if (isParameterScheduler(scheduler)) { - dispatch(setScheduler(scheduler)); + if (scheduler) { + if (isParameterScheduler(scheduler)) { + dispatch(setScheduler(scheduler)); + } } - } - dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) }))); + dispatch(addToast(makeToast({ title: t('toast.parameterSet', { parameter: 'Default settings' }) }))); + } }, }); }; diff --git a/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterShouldAutoConfig.tsx b/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterShouldAutoConfig.tsx index 7399febb2df..cb3d36c58d9 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterShouldAutoConfig.tsx +++ b/invokeai/frontend/web/src/features/controlAdapters/components/ControlAdapterShouldAutoConfig.tsx @@ -1,6 +1,7 @@ import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled'; +import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel'; import { useControlAdapterShouldAutoConfig } from 'features/controlAdapters/hooks/useControlAdapterShouldAutoConfig'; import { controlAdapterAutoConfigToggled } from 'features/controlAdapters/store/controlAdaptersSlice'; import { isNil } from 'lodash-es'; @@ -14,12 +15,13 @@ type Props = { const ControlAdapterShouldAutoConfig = ({ id }: Props) => { const isEnabled = useControlAdapterIsEnabled(id); const shouldAutoConfig = useControlAdapterShouldAutoConfig(id); + const { modelConfig } = useControlAdapterModel(id); const dispatch = useAppDispatch(); const { t } = useTranslation(); const handleShouldAutoConfigChanged = useCallback(() => { - dispatch(controlAdapterAutoConfigToggled({ id })); - }, [id, dispatch]); + dispatch(controlAdapterAutoConfigToggled({ id, modelConfig })); + }, [id, dispatch, modelConfig]); if (isNil(shouldAutoConfig)) { return null; diff --git a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx index d7cf2e8452f..8bd80a8d2ec 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx +++ b/invokeai/frontend/web/src/features/controlAdapters/components/parameters/ParamControlAdapterModel.tsx @@ -6,7 +6,6 @@ import { useControlAdapterModel } from 'features/controlAdapters/hooks/useContro import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery'; import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType'; import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; -import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers'; import { memo, useCallback, useMemo } from 'react'; import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; @@ -17,21 +16,21 @@ type ParamControlAdapterModelProps = { const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { const isEnabled = useControlAdapterIsEnabled(id); const controlAdapterType = useControlAdapterType(id); - const model = useControlAdapterModel(id); + const { modelConfig } = useControlAdapterModel(id); const dispatch = useAppDispatch(); const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const { data, isLoading } = useControlAdapterModelQuery(controlAdapterType); const _onChange = useCallback( - (model: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => { - if (!model) { + (modelConfig: ControlNetModelConfig | IPAdapterModelConfig | T2IAdapterModelConfig | null) => { + if (!modelConfig) { return; } dispatch( controlAdapterModelChanged({ id, - model: getModelKeyAndBase(model), + modelConfig, }) ); }, @@ -39,8 +38,8 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => { ); const selectedModel = useMemo( - () => (model && controlAdapterType ? { ...model, model_type: controlAdapterType } : null), - [controlAdapterType, model] + () => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null), + [controlAdapterType, modelConfig] ); const { items, selectedItem, onChange, placeholder } = useModelCustomSelect({ diff --git a/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts b/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts index 7fd1088767b..82d6e8c5d65 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/hooks/useAddControlAdapter.ts @@ -1,7 +1,9 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice'; -import type { ControlAdapterType } from 'features/controlAdapters/store/types'; +import { type ControlAdapterType, isControlAdapterProcessorType } from 'features/controlAdapters/store/types'; import { useCallback, useMemo } from 'react'; +import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; import { useControlAdapterModels } from './useControlAdapterModels'; @@ -11,7 +13,7 @@ export const useAddControlAdapter = (type: ControlAdapterType) => { const models = useControlAdapterModels(type); - const firstModel = useMemo(() => { + const firstModel: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig | undefined = useMemo(() => { // prefer to use a model that matches the base model const firstCompatibleModel = models.filter((m) => (baseModel ? m.base === baseModel : true))[0]; @@ -28,6 +30,26 @@ export const useAddControlAdapter = (type: ControlAdapterType) => { if (isDisabled) { return; } + + if ( + (type === 'controlnet' || type === 't2i_adapter') && + (firstModel?.type === 'controlnet' || firstModel?.type === 't2i_adapter') + ) { + const defaultPreprocessor = firstModel.default_settings?.preprocessor; + const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none'; + const processorNode = CONTROLNET_PROCESSORS[processorType].default; + dispatch( + controlAdapterAdded({ + type, + overrides: { + model: firstModel, + processorType, + processorNode, + }, + }) + ); + return; + } dispatch( controlAdapterAdded({ type, diff --git a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModel.ts b/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModel.ts index 1416c8c9f1e..4de2aeac7fb 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModel.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/hooks/useControlAdapterModel.ts @@ -1,3 +1,4 @@ +import { skipToken } from '@reduxjs/toolkit/query'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { @@ -5,18 +6,22 @@ import { selectControlAdaptersSlice, } from 'features/controlAdapters/store/controlAdaptersSlice'; import { useMemo } from 'react'; +import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard'; +import { isControlAdapterModelConfig } from 'services/api/types'; export const useControlAdapterModel = (id: string) => { const selector = useMemo( () => createMemoizedSelector( selectControlAdaptersSlice, - (controlAdapters) => selectControlAdapterById(controlAdapters, id)?.model + (controlAdapters) => selectControlAdapterById(controlAdapters, id)?.model?.key ), [id] ); - const model = useAppSelector(selector); + const key = useAppSelector(selector); - return model; + const result = useGetModelConfigWithTypeGuard(key ?? skipToken, isControlAdapterModelConfig); + + return result; }; diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/constants.ts b/invokeai/frontend/web/src/features/controlAdapters/store/constants.ts index 01c7d4217fe..91f792bf97b 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/store/constants.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/store/constants.ts @@ -253,23 +253,3 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = { }, }, }; - -export const CONTROLNET_MODEL_DEFAULT_PROCESSORS: { - [key: string]: ControlAdapterProcessorType; -} = { - canny: 'canny_image_processor', - mlsd: 'mlsd_image_processor', - depth: 'depth_anything_image_processor', - bae: 'normalbae_image_processor', - sketch: 'pidi_image_processor', - scribble: 'lineart_image_processor', - lineart: 'lineart_image_processor', - lineart_anime: 'lineart_anime_image_processor', - softedge: 'hed_image_processor', - shuffle: 'content_shuffle_image_processor', - openpose: 'dw_openpose_image_processor', - mediapipe: 'mediapipe_face_processor', - pidi: 'pidi_image_processor', - zoe: 'zoe_depth_image_processor', - color: 'color_map_image_processor', -}; diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts index a20e287011f..ee36d10e28f 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts @@ -3,20 +3,14 @@ import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import type { PersistConfig, RootState } from 'app/store/store'; import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter'; -import type { - ParameterControlNetModel, - ParameterIPAdapterModel, - ParameterT2IAdapterModel, -} from 'features/parameters/types/parameterSchemas'; +import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor'; import { cloneDeep, merge, uniq } from 'lodash-es'; +import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; import { socketInvocationError } from 'services/events/actions'; import { v4 as uuidv4 } from 'uuid'; import { controlAdapterImageProcessed } from './actions'; -import { - CONTROLNET_MODEL_DEFAULT_PROCESSORS as CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS, - CONTROLNET_PROCESSORS, -} from './constants'; +import { CONTROLNET_PROCESSORS } from './constants'; import type { ControlAdapterConfig, ControlAdapterProcessorType, @@ -194,15 +188,17 @@ export const controlAdaptersSlice = createSlice({ state, action: PayloadAction<{ id: string; - model: ParameterControlNetModel | ParameterT2IAdapterModel | ParameterIPAdapterModel; + modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig; }> ) => { - const { id, model } = action.payload; + const { id, modelConfig } = action.payload; const cn = selectControlAdapterById(state, id); if (!cn) { return; } + const model = { key: modelConfig.key, base: modelConfig.base }; + if (!isControlNetOrT2IAdapter(cn)) { caAdapter.updateOne(state, { id, changes: { model } }); return; @@ -215,24 +211,14 @@ export const controlAdaptersSlice = createSlice({ update.changes.processedControlImage = null; - let processorType: ControlAdapterProcessorType | undefined = undefined; - - for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) { - // TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType - if (model.key.includes(modelSubstring)) { - processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring]; - break; - } + if (modelConfig.type === 'ip_adapter') { + // should never happen... + return; } - if (processorType) { - update.changes.processorType = processorType; - update.changes.processorNode = CONTROLNET_PROCESSORS[processorType] - .default as RequiredControlAdapterProcessorNode; - } else { - update.changes.processorType = 'none'; - update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode; - } + const processor = buildControlAdapterProcessor(modelConfig); + update.changes.processorType = processor.processorType; + update.changes.processorNode = processor.processorNode; caAdapter.updateOne(state, update); }, @@ -324,39 +310,23 @@ export const controlAdaptersSlice = createSlice({ state, action: PayloadAction<{ id: string; + modelConfig?: ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig; }> ) => { - const { id } = action.payload; + const { id, modelConfig } = action.payload; const cn = selectControlAdapterById(state, id); - if (!cn || !isControlNetOrT2IAdapter(cn)) { + if (!cn || !isControlNetOrT2IAdapter(cn) || modelConfig?.type === 'ip_adapter') { return; } - const update: Update = { id, changes: { shouldAutoConfig: !cn.shouldAutoConfig }, }; - if (update.changes.shouldAutoConfig) { - // manage the processor for the user - let processorType: ControlAdapterProcessorType | undefined = undefined; - - for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) { - // TODO(MM2): matching modelSubstring to the model key is no longer a valid way to figure out the default processorType - if (cn.model?.key.includes(modelSubstring)) { - processorType = CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring]; - break; - } - } - - if (processorType) { - update.changes.processorType = processorType; - update.changes.processorNode = CONTROLNET_PROCESSORS[processorType] - .default as RequiredControlAdapterProcessorNode; - } else { - update.changes.processorType = 'none'; - update.changes.processorNode = CONTROLNET_PROCESSORS.none.default as RequiredControlAdapterProcessorNode; - } + if (update.changes.shouldAutoConfig && modelConfig) { + const processor = buildControlAdapterProcessor(modelConfig); + update.changes.processorType = processor.processorType; + update.changes.processorNode = processor.processorNode; } caAdapter.updateOne(state, update); diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/types.test.ts b/invokeai/frontend/web/src/features/controlAdapters/store/types.test.ts new file mode 100644 index 00000000000..3bde8bc6c6f --- /dev/null +++ b/invokeai/frontend/web/src/features/controlAdapters/store/types.test.ts @@ -0,0 +1,10 @@ +import type { ControlAdapterProcessorType, zControlAdapterProcessorType } from 'features/controlAdapters/store/types'; +import type { Equals } from 'tsafe'; +import { assert } from 'tsafe'; +import { describe, test } from 'vitest'; +import type { z } from 'zod'; + +describe('Control Adapter Types', () => { + test('ControlAdapterProcessorType', () => + assert>>()); +}); diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/types.ts b/invokeai/frontend/web/src/features/controlAdapters/store/types.ts index 3665355ecfd..28e375fe493 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/store/types.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/store/types.ts @@ -47,6 +47,25 @@ export type ControlAdapterProcessorNode = * Any ControlNet processor type */ export type ControlAdapterProcessorType = NonNullable; +export const zControlAdapterProcessorType = z.enum([ + 'canny_image_processor', + 'color_map_image_processor', + 'content_shuffle_image_processor', + 'depth_anything_image_processor', + 'hed_image_processor', + 'lineart_anime_image_processor', + 'lineart_image_processor', + 'mediapipe_face_processor', + 'midas_depth_image_processor', + 'mlsd_image_processor', + 'normalbae_image_processor', + 'dw_openpose_image_processor', + 'pidi_image_processor', + 'zoe_depth_image_processor', + 'none', +]); +export const isControlAdapterProcessorType = (v: unknown): v is ControlAdapterProcessorType => + zControlAdapterProcessorType.safeParse(v).success; /** * The Canny processor node, with parameters flagged as required diff --git a/invokeai/frontend/web/src/features/controlAdapters/util/buildControlAdapterProcessor.ts b/invokeai/frontend/web/src/features/controlAdapters/util/buildControlAdapterProcessor.ts new file mode 100644 index 00000000000..911bacc7874 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlAdapters/util/buildControlAdapterProcessor.ts @@ -0,0 +1,11 @@ +import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; +import { isControlAdapterProcessorType } from 'features/controlAdapters/store/types'; +import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types'; + +export const buildControlAdapterProcessor = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => { + const defaultPreprocessor = modelConfig.default_settings?.preprocessor; + const processorType = isControlAdapterProcessorType(defaultPreprocessor) ? defaultPreprocessor : 'none'; + const processorNode = CONTROLNET_PROCESSORS[processorType].default; + + return { processorType, processorNode }; +}; diff --git a/invokeai/frontend/web/src/features/metadata/util/parsers.ts b/invokeai/frontend/web/src/features/metadata/util/parsers.ts index 30ec37991c6..24274b8e6a7 100644 --- a/invokeai/frontend/web/src/features/metadata/util/parsers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/parsers.ts @@ -1,9 +1,9 @@ -import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants'; import { initialControlNet, initialIPAdapter, initialT2IAdapter, } from 'features/controlAdapters/util/buildControlAdapter'; +import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor'; import type { LoRA } from 'features/lora/store/loraSlice'; import { defaultLoRAConfig } from 'features/lora/store/loraSlice'; import type { @@ -253,8 +253,7 @@ const parseControlNet: MetadataParseFunc = async (meta .catch(null) .parse(getProperty(metadataItem, 'resize_mode')); - const processorType = 'none'; - const processorNode = CONTROLNET_PROCESSORS.none.default; + const { processorType, processorNode } = buildControlAdapterProcessor(controlNetModel); const controlNet: ControlNetConfigMetadata = { type: 'controlnet', @@ -305,8 +304,7 @@ const parseT2IAdapter: MetadataParseFunc = async (meta .catch(null) .parse(getProperty(metadataItem, 'resize_mode')); - const processorType = 'none'; - const processorNode = CONTROLNET_PROCESSORS.none.default; + const { processorType, processorNode } = buildControlAdapterProcessor(t2iAdapterModel); const t2iAdapter: T2IAdapterConfigMetadata = { type: 't2i_adapter', diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings.ts b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings.ts new file mode 100644 index 00000000000..826bec17b17 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings.ts @@ -0,0 +1,23 @@ +import { skipToken } from '@reduxjs/toolkit/query'; +import { isNil } from 'lodash-es'; +import { useMemo } from 'react'; +import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard'; +import { isControlNetOrT2IAdapterModelConfig } from 'services/api/types'; + +export const useControlNetOrT2IAdapterDefaultSettings = (modelKey?: string | null) => { + const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard( + modelKey ?? skipToken, + isControlNetOrT2IAdapterModelConfig + ); + + const defaultSettingsDefaults = useMemo(() => { + return { + preprocessor: { + isEnabled: !isNil(modelConfig?.default_settings?.preprocessor), + value: modelConfig?.default_settings?.preprocessor || 'none', + }, + }; + }, [modelConfig?.default_settings]); + + return { defaultSettingsDefaults, isLoading }; +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useMainModelDefaultSettings.ts b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useMainModelDefaultSettings.ts new file mode 100644 index 00000000000..207199ec933 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useMainModelDefaultSettings.ts @@ -0,0 +1,65 @@ +import { skipToken } from '@reduxjs/toolkit/query'; +import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { useAppSelector } from 'app/store/storeHooks'; +import { selectConfigSlice } from 'features/system/store/configSlice'; +import { isNil } from 'lodash-es'; +import { useMemo } from 'react'; +import { useGetModelConfigWithTypeGuard } from 'services/api/hooks/useGetModelConfigWithTypeGuard'; +import { isNonRefinerMainModelConfig } from 'services/api/types'; + +const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => { + const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd; + + return { + initialSteps: steps.initial, + initialCfg: guidance.initial, + initialScheduler: scheduler, + initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial, + initialVaePrecision: vaePrecision, + }; +}); + +export const useMainModelDefaultSettings = (modelKey?: string | null) => { + const { modelConfig, isLoading } = useGetModelConfigWithTypeGuard(modelKey ?? skipToken, isNonRefinerMainModelConfig); + + const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } = + useAppSelector(initialStatesSelector); + + const defaultSettingsDefaults = useMemo(() => { + return { + vae: { + isEnabled: !isNil(modelConfig?.default_settings?.vae), + value: modelConfig?.default_settings?.vae || 'default', + }, + vaePrecision: { + isEnabled: !isNil(modelConfig?.default_settings?.vae_precision), + value: modelConfig?.default_settings?.vae_precision || initialVaePrecision || 'fp32', + }, + scheduler: { + isEnabled: !isNil(modelConfig?.default_settings?.scheduler), + value: modelConfig?.default_settings?.scheduler || initialScheduler || 'euler', + }, + steps: { + isEnabled: !isNil(modelConfig?.default_settings?.steps), + value: modelConfig?.default_settings?.steps || initialSteps, + }, + cfgScale: { + isEnabled: !isNil(modelConfig?.default_settings?.cfg_scale), + value: modelConfig?.default_settings?.cfg_scale || initialCfg, + }, + cfgRescaleMultiplier: { + isEnabled: !isNil(modelConfig?.default_settings?.cfg_rescale_multiplier), + value: modelConfig?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier, + }, + }; + }, [ + modelConfig?.default_settings, + initialSteps, + initialCfg, + initialScheduler, + initialCfgRescaleMultiplier, + initialVaePrecision, + ]); + + return { defaultSettingsDefaults, isLoading }; +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings.tsx new file mode 100644 index 00000000000..d54320b9dd7 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings.tsx @@ -0,0 +1,105 @@ +import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings'; +import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor'; +import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings'; +import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; +import { useCallback } from 'react'; +import type { SubmitHandler } from 'react-hook-form'; +import { useForm } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; +import { PiCheckBold } from 'react-icons/pi'; +import { useUpdateModelMutation } from 'services/api/endpoints/models'; + +export type ControlNetOrT2IAdapterDefaultSettingsFormData = { + preprocessor: FormField; +}; + +export const ControlNetOrT2IAdapterDefaultSettings = () => { + const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + + const { defaultSettingsDefaults, isLoading: isLoadingDefaultSettings } = + useControlNetOrT2IAdapterDefaultSettings(selectedModelKey); + + const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation(); + + const { handleSubmit, control, formState, reset } = useForm({ + defaultValues: defaultSettingsDefaults, + }); + + const onSubmit = useCallback>( + (data) => { + if (!selectedModelKey) { + return; + } + + const body = { + preprocessor: data.preprocessor.isEnabled ? data.preprocessor.value : null, + }; + + updateModel({ + key: selectedModelKey, + body: { default_settings: body }, + }) + .unwrap() + .then((_) => { + dispatch( + addToast( + makeToast({ + title: t('modelManager.defaultSettingsSaved'), + status: 'success', + }) + ) + ); + reset(data); + }) + .catch((error) => { + if (error) { + dispatch( + addToast( + makeToast({ + title: `${error.data.detail} `, + status: 'error', + }) + ) + ); + } + }); + }, + [selectedModelKey, dispatch, reset, updateModel, t] + ); + + if (isLoadingDefaultSettings) { + return {t('common.loading')}; + } + + return ( + <> + + {t('modelManager.defaultSettings')} + + + + + + + + + + + + ); +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor.tsx new file mode 100644 index 00000000000..b2284336bf4 --- /dev/null +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor.tsx @@ -0,0 +1,66 @@ +import type { ComboboxOnChange } from '@invoke-ai/ui-library'; +import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; +import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; +import type { ControlNetOrT2IAdapterDefaultSettingsFormData } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings'; +import type { FormField } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings'; +import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle'; +import { useCallback, useMemo } from 'react'; +import type { UseControllerProps } from 'react-hook-form'; +import { useController } from 'react-hook-form'; +import { useTranslation } from 'react-i18next'; + +const OPTIONS = [ + { label: 'Canny', value: 'canny_image_processor' }, + { label: 'MLSD', value: 'mlsd_image_processor' }, + { label: 'Depth Anything', value: 'depth_anything_image_processor' }, + { label: 'Normal BAE', value: 'normalbae_image_processor' }, + { label: 'Pidi', value: 'pidi_image_processor' }, + { label: 'Lineart', value: 'lineart_image_processor' }, + { label: 'Lineart Anime', value: 'lineart_anime_image_processor' }, + { label: 'HED', value: 'hed_image_processor' }, + { label: 'Content Shuffle', value: 'content_shuffle_image_processor' }, + { label: 'DW OpenPose', value: 'dw_openpose_image_processor' }, + { label: 'MediaPipe Face', value: 'mediapipe_face_processor' }, + { label: 'ZoeDepth', value: 'zoe_depth_image_processor' }, + { label: 'Color Map', value: 'color_map_image_processor' }, + { label: 'None', value: 'none' }, +] as const; + +type DefaultSchedulerType = ControlNetOrT2IAdapterDefaultSettingsFormData['preprocessor']; + +export function DefaultPreprocessor(props: UseControllerProps) { + const { t } = useTranslation(); + const { field } = useController(props); + + const onChange = useCallback( + (v) => { + if (!v) { + return; + } + const updatedValue = { + ...(field.value as FormField), + value: v.value, + }; + field.onChange(updatedValue); + }, + [field] + ); + + const value = useMemo(() => OPTIONS.find((o) => o.value === (field.value as FormField).value), [field]); + + const isDisabled = useMemo(() => { + return !(field.value as DefaultSchedulerType).isEnabled; + }, [field.value]); + + return ( + + + + {t('controlnet.processor')} + + + + + + ); +} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings.tsx deleted file mode 100644 index 0e31d3b53e9..00000000000 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings.tsx +++ /dev/null @@ -1,68 +0,0 @@ -import { Text } from '@invoke-ai/ui-library'; -import { skipToken } from '@reduxjs/toolkit/query'; -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import { useAppSelector } from 'app/store/storeHooks'; -import { selectConfigSlice } from 'features/system/store/configSlice'; -import { isNil } from 'lodash-es'; -import { useMemo } from 'react'; -import { useTranslation } from 'react-i18next'; -import { useGetModelConfigQuery } from 'services/api/endpoints/models'; - -import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm'; - -const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config) => { - const { steps, guidance, scheduler, cfgRescaleMultiplier, vaePrecision } = config.sd; - - return { - initialSteps: steps.initial, - initialCfg: guidance.initial, - initialScheduler: scheduler, - initialCfgRescaleMultiplier: cfgRescaleMultiplier.initial, - initialVaePrecision: vaePrecision, - }; -}); - -export const DefaultSettings = () => { - const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); - const { t } = useTranslation(); - - const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken); - const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } = - useAppSelector(initialStatesSelector); - - const defaultSettingsDefaults = useMemo(() => { - return { - vae: { isEnabled: !isNil(data?.default_settings?.vae), value: data?.default_settings?.vae || 'default' }, - vaePrecision: { - isEnabled: !isNil(data?.default_settings?.vae_precision), - value: data?.default_settings?.vae_precision || initialVaePrecision || 'fp32', - }, - scheduler: { - isEnabled: !isNil(data?.default_settings?.scheduler), - value: data?.default_settings?.scheduler || initialScheduler || 'euler', - }, - steps: { isEnabled: !isNil(data?.default_settings?.steps), value: data?.default_settings?.steps || initialSteps }, - cfgScale: { - isEnabled: !isNil(data?.default_settings?.cfg_scale), - value: data?.default_settings?.cfg_scale || initialCfg, - }, - cfgRescaleMultiplier: { - isEnabled: !isNil(data?.default_settings?.cfg_rescale_multiplier), - value: data?.default_settings?.cfg_rescale_multiplier || initialCfgRescaleMultiplier, - }, - }; - }, [ - data?.default_settings, - initialSteps, - initialCfg, - initialScheduler, - initialCfgRescaleMultiplier, - initialVaePrecision, - ]); - - if (isLoading) { - return {t('common.loading')}; - } - - return ; -}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultCfgRescaleMultiplier.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultCfgRescaleMultiplier.tsx similarity index 91% rename from invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultCfgRescaleMultiplier.tsx rename to invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultCfgRescaleMultiplier.tsx index 5e1cfe990a9..d16ce1460c2 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultCfgRescaleMultiplier.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultCfgRescaleMultiplier.tsx @@ -1,17 +1,17 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; -import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/SettingToggle'; +import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle'; import { useCallback, useMemo } from 'react'; import type { UseControllerProps } from 'react-hook-form'; import { useController } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { DefaultSettingsFormData } from './DefaultSettingsForm'; +import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings'; -type DefaultCfgRescaleMultiplierType = DefaultSettingsFormData['cfgRescaleMultiplier']; +type DefaultCfgRescaleMultiplierType = MainModelDefaultSettingsFormData['cfgRescaleMultiplier']; -export function DefaultCfgRescaleMultiplier(props: UseControllerProps) { +export function DefaultCfgRescaleMultiplier(props: UseControllerProps) { const { field } = useController(props); const sliderMin = useAppSelector((s) => s.config.sd.cfgRescaleMultiplier.sliderMin); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultCfgScale.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultCfgScale.tsx similarity index 90% rename from invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultCfgScale.tsx rename to invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultCfgScale.tsx index a1bb34868b9..293261bc352 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultCfgScale.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultCfgScale.tsx @@ -1,17 +1,17 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; -import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/SettingToggle'; +import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle'; import { useCallback, useMemo } from 'react'; import type { UseControllerProps } from 'react-hook-form'; import { useController } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { DefaultSettingsFormData } from './DefaultSettingsForm'; +import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings'; -type DefaultCfgType = DefaultSettingsFormData['cfgScale']; +type DefaultCfgType = MainModelDefaultSettingsFormData['cfgScale']; -export function DefaultCfgScale(props: UseControllerProps) { +export function DefaultCfgScale(props: UseControllerProps) { const { field } = useController(props); const sliderMin = useAppSelector((s) => s.config.sd.guidance.sliderMin); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultScheduler.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultScheduler.tsx similarity index 86% rename from invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultScheduler.tsx rename to invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultScheduler.tsx index d195b2cc387..4397e35a51e 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultScheduler.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultScheduler.tsx @@ -1,7 +1,7 @@ import type { ComboboxOnChange } from '@invoke-ai/ui-library'; import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; -import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/SettingToggle'; +import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle'; import { SCHEDULER_OPTIONS } from 'features/parameters/types/constants'; import { isParameterScheduler } from 'features/parameters/types/parameterSchemas'; import { useCallback, useMemo } from 'react'; @@ -9,11 +9,11 @@ import type { UseControllerProps } from 'react-hook-form'; import { useController } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { DefaultSettingsFormData } from './DefaultSettingsForm'; +import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings'; -type DefaultSchedulerType = DefaultSettingsFormData['scheduler']; +type DefaultSchedulerType = MainModelDefaultSettingsFormData['scheduler']; -export function DefaultScheduler(props: UseControllerProps) { +export function DefaultScheduler(props: UseControllerProps) { const { t } = useTranslation(); const { field } = useController(props); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultSteps.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultSteps.tsx similarity index 90% rename from invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultSteps.tsx rename to invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultSteps.tsx index 37282564b54..9c1912a0f78 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultSteps.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultSteps.tsx @@ -1,17 +1,17 @@ import { CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; -import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/SettingToggle'; +import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle'; import { useCallback, useMemo } from 'react'; import type { UseControllerProps } from 'react-hook-form'; import { useController } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { DefaultSettingsFormData } from './DefaultSettingsForm'; +import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings'; -type DefaultSteps = DefaultSettingsFormData['steps']; +type DefaultSteps = MainModelDefaultSettingsFormData['steps']; -export function DefaultSteps(props: UseControllerProps) { +export function DefaultSteps(props: UseControllerProps) { const { field } = useController(props); const sliderMin = useAppSelector((s) => s.config.sd.steps.sliderMin); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultVae.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultVae.tsx similarity index 90% rename from invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultVae.tsx rename to invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultVae.tsx index d00f6b212fa..dcaab943779 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultVae.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultVae.tsx @@ -3,7 +3,7 @@ import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { skipToken } from '@reduxjs/toolkit/query'; import { useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; -import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/SettingToggle'; +import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle'; import { map } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import type { UseControllerProps } from 'react-hook-form'; @@ -11,11 +11,11 @@ import { useController } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; import { useGetModelConfigQuery, useGetVaeModelsQuery } from 'services/api/endpoints/models'; -import type { DefaultSettingsFormData } from './DefaultSettingsForm'; +import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings'; -type DefaultVaeType = DefaultSettingsFormData['vae']; +type DefaultVaeType = MainModelDefaultSettingsFormData['vae']; -export function DefaultVae(props: UseControllerProps) { +export function DefaultVae(props: UseControllerProps) { const { t } = useTranslation(); const { field } = useController(props); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultVaePrecision.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultVaePrecision.tsx similarity index 86% rename from invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultVaePrecision.tsx rename to invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultVaePrecision.tsx index b5cfe6f81e8..d33cf4e08da 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultVaePrecision.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultVaePrecision.tsx @@ -1,23 +1,23 @@ import type { ComboboxOnChange } from '@invoke-ai/ui-library'; import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; -import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/SettingToggle'; +import { SettingToggle } from 'features/modelManagerV2/subpanels/ModelPanel/SettingToggle'; import { isParameterPrecision } from 'features/parameters/types/parameterSchemas'; import { useCallback, useMemo } from 'react'; import type { UseControllerProps } from 'react-hook-form'; import { useController } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; -import type { DefaultSettingsFormData } from './DefaultSettingsForm'; +import type { MainModelDefaultSettingsFormData } from './MainModelDefaultSettings'; const options = [ { label: 'FP16', value: 'fp16' }, { label: 'FP32', value: 'fp32' }, ]; -type DefaultVaePrecisionType = DefaultSettingsFormData['vaePrecision']; +type DefaultVaePrecisionType = MainModelDefaultSettingsFormData['vaePrecision']; -export function DefaultVaePrecision(props: UseControllerProps) { +export function DefaultVaePrecision(props: UseControllerProps) { const { t } = useTranslation(); const { field } = useController(props); diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultSettingsForm.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings.tsx similarity index 83% rename from invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultSettingsForm.tsx rename to invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings.tsx index 3c8551a52f3..9766fc1a146 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/DefaultSettingsForm.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/MainModelDefaultSettings.tsx @@ -1,5 +1,6 @@ -import { Button, Flex, Heading } from '@invoke-ai/ui-library'; +import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings'; import type { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; @@ -22,7 +23,7 @@ export interface FormField { isEnabled: boolean; } -export type DefaultSettingsFormData = { +export type MainModelDefaultSettingsFormData = { vae: FormField; vaePrecision: FormField; scheduler: FormField; @@ -31,22 +32,21 @@ export type DefaultSettingsFormData = { cfgRescaleMultiplier: FormField; }; -export const DefaultSettingsForm = ({ - defaultSettingsDefaults, -}: { - defaultSettingsDefaults: DefaultSettingsFormData; -}) => { - const dispatch = useAppDispatch(); - const { t } = useTranslation(); +export const MainModelDefaultSettings = () => { const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + + const { defaultSettingsDefaults, isLoading: isLoadingDefaultSettings } = + useMainModelDefaultSettings(selectedModelKey); - const [updateModel, { isLoading }] = useUpdateModelMutation(); + const [updateModel, { isLoading: isLoadingUpdateModel }] = useUpdateModelMutation(); - const { handleSubmit, control, formState, reset } = useForm({ + const { handleSubmit, control, formState, reset } = useForm({ defaultValues: defaultSettingsDefaults, }); - const onSubmit = useCallback>( + const onSubmit = useCallback>( (data) => { if (!selectedModelKey) { return; @@ -93,6 +93,10 @@ export const DefaultSettingsForm = ({ [selectedModelKey, dispatch, reset, updateModel, t] ); + if (isLoadingDefaultSettings) { + return {t('common.loading')}; + } + return ( <> @@ -104,7 +108,7 @@ export const DefaultSettingsForm = ({ isDisabled={!formState.isDirty} onClick={handleSubmit(onSubmit)} type="submit" - isLoading={isLoading} + isLoading={isLoadingUpdateModel} > {t('common.save')} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx index a167823596c..adb123f24db 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/ModelView.tsx @@ -1,11 +1,12 @@ import { Box, Flex, Text } from '@invoke-ai/ui-library'; import { skipToken } from '@reduxjs/toolkit/query'; import { useAppSelector } from 'app/store/storeHooks'; +import { ControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings'; import { TriggerPhrases } from 'features/modelManagerV2/subpanels/ModelPanel/TriggerPhrases'; import { useTranslation } from 'react-i18next'; import { useGetModelConfigQuery } from 'services/api/endpoints/models'; -import { DefaultSettings } from './DefaultSettings'; +import { MainModelDefaultSettings } from './MainModelDefaultSettings/MainModelDefaultSettings'; import { ModelAttrView } from './ModelAttrView'; export const ModelView = () => { @@ -59,9 +60,14 @@ export const ModelView = () => { )} - {data.type === 'main' && ( + {data.type === 'main' && data.base !== 'sdxl-refiner' && ( - + + + )} + {(data.type === 'controlnet' || data.type === 't2i_adapter') && ( + + )} {(data.type === 'main' || data.type === 'lora') && ( diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/SettingToggle.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/SettingToggle.tsx similarity index 79% rename from invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/SettingToggle.tsx rename to invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/SettingToggle.tsx index 7f5cd8efb98..15e4693c4f0 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/DefaultSettings/SettingToggle.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelPanel/SettingToggle.tsx @@ -4,9 +4,9 @@ import { useCallback, useMemo } from 'react'; import type { UseControllerProps } from 'react-hook-form'; import { useController } from 'react-hook-form'; -import type { DefaultSettingsFormData, FormField } from './DefaultSettingsForm'; +import type { FormField } from './MainModelDefaultSettings/MainModelDefaultSettings'; -export function SettingToggle(props: UseControllerProps) { +export function SettingToggle>>(props: UseControllerProps) { const { field } = useController(props); const value = useMemo(() => { diff --git a/invokeai/frontend/web/src/services/api/hooks/useGetModelConfigWithTypeGuard.ts b/invokeai/frontend/web/src/services/api/hooks/useGetModelConfigWithTypeGuard.ts new file mode 100644 index 00000000000..6de29414034 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/hooks/useGetModelConfigWithTypeGuard.ts @@ -0,0 +1,20 @@ +import { skipToken } from '@reduxjs/toolkit/query'; +import { useGetModelConfigQuery } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; + +export const useGetModelConfigWithTypeGuard = ( + key: string | typeof skipToken, + typeGuard: (config: AnyModelConfig) => config is T +) => { + const result = useGetModelConfigQuery(key ?? skipToken, { + selectFromResult: (result) => { + const modelConfig = result.data; + return { + ...result, + modelConfig: modelConfig && typeGuard(modelConfig) ? modelConfig : undefined, + }; + }, + }); + + return result; +}; diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 12f7a88ba14..58b1ca309e4 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1376,8 +1376,6 @@ export type components = { * @description The original API response from the source, as stringified JSON. */ source_api_response?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Cover Image * @description Url for image to preview model @@ -2240,6 +2238,11 @@ export type components = { */ type: "content_shuffle_image_processor"; }; + /** ControlAdapterDefaultSettings */ + ControlAdapterDefaultSettings: { + /** Preprocessor */ + preprocessor: string | null; + }; /** ControlField */ ControlField: { /** @description The control image */ @@ -2284,6 +2287,8 @@ export type components = { * @description Model config for ControlNet models (diffusers version). */ ControlNetCheckpointConfig: { + /** @description Default settings for this model */ + default_settings?: components["schemas"]["ControlAdapterDefaultSettings"] | null; /** * Key * @description A unique key for this model. @@ -2323,8 +2328,6 @@ export type components = { * @description The original API response from the source, as stringified JSON. */ source_api_response?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Cover Image * @description Url for image to preview model @@ -2358,6 +2361,8 @@ export type components = { * @description Model config for ControlNet models (diffusers version). */ ControlNetDiffusersConfig: { + /** @description Default settings for this model */ + default_settings?: components["schemas"]["ControlAdapterDefaultSettings"] | null; /** * Key * @description A unique key for this model. @@ -2397,8 +2402,6 @@ export type components = { * @description The original API response from the source, as stringified JSON. */ source_api_response?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Cover Image * @description Url for image to preview model @@ -4100,7 +4103,7 @@ export type components = { * @description The nodes in this graph */ nodes: { - [key: string]: components["schemas"]["FaceOffInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["InfillColorInvocation"]; + [key: string]: components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"]; }; /** * Edges @@ -4137,7 +4140,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["T2IAdapterOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["String2Output"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["CalculateImageTilesOutput"]; + [key: string]: components["schemas"]["ColorCollectionOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["String2Output"] | components["schemas"]["NoiseOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["IPAdapterOutput"]; }; /** * Errors @@ -4317,8 +4320,6 @@ export type components = { * @description The original API response from the source, as stringified JSON. */ source_api_response?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Cover Image * @description Url for image to preview model @@ -6370,8 +6371,6 @@ export type components = { * @description The original API response from the source, as stringified JSON. */ source_api_response?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Cover Image * @description Url for image to preview model @@ -6523,8 +6522,6 @@ export type components = { * @description The original API response from the source, as stringified JSON. */ source_api_response?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Cover Image * @description Url for image to preview model @@ -6629,8 +6626,6 @@ export type components = { * @description The original API response from the source, as stringified JSON. */ source_api_response?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Cover Image * @description Url for image to preview model @@ -6647,6 +6642,8 @@ export type components = { * @description Set of trigger phrases for this model */ trigger_phrases?: string[] | null; + /** @description Default settings for this model */ + default_settings?: components["schemas"]["MainModelDefaultSettings"] | null; /** * Format * @default checkpoint @@ -6717,8 +6714,6 @@ export type components = { * @description The original API response from the source, as stringified JSON. */ source_api_response?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Cover Image * @description Url for image to preview model @@ -6735,6 +6730,8 @@ export type components = { * @description Set of trigger phrases for this model */ trigger_phrases?: string[] | null; + /** @description Default settings for this model */ + default_settings?: components["schemas"]["MainModelDefaultSettings"] | null; /** * Format * @default diffusers @@ -6744,6 +6741,21 @@ export type components = { /** @default */ repo_variant?: components["schemas"]["ModelRepoVariant"] | null; }; + /** MainModelDefaultSettings */ + MainModelDefaultSettings: { + /** Vae */ + vae: string | null; + /** Vae Precision */ + vae_precision: string | null; + /** Scheduler */ + scheduler: ("ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm") | null; + /** Steps */ + steps: number | null; + /** Cfg Scale */ + cfg_scale: number | null; + /** Cfg Rescale Multiplier */ + cfg_rescale_multiplier: number | null; + }; /** * Main Model * @description Loads a main model, outputting its submodels. @@ -7263,21 +7275,6 @@ export type components = { */ type: "mlsd_image_processor"; }; - /** ModelDefaultSettings */ - ModelDefaultSettings: { - /** Vae */ - vae: string | null; - /** Vae Precision */ - vae_precision: string | null; - /** Scheduler */ - scheduler: ("ddim" | "ddpm" | "deis" | "lms" | "lms_k" | "pndm" | "heun" | "heun_k" | "euler" | "euler_k" | "euler_a" | "kdpm_2" | "kdpm_2_a" | "dpmpp_2s" | "dpmpp_2s_k" | "dpmpp_2m" | "dpmpp_2m_k" | "dpmpp_2m_sde" | "dpmpp_2m_sde_k" | "dpmpp_sde" | "dpmpp_sde_k" | "unipc" | "lcm") | null; - /** Steps */ - steps: number | null; - /** Cfg Scale */ - cfg_scale: number | null; - /** Cfg Rescale Multiplier */ - cfg_rescale_multiplier: number | null; - }; /** ModelField */ ModelField: { /** @@ -7445,8 +7442,11 @@ export type components = { * @description Set of trigger phrases for this model */ trigger_phrases?: string[] | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ModelDefaultSettings"] | null; + /** + * Default Settings + * @description Default settings for this model + */ + default_settings?: components["schemas"]["MainModelDefaultSettings"] | components["schemas"]["ControlAdapterDefaultSettings"] | null; /** @description The variant of the model. */ variant?: components["schemas"]["ModelVariantType"] | null; /** @description The prediction type of the model. */ @@ -9649,6 +9649,8 @@ export type components = { * @description Model config for T2I. */ T2IAdapterConfig: { + /** @description Default settings for this model */ + default_settings?: components["schemas"]["ControlAdapterDefaultSettings"] | null; /** * Key * @description A unique key for this model. @@ -9688,8 +9690,6 @@ export type components = { * @description The original API response from the source, as stringified JSON. */ source_api_response?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Cover Image * @description Url for image to preview model @@ -9901,8 +9901,6 @@ export type components = { * @description The original API response from the source, as stringified JSON. */ source_api_response?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Cover Image * @description Url for image to preview model @@ -9965,8 +9963,6 @@ export type components = { * @description The original API response from the source, as stringified JSON. */ source_api_response?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Cover Image * @description Url for image to preview model @@ -10290,8 +10286,6 @@ export type components = { * @description The original API response from the source, as stringified JSON. */ source_api_response?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Cover Image * @description Url for image to preview model @@ -10364,8 +10358,6 @@ export type components = { * @description The original API response from the source, as stringified JSON. */ source_api_response?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ModelDefaultSettings"] | null; /** * Cover Image * @description Url for image to preview model diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 7a870a321f4..2d304a8333f 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -83,6 +83,18 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd return config.type === 't2i_adapter'; }; +export const isControlAdapterModelConfig = ( + config: AnyModelConfig +): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => { + return isControlNetModelConfig(config) || isT2IAdapterModelConfig(config) || isIPAdapterModelConfig(config); +}; + +export const isControlNetOrT2IAdapterModelConfig = ( + config: AnyModelConfig +): config is ControlNetModelConfig | T2IAdapterModelConfig => { + return isControlNetModelConfig(config) || isT2IAdapterModelConfig(config); +}; + export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => { return config.type === 'main' && config.base !== 'sdxl-refiner'; }; diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index 78a9ec50b49..1aba632380d 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -3,7 +3,7 @@ import pytest from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant -from invokeai.backend.model_manager.probe import VaeFolderProbe +from invokeai.backend.model_manager.probe import ControlAdapterProbe, VaeFolderProbe @pytest.mark.parametrize( @@ -28,3 +28,17 @@ def test_repo_variant(datadir: Path): probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16") repo_variant = probe.get_repo_variant() assert repo_variant == ModelRepoVariant.FP16 + + +def test_controlnet_t2i_default_settings(): + should_be_canny = ControlAdapterProbe.get_default_settings("some_canny_model") + assert should_be_canny and should_be_canny.preprocessor == "canny_image_processor" + + should_be_depth_anything = ControlAdapterProbe.get_default_settings("some_depth_model") + assert should_be_depth_anything and should_be_depth_anything.preprocessor == "depth_anything_image_processor" + + should_be_dw_openpose = ControlAdapterProbe.get_default_settings("some_pose_model") + assert should_be_dw_openpose and should_be_dw_openpose.preprocessor == "dw_openpose_image_processor" + + should_be_none = ControlAdapterProbe.get_default_settings("i like turtles") + assert should_be_none is None