Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(nodes): enriched model identifiers #5910

Merged
merged 4 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 8 additions & 4 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# values from the command line or config file.
import sys

from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.version.invokeai_version import __version__

from .services.config import InvokeAIAppConfig
Expand Down Expand Up @@ -156,17 +158,19 @@ def custom_openapi() -> dict[str, Any]:
openapi_schema["components"]["schemas"][schema_key] = output_schema
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"

# Add Node Editor UI helper schemas
ui_config_schemas = models_json_schema(
# Some models don't end up in the schemas as standalone definitions
additional_schemas = models_json_schema(
[
(UIConfigBase, "serialization"),
(InputFieldJSONSchemaExtra, "serialization"),
(OutputFieldJSONSchemaExtra, "serialization"),
(ModelIdentifierField, "serialization"),
(ProgressImage, "serialization"),
],
ref_template="#/components/schemas/{model}",
)
for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = ui_config_schema
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema_json

# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/controlnet_image_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
Expand All @@ -55,7 +55,7 @@

class ControlField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelField = Field(description="The ControlNet model to use")
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
Expand Down Expand Up @@ -91,7 +91,7 @@ class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""

image: ImageField = InputField(description="The control image")
control_model: ModelField = InputField(
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel
)
control_weight: Union[float, List[float]] = InputField(
Expand Down
10 changes: 5 additions & 5 deletions invokeai/app/invocations/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
Expand All @@ -20,8 +20,8 @@

class IPAdapterField(BaseModel):
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.")
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
Expand Down Expand Up @@ -54,7 +54,7 @@ class IPAdapterInvocation(BaseInvocation):

# Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: ModelField = InputField(
ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.",
title="IP-Adapter Model",
input=Input.Direct,
Expand Down Expand Up @@ -97,7 +97,7 @@ def invoke(self, context: InvocationContext) -> IPAdapterOutput:
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=ModelField(key=image_encoder_models[0].key),
image_encoder_model=ModelIdentifierField(key=image_encoder_models[0].key),
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
invocation_output,
)
from .controlnet_image_processors import ControlField
from .model import ModelField, UNetField, VAEField
from .model import ModelIdentifierField, UNetField, VAEField

if choose_torch_device() == torch.device("mps"):
from torch import mps
Expand Down Expand Up @@ -245,7 +245,7 @@ def invoke(self, context: InvocationContext) -> GradientMaskOutput:

def get_scheduler(
context: InvocationContext,
scheduler_info: ModelField,
scheduler_info: ModelIdentifierField,
scheduler_name: str,
seed: int,
) -> Scheduler:
Expand Down
34 changes: 20 additions & 14 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import SubModelType
from invokeai.backend.model_manager.config import BaseModelType, ModelType, SubModelType

from .baseinvocation import (
BaseInvocation,
Expand All @@ -16,33 +16,39 @@
)


class ModelField(BaseModel):
key: str = Field(description="Key of the model")
submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None)
class ModelIdentifierField(BaseModel):
key: str = Field(description="The model's unique key")
hash: str = Field(description="The model's BLAKE3 hash")
name: str = Field(description="The model's name")
base: BaseModelType = Field(description="The model's base model type")
type: ModelType = Field(description="The model's type")
submodel_type: Optional[SubModelType] = Field(
description="The submodel to load, if this is a main model", default=None
)


class LoRAField(BaseModel):
lora: ModelField = Field(description="Info to load lora model")
lora: ModelIdentifierField = Field(description="Info to load lora model")
weight: float = Field(description="Weight to apply to lora model")


class UNetField(BaseModel):
unet: ModelField = Field(description="Info to load unet submodel")
scheduler: ModelField = Field(description="Info to load scheduler submodel")
unet: ModelIdentifierField = Field(description="Info to load unet submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")


class CLIPField(BaseModel):
tokenizer: ModelField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelField = Field(description="Info to load text_encoder submodel")
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")


class VAEField(BaseModel):
vae: ModelField = Field(description="Info to load vae submodel")
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')


Expand Down Expand Up @@ -84,7 +90,7 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""

model: ModelField = InputField(
model: ModelIdentifierField = InputField(
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
)
# TODO: precision?
Expand Down Expand Up @@ -119,7 +125,7 @@ class LoRALoaderOutput(BaseInvocationOutput):
class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""

lora: ModelField = InputField(
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
Expand Down Expand Up @@ -190,7 +196,7 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""

lora: ModelField = InputField(
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
Expand Down Expand Up @@ -264,7 +270,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""

vae_model: ModelField = InputField(
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel
)

Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
invocation,
invocation_output,
)
from .model import CLIPField, ModelField, UNetField, VAEField
from .model import CLIPField, ModelIdentifierField, UNetField, VAEField


@invocation_output("sdxl_model_loader_output")
Expand All @@ -34,7 +34,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""

model: ModelField = InputField(
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
)
# TODO: precision?
Expand Down Expand Up @@ -72,7 +72,7 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""

model: ModelField = InputField(
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
)
# TODO: precision?
Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/t2i_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext


class T2IAdapterField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.")
t2i_adapter_model: ModelIdentifierField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
Expand Down Expand Up @@ -52,7 +52,7 @@ class T2IAdapterInvocation(BaseInvocation):

# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
t2i_adapter_model: ModelField = InputField(
t2i_adapter_model: ModelIdentifierField = InputField(
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
input=Input.Direct,
Expand Down
10 changes: 6 additions & 4 deletions invokeai/app/services/shared/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.model import ModelField
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem

"""
Expand Down Expand Up @@ -300,7 +300,7 @@ def load(self, name: str) -> ConditioningFieldData:


class ModelsInterface(InvocationContextInterface):
def exists(self, identifier: Union[str, "ModelField"]) -> bool:
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
"""Checks if a model exists.

Args:
Expand All @@ -314,7 +314,9 @@ def exists(self, identifier: Union[str, "ModelField"]) -> bool:

return self._services.model_manager.store.exists(identifier.key)

def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load(
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""Loads a model.

Args:
Expand Down Expand Up @@ -361,7 +363,7 @@ def load_by_attrs(

return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)

def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig:
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Gets a model's config.

Args:
Expand Down
2 changes: 1 addition & 1 deletion invokeai/frontend/web/.storybook/ReduxInit.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => {
const dispatch = useAppDispatch();
useGlobalModifiersInit();
useEffect(() => {
dispatch(modelChanged({ key: 'test_model', base: 'sd-1' }));
dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' }));
}, []);

return props.children;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { groupBy, map, reduce } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';

type UseGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
selectedModel?: ModelIdentifierWithBase | null;
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
isLoading?: boolean;
Expand Down
4 changes: 2 additions & 2 deletions invokeai/frontend/web/src/common/hooks/useModelCombobox.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { map } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';

type UseModelComboboxArg<T extends AnyModelConfig> = {
modelEntities: EntityState<T, string> | undefined;
selectedModel?: ModelIdentifierWithBase | null;
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
optionsFilter?: (model: T) => boolean;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { Item } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { filter } from 'lodash-es';
import { useCallback, useMemo } from 'react';
Expand All @@ -11,7 +11,7 @@ import type { AnyModelConfig } from 'services/api/types';
type UseModelCustomSelectArg<T extends AnyModelConfig> = {
data: EntityState<T, string> | undefined;
isLoading: boolean;
selectedModel?: ModelIdentifierWithBase | null;
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
modelFilter?: (model: T) => boolean;
isModelDisabled?: (model: T) => boolean;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import type { PersistConfig, RootState } from 'app/store/store';
import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter';
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { cloneDeep, merge, uniq } from 'lodash-es';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { socketInvocationError } from 'services/events/actions';
Expand Down Expand Up @@ -197,7 +198,7 @@ export const controlAdaptersSlice = createSlice({
return;
}

const model = { key: modelConfig.key, base: modelConfig.base };
const model = zModelIdentifierField.parse(modelConfig);

if (!isControlNetOrT2IAdapter(cn)) {
caAdapter.updateOne(state, { id, changes: { model } });
Expand Down
4 changes: 2 additions & 2 deletions invokeai/frontend/web/src/features/lora/store/loraSlice.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import type { LoRAModelConfig } from 'services/api/types';

Expand Down Expand Up @@ -31,7 +31,7 @@ export const loraSlice = createSlice({
initialState: initialLoraState,
reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfig>) => {
const model = getModelKeyAndBase(action.payload);
const model = zModelIdentifierField.parse(action.payload);
state.loras[model.key] = { ...defaultLoRAConfig, model };
},
loraRecalled: (state, action: PayloadAction<LoRA>) => {
Expand Down