From 88c22a888c4ebc2e5751507129bb14b9b964c6f0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 9 Mar 2024 19:43:24 +1100 Subject: [PATCH 1/4] feat(nodes): "ModelField" -> "ModelIdentifierField", add hash/name/base/type --- .../controlnet_image_processors.py | 6 ++-- invokeai/app/invocations/ip_adapter.py | 10 +++--- invokeai/app/invocations/latent.py | 4 +-- invokeai/app/invocations/model.py | 34 +++++++++++-------- invokeai/app/invocations/sdxl.py | 6 ++-- invokeai/app/invocations/t2i_adapter.py | 6 ++-- .../app/services/shared/invocation_context.py | 10 +++--- invokeai/invocation_api/__init__.py | 4 +-- 8 files changed, 44 insertions(+), 36 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index d4da0c25a1d..7b6cfaaaf13 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -35,7 +35,7 @@ WithBoard, WithMetadata, ) -from invokeai.app.invocations.model import ModelField +from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext @@ -55,7 +55,7 @@ class ControlField(BaseModel): image: ImageField = Field(description="The control image") - control_model: ModelField = Field(description="The ControlNet model to use") + control_model: ModelIdentifierField = Field(description="The ControlNet model to use") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") begin_step_percent: float = Field( default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" @@ -91,7 +91,7 @@ class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" image: ImageField = InputField(description="The control image") - control_model: ModelField = InputField( + control_model: ModelIdentifierField = InputField( description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel ) control_weight: Union[float, List[float]] = InputField( diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index e7d33654f5b..c12352071fb 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -11,7 +11,7 @@ invocation_output, ) from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType -from invokeai.app.invocations.model import ModelField +from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext @@ -20,8 +20,8 @@ class IPAdapterField(BaseModel): image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).") - ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.") - image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.") + ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model to use.") + image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.") weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") begin_step_percent: float = Field( default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)" @@ -54,7 +54,7 @@ class IPAdapterInvocation(BaseInvocation): # Inputs image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).") - ip_adapter_model: ModelField = InputField( + ip_adapter_model: ModelIdentifierField = InputField( description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, @@ -97,7 +97,7 @@ def invoke(self, context: InvocationContext) -> IPAdapterOutput: ip_adapter=IPAdapterField( image=self.image, ip_adapter_model=self.ip_adapter_model, - image_encoder_model=ModelField(key=image_encoder_models[0].key), + image_encoder_model=ModelIdentifierField(key=image_encoder_models[0].key), weight=self.weight, begin_step_percent=self.begin_step_percent, end_step_percent=self.end_step_percent, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index f21e28cfa42..94cd5a75f01 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -76,7 +76,7 @@ invocation_output, ) from .controlnet_image_processors import ControlField -from .model import ModelField, UNetField, VAEField +from .model import ModelIdentifierField, UNetField, VAEField if choose_torch_device() == torch.device("mps"): from torch import mps @@ -245,7 +245,7 @@ def invoke(self, context: InvocationContext) -> GradientMaskOutput: def get_scheduler( context: InvocationContext, - scheduler_info: ModelField, + scheduler_info: ModelIdentifierField, scheduler_name: str, seed: int, ) -> Scheduler: diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 7806f61a8f4..0ae2d272295 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -6,7 +6,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig -from invokeai.backend.model_manager.config import SubModelType +from invokeai.backend.model_manager.config import BaseModelType, ModelType, SubModelType from .baseinvocation import ( BaseInvocation, @@ -16,33 +16,39 @@ ) -class ModelField(BaseModel): - key: str = Field(description="Key of the model") - submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None) +class ModelIdentifierField(BaseModel): + key: str = Field(description="The model's unique key") + hash: str = Field(description="The model's BLAKE3 hash") + name: str = Field(description="The model's name") + base: BaseModelType = Field(description="The model's base model type") + type: ModelType = Field(description="The model's type") + submodel_type: Optional[SubModelType] = Field( + description="The submodel to load, if this is a main model", default=None + ) class LoRAField(BaseModel): - lora: ModelField = Field(description="Info to load lora model") + lora: ModelIdentifierField = Field(description="Info to load lora model") weight: float = Field(description="Weight to apply to lora model") class UNetField(BaseModel): - unet: ModelField = Field(description="Info to load unet submodel") - scheduler: ModelField = Field(description="Info to load scheduler submodel") + unet: ModelIdentifierField = Field(description="Info to load unet submodel") + scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel") loras: List[LoRAField] = Field(description="LoRAs to apply on model loading") seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration") class CLIPField(BaseModel): - tokenizer: ModelField = Field(description="Info to load tokenizer submodel") - text_encoder: ModelField = Field(description="Info to load text_encoder submodel") + tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel") + text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel") skipped_layers: int = Field(description="Number of skipped layers in text_encoder") loras: List[LoRAField] = Field(description="LoRAs to apply on model loading") class VAEField(BaseModel): - vae: ModelField = Field(description="Info to load vae submodel") + vae: ModelIdentifierField = Field(description="Info to load vae submodel") seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') @@ -84,7 +90,7 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput): class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" - model: ModelField = InputField( + model: ModelIdentifierField = InputField( description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel ) # TODO: precision? @@ -119,7 +125,7 @@ class LoRALoaderOutput(BaseInvocationOutput): class LoRALoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" - lora: ModelField = InputField( + lora: ModelIdentifierField = InputField( description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel ) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) @@ -190,7 +196,7 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput): class SDXLLoRALoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" - lora: ModelField = InputField( + lora: ModelIdentifierField = InputField( description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel ) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) @@ -264,7 +270,7 @@ def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput: class VAELoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" - vae_model: ModelField = InputField( + vae_model: ModelIdentifierField = InputField( description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel ) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 17b6ef20534..9676a6cec0b 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -8,7 +8,7 @@ invocation, invocation_output, ) -from .model import CLIPField, ModelField, UNetField, VAEField +from .model import CLIPField, ModelIdentifierField, UNetField, VAEField @invocation_output("sdxl_model_loader_output") @@ -34,7 +34,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" - model: ModelField = InputField( + model: ModelIdentifierField = InputField( description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel ) # TODO: precision? @@ -72,7 +72,7 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" - model: ModelField = InputField( + model: ModelIdentifierField = InputField( description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel ) # TODO: precision? diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index d399d17864d..71eb31c3aa5 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -10,14 +10,14 @@ ) from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType -from invokeai.app.invocations.model import ModelField +from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext class T2IAdapterField(BaseModel): image: ImageField = Field(description="The T2I-Adapter image prompt.") - t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.") + t2i_adapter_model: ModelIdentifierField = Field(description="The T2I-Adapter model to use.") weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter") begin_step_percent: float = Field( default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)" @@ -52,7 +52,7 @@ class T2IAdapterInvocation(BaseInvocation): # Inputs image: ImageField = InputField(description="The IP-Adapter image prompt.") - t2i_adapter_model: ModelField = InputField( + t2i_adapter_model: ModelIdentifierField = InputField( description="The T2I-Adapter model.", title="T2I-Adapter Model", input=Input.Direct, diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index abf131a1254..4e445a693a2 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from invokeai.app.invocations.baseinvocation import BaseInvocation - from invokeai.app.invocations.model import ModelField + from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem """ @@ -300,7 +300,7 @@ def load(self, name: str) -> ConditioningFieldData: class ModelsInterface(InvocationContextInterface): - def exists(self, identifier: Union[str, "ModelField"]) -> bool: + def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool: """Checks if a model exists. Args: @@ -314,7 +314,9 @@ def exists(self, identifier: Union[str, "ModelField"]) -> bool: return self._services.model_manager.store.exists(identifier.key) - def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel: + def load( + self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None + ) -> LoadedModel: """Loads a model. Args: @@ -361,7 +363,7 @@ def load_by_attrs( return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data) - def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig: + def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: """Gets a model's config. Args: diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index c15beb446e9..300ecd751b0 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -36,7 +36,7 @@ CLIPField, CLIPOutput, LoRALoaderOutput, - ModelField, + ModelIdentifierField, ModelLoaderOutput, SDXLLoRALoaderOutput, UNetField, @@ -114,7 +114,7 @@ "MetadataItemOutput", "MetadataOutput", # invokeai.app.invocations.model - "ModelField", + "ModelIdentifierField", "UNetField", "CLIPField", "VAEField", From e8fd3e467ec52d97223b2fbb47d99217a9cc332e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 9 Mar 2024 19:44:15 +1100 Subject: [PATCH 2/4] feat(api): add ModelIdentifierField to openapi schema - Also add `ProgressImage` --- invokeai/app/api_app.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 703774b77bb..8d6c4fc137a 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -3,6 +3,8 @@ # values from the command line or config file. import sys +from invokeai.app.invocations.model import ModelIdentifierField +from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.version.invokeai_version import __version__ from .services.config import InvokeAIAppConfig @@ -156,17 +158,19 @@ def custom_openapi() -> dict[str, Any]: openapi_schema["components"]["schemas"][schema_key] = output_schema openapi_schema["components"]["schemas"][schema_key]["class"] = "output" - # Add Node Editor UI helper schemas - ui_config_schemas = models_json_schema( + # Some models don't end up in the schemas as standalone definitions + additional_schemas = models_json_schema( [ (UIConfigBase, "serialization"), (InputFieldJSONSchemaExtra, "serialization"), (OutputFieldJSONSchemaExtra, "serialization"), + (ModelIdentifierField, "serialization"), + (ProgressImage, "serialization"), ], ref_template="#/components/schemas/{model}", ) - for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items(): - openapi_schema["components"]["schemas"][schema_key] = ui_config_schema + for schema_key, schema_json in additional_schemas[1]["$defs"].items(): + openapi_schema["components"]["schemas"][schema_key] = schema_json # Add a reference to the output type to additionalProperties of the invoker schema for invoker in all_invocations: From 8f9b1302f690f5325a42d45e8634155cdfb83b6a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 9 Mar 2024 19:49:14 +1100 Subject: [PATCH 3/4] chore(ui): typegen --- .../frontend/web/src/services/api/schema.ts | 97 +++++++++++++------ 1 file changed, 66 insertions(+), 31 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 58b1ca309e4..fad626e9557 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1243,9 +1243,9 @@ export type components = { /** CLIPField */ CLIPField: { /** @description Info to load tokenizer submodel */ - tokenizer: components["schemas"]["ModelField"]; + tokenizer: components["schemas"]["ModelIdentifierField"]; /** @description Info to load text_encoder submodel */ - text_encoder: components["schemas"]["ModelField"]; + text_encoder: components["schemas"]["ModelIdentifierField"]; /** * Skipped Layers * @description Number of skipped layers in text_encoder @@ -2248,7 +2248,7 @@ export type components = { /** @description The control image */ image: components["schemas"]["ImageField"]; /** @description The ControlNet model to use */ - control_model: components["schemas"]["ModelField"]; + control_model: components["schemas"]["ModelIdentifierField"]; /** * Control Weight * @description The weight given to the ControlNet @@ -2447,7 +2447,7 @@ export type components = { /** @description The control image */ image?: components["schemas"]["ImageField"]; /** @description ControlNet model to load */ - control_model: components["schemas"]["ModelField"]; + control_model: components["schemas"]["ModelIdentifierField"]; /** * Control Weight * @description The weight given to the ControlNet @@ -4103,7 +4103,7 @@ export type components = { * @description The nodes in this graph */ nodes: { - [key: string]: components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"]; + [key: string]: components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"]; }; /** * Edges @@ -4140,7 +4140,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["ColorCollectionOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["String2Output"] | components["schemas"]["NoiseOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["IPAdapterOutput"]; + [key: string]: components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["String2Output"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["SchedulerOutput"]; }; /** * Errors @@ -4347,9 +4347,9 @@ export type components = { */ image: components["schemas"]["ImageField"] | components["schemas"]["ImageField"][]; /** @description The IP-Adapter model to use. */ - ip_adapter_model: components["schemas"]["ModelField"]; + ip_adapter_model: components["schemas"]["ModelIdentifierField"]; /** @description The name of the CLIP image encoder model. */ - image_encoder_model: components["schemas"]["ModelField"]; + image_encoder_model: components["schemas"]["ModelIdentifierField"]; /** * Weight * @description The weight given to the ControlNet @@ -4400,7 +4400,7 @@ export type components = { * IP-Adapter Model * @description The IP-Adapter model. */ - ip_adapter_model: components["schemas"]["ModelField"]; + ip_adapter_model: components["schemas"]["ModelIdentifierField"]; /** * Weight * @description The weight given to the IP-Adapter @@ -6397,7 +6397,7 @@ export type components = { /** LoRAField */ LoRAField: { /** @description Info to load lora model */ - lora: components["schemas"]["ModelField"]; + lora: components["schemas"]["ModelIdentifierField"]; /** * Weight * @description Weight to apply to lora model @@ -6430,7 +6430,7 @@ export type components = { * LoRA * @description LoRA model to load */ - lora: components["schemas"]["ModelField"]; + lora: components["schemas"]["ModelIdentifierField"]; /** * Weight * @description The weight at which the LoRA is applied to each model @@ -6779,7 +6779,7 @@ export type components = { */ use_cache?: boolean; /** @description Main model (UNet, VAE, CLIP) to load */ - model: components["schemas"]["ModelField"]; + model: components["schemas"]["ModelIdentifierField"]; /** * type * @default main_model_loader @@ -7275,25 +7275,39 @@ export type components = { */ type: "mlsd_image_processor"; }; - /** ModelField */ - ModelField: { + /** + * ModelFormat + * @description Storage format of model. + * @enum {string} + */ + ModelFormat: "diffusers" | "checkpoint" | "lycoris" | "onnx" | "olive" | "embedding_file" | "embedding_folder" | "invokeai"; + /** ModelIdentifierField */ + ModelIdentifierField: { /** * Key - * @description Key of the model + * @description The model's unique key */ key: string; /** - * @description Submodel type + * Hash + * @description The model's BLAKE3 hash + */ + hash: string; + /** + * Name + * @description The model's name + */ + name: string; + /** @description The model's base model type */ + base: components["schemas"]["BaseModelType"]; + /** @description The model's type */ + type: components["schemas"]["ModelType"]; + /** + * @description The submodel to load, if this is a main model * @default null */ submodel_type?: components["schemas"]["SubModelType"] | null; }; - /** - * ModelFormat - * @description Storage format of model. - * @enum {string} - */ - ModelFormat: "diffusers" | "checkpoint" | "lycoris" | "onnx" | "olive" | "embedding_file" | "embedding_folder" | "invokeai"; /** * ModelInstallJob * @description Object that tracks the current status of an install request. @@ -8404,7 +8418,7 @@ export type components = { * LoRA * @description LoRA model to load */ - lora: components["schemas"]["ModelField"]; + lora: components["schemas"]["ModelIdentifierField"]; /** * Weight * @description The weight at which the LoRA is applied to each model @@ -8486,7 +8500,7 @@ export type components = { */ use_cache?: boolean; /** @description SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load */ - model: components["schemas"]["ModelField"]; + model: components["schemas"]["ModelIdentifierField"]; /** * type * @default sdxl_model_loader @@ -8612,7 +8626,7 @@ export type components = { */ use_cache?: boolean; /** @description SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load */ - model: components["schemas"]["ModelField"]; + model: components["schemas"]["ModelIdentifierField"]; /** * type * @default sdxl_refiner_model_loader @@ -9712,7 +9726,7 @@ export type components = { /** @description The T2I-Adapter image prompt. */ image: components["schemas"]["ImageField"]; /** @description The T2I-Adapter model to use. */ - t2i_adapter_model: components["schemas"]["ModelField"]; + t2i_adapter_model: components["schemas"]["ModelIdentifierField"]; /** * Weight * @description The weight given to the T2I-Adapter @@ -9767,7 +9781,7 @@ export type components = { * T2I-Adapter Model * @description The T2I-Adapter model. */ - t2i_adapter_model: components["schemas"]["ModelField"]; + t2i_adapter_model: components["schemas"]["ModelIdentifierField"]; /** * Weight * @description The weight given to the T2I-Adapter @@ -10127,9 +10141,9 @@ export type components = { /** UNetField */ UNetField: { /** @description Info to load unet submodel */ - unet: components["schemas"]["ModelField"]; + unet: components["schemas"]["ModelIdentifierField"]; /** @description Info to load scheduler submodel */ - scheduler: components["schemas"]["ModelField"]; + scheduler: components["schemas"]["ModelIdentifierField"]; /** * Loras * @description LoRAs to apply on model loading @@ -10379,7 +10393,7 @@ export type components = { /** VAEField */ VAEField: { /** @description Info to load vae submodel */ - vae: components["schemas"]["ModelField"]; + vae: components["schemas"]["ModelIdentifierField"]; /** * Seamless Axes * @description Axes("x" and "y") to which apply seamless @@ -10412,7 +10426,7 @@ export type components = { * VAE * @description VAE model to load */ - vae_model: components["schemas"]["ModelField"]; + vae_model: components["schemas"]["ModelIdentifierField"]; /** * type * @default vae_loader @@ -10784,6 +10798,27 @@ export type components = { /** Ui Order */ ui_order: number | null; }; + /** + * ProgressImage + * @description The progress image sent intermittently during processing + */ + ProgressImage: { + /** + * Width + * @description The effective width of the image in pixels + */ + width: number; + /** + * Height + * @description The effective height of the image in pixels + */ + height: number; + /** + * Dataurl + * @description The image data as a b64 data URL + */ + dataURL: string; + }; /** * UIComponent * @description The type of UI component to use for a field, used to override the default components, which are From bbf92c20476532b0d7f124c4353174f503ce5791 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 9 Mar 2024 19:51:15 +1100 Subject: [PATCH 4/4] fix(ui): update all components and logic to use enriched ModelIdentifierField --- .../frontend/web/.storybook/ReduxInit.tsx | 2 +- .../common/hooks/useGroupedModelCombobox.ts | 4 +-- .../web/src/common/hooks/useModelCombobox.ts | 4 +-- .../src/common/hooks/useModelCustomSelect.ts | 4 +-- .../store/controlAdaptersSlice.ts | 3 +- .../web/src/features/lora/store/loraSlice.ts | 4 +-- .../src/features/metadata/util/handlers.ts | 4 +-- .../metadata/util/modelFetchingHelpers.ts | 6 ---- .../web/src/features/metadata/util/parsers.ts | 21 +++++------- .../src/features/nodes/types/common.test-d.ts | 5 ++- .../web/src/features/nodes/types/common.ts | 34 ++++++++++++------- .../web/src/features/nodes/types/field.ts | 14 ++++---- .../nodes/util/graph/addLoRAsToGraph.ts | 18 ++++------ .../nodes/util/graph/addSDXLLoRAstoGraph.ts | 21 ++++-------- .../MainModel/ParamMainModelSelect.tsx | 7 +++- .../VAEModel/ParamVAEModelSelect.tsx | 4 +-- .../parameters/types/parameterSchemas.ts | 16 ++++----- .../parameters/util/optimalDimension.ts | 4 +-- .../ParamSDXLRefinerModelSelect.tsx | 4 +-- 19 files changed, 85 insertions(+), 94 deletions(-) diff --git a/invokeai/frontend/web/.storybook/ReduxInit.tsx b/invokeai/frontend/web/.storybook/ReduxInit.tsx index 7d3f8e0d2bf..d50d52754c2 100644 --- a/invokeai/frontend/web/.storybook/ReduxInit.tsx +++ b/invokeai/frontend/web/.storybook/ReduxInit.tsx @@ -10,7 +10,7 @@ export const ReduxInit = memo((props: PropsWithChildren) => { const dispatch = useAppDispatch(); useGlobalModifiersInit(); useEffect(() => { - dispatch(modelChanged({ key: 'test_model', base: 'sd-1' })); + dispatch(modelChanged({ key: 'test_model', hash: 'some_hash', name: 'some name', base: 'sd-1', type: 'main' })); }, []); return props.children; diff --git a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts index 2fffd7bda0b..ee9da8ea660 100644 --- a/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useGroupedModelCombobox.ts @@ -2,7 +2,7 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import type { EntityState } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import type { GroupBase } from 'chakra-react-select'; -import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; +import type { ModelIdentifierField } from 'features/nodes/types/common'; import { groupBy, map, reduce } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -10,7 +10,7 @@ import type { AnyModelConfig } from 'services/api/types'; type UseGroupedModelComboboxArg = { modelEntities: EntityState | undefined; - selectedModel?: ModelIdentifierWithBase | null; + selectedModel?: ModelIdentifierField | null; onChange: (value: T | null) => void; getIsDisabled?: (model: T) => boolean; isLoading?: boolean; diff --git a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts index e0718d64132..3d9109a5ef1 100644 --- a/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts +++ b/invokeai/frontend/web/src/common/hooks/useModelCombobox.ts @@ -1,6 +1,6 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import type { EntityState } from '@reduxjs/toolkit'; -import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; +import type { ModelIdentifierField } from 'features/nodes/types/common'; import { map } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -8,7 +8,7 @@ import type { AnyModelConfig } from 'services/api/types'; type UseModelComboboxArg = { modelEntities: EntityState | undefined; - selectedModel?: ModelIdentifierWithBase | null; + selectedModel?: ModelIdentifierField | null; onChange: (value: T | null) => void; getIsDisabled?: (model: T) => boolean; optionsFilter?: (model: T) => boolean; diff --git a/invokeai/frontend/web/src/common/hooks/useModelCustomSelect.ts b/invokeai/frontend/web/src/common/hooks/useModelCustomSelect.ts index 5626f4c3952..60de28468c2 100644 --- a/invokeai/frontend/web/src/common/hooks/useModelCustomSelect.ts +++ b/invokeai/frontend/web/src/common/hooks/useModelCustomSelect.ts @@ -1,7 +1,7 @@ import type { Item } from '@invoke-ai/ui-library'; import type { EntityState } from '@reduxjs/toolkit'; import { EMPTY_ARRAY } from 'app/store/constants'; -import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; +import type { ModelIdentifierField } from 'features/nodes/types/common'; import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants'; import { filter } from 'lodash-es'; import { useCallback, useMemo } from 'react'; @@ -11,7 +11,7 @@ import type { AnyModelConfig } from 'services/api/types'; type UseModelCustomSelectArg = { data: EntityState | undefined; isLoading: boolean; - selectedModel?: ModelIdentifierWithBase | null; + selectedModel?: ModelIdentifierField | null; onChange: (value: T | null) => void; modelFilter?: (model: T) => boolean; isModelDisabled?: (model: T) => boolean; diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts index ee36d10e28f..395886b13a9 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts @@ -4,6 +4,7 @@ import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import type { PersistConfig, RootState } from 'app/store/store'; import { buildControlAdapter } from 'features/controlAdapters/util/buildControlAdapter'; import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor'; +import { zModelIdentifierField } from 'features/nodes/types/common'; import { cloneDeep, merge, uniq } from 'lodash-es'; import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types'; import { socketInvocationError } from 'services/events/actions'; @@ -197,7 +198,7 @@ export const controlAdaptersSlice = createSlice({ return; } - const model = { key: modelConfig.key, base: modelConfig.base }; + const model = zModelIdentifierField.parse(modelConfig); if (!isControlNetOrT2IAdapter(cn)) { caAdapter.updateOne(state, { id, changes: { model } }); diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index 7d6a2fccaf8..8f1c138f98d 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -1,7 +1,7 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; -import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers'; +import { zModelIdentifierField } from 'features/nodes/types/common'; import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas'; import type { LoRAModelConfig } from 'services/api/types'; @@ -31,7 +31,7 @@ export const loraSlice = createSlice({ initialState: initialLoraState, reducers: { loraAdded: (state, action: PayloadAction) => { - const model = getModelKeyAndBase(action.payload); + const model = zModelIdentifierField.parse(action.payload); state.loras[model.key] = { ...defaultLoRAConfig, model }; }, loraRecalled: (state, action: PayloadAction) => { diff --git a/invokeai/frontend/web/src/features/metadata/util/handlers.ts b/invokeai/frontend/web/src/features/metadata/util/handlers.ts index 09d65b7d922..069668f74ae 100644 --- a/invokeai/frontend/web/src/features/metadata/util/handlers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/handlers.ts @@ -13,13 +13,13 @@ import type { } from 'features/metadata/types'; import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers'; import { validators } from 'features/metadata/util/validators'; -import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; +import type { ModelIdentifierField } from 'features/nodes/types/common'; import { t } from 'i18next'; import { parsers } from './parsers'; import { recallers } from './recallers'; -const renderModelConfigValue: MetadataRenderValueFunc = async (value) => { +const renderModelConfigValue: MetadataRenderValueFunc = async (value) => { try { const modelConfig = await fetchModelConfig(value.key); return `${modelConfig.name} (${modelConfig.base.toUpperCase()})`; diff --git a/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts b/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts index 3c0745917ac..a237582ed83 100644 --- a/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts @@ -1,5 +1,4 @@ import { getStore } from 'app/store/nanostores/store'; -import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common'; import { modelsApi } from 'services/api/endpoints/models'; import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types'; @@ -105,8 +104,3 @@ export const getModelKey = async (modelIdentifier: unknown, type: ModelType, mes } throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`); }; - -export const getModelKeyAndBase = (modelConfig: AnyModelConfig): ModelIdentifierWithBase => ({ - key: modelConfig.key, - base: modelConfig.base, -}); diff --git a/invokeai/frontend/web/src/features/metadata/util/parsers.ts b/invokeai/frontend/web/src/features/metadata/util/parsers.ts index 24274b8e6a7..4483a4d4148 100644 --- a/invokeai/frontend/web/src/features/metadata/util/parsers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/parsers.ts @@ -13,12 +13,7 @@ import type { T2IAdapterConfigMetadata, } from 'features/metadata/types'; import { fetchModelConfigWithTypeGuard, getModelKey } from 'features/metadata/util/modelFetchingHelpers'; -import { - zControlField, - zIPAdapterField, - zModelIdentifierWithBase, - zT2IAdapterField, -} from 'features/nodes/types/common'; +import { zControlField, zIPAdapterField, zModelIdentifierField, zT2IAdapterField } from 'features/nodes/types/common'; import type { ParameterCFGRescaleMultiplier, ParameterCFGScale, @@ -181,7 +176,7 @@ const parseMainModel: MetadataParseFunc = async (metadata) => { const model = await getProperty(metadata, 'model', undefined); const key = await getModelKey(model, 'main'); const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig); - const modelIdentifier = zModelIdentifierWithBase.parse(mainModelConfig); + const modelIdentifier = zModelIdentifierField.parse(mainModelConfig); return modelIdentifier; }; @@ -189,7 +184,7 @@ const parseRefinerModel: MetadataParseFunc = async (m const refiner_model = await getProperty(metadata, 'refiner_model', undefined); const key = await getModelKey(refiner_model, 'main'); const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig); - const modelIdentifier = zModelIdentifierWithBase.parse(refinerModelConfig); + const modelIdentifier = zModelIdentifierField.parse(refinerModelConfig); return modelIdentifier; }; @@ -197,7 +192,7 @@ const parseVAEModel: MetadataParseFunc = async (metadata) => const vae = await getProperty(metadata, 'vae', undefined); const key = await getModelKey(vae, 'vae'); const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig); - const modelIdentifier = zModelIdentifierWithBase.parse(vaeModelConfig); + const modelIdentifier = zModelIdentifierField.parse(vaeModelConfig); return modelIdentifier; }; @@ -211,7 +206,7 @@ const parseLoRA: MetadataParseFunc = async (metadataItem) => { const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig); return { - model: zModelIdentifierWithBase.parse(loraModelConfig), + model: zModelIdentifierField.parse(loraModelConfig), weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight, isEnabled: true, }; @@ -258,7 +253,7 @@ const parseControlNet: MetadataParseFunc = async (meta const controlNet: ControlNetConfigMetadata = { type: 'controlnet', isEnabled: true, - model: zModelIdentifierWithBase.parse(controlNetModel), + model: zModelIdentifierField.parse(controlNetModel), weight: typeof control_weight === 'number' ? control_weight : initialControlNet.weight, beginStepPct: begin_step_percent ?? initialControlNet.beginStepPct, endStepPct: end_step_percent ?? initialControlNet.endStepPct, @@ -309,7 +304,7 @@ const parseT2IAdapter: MetadataParseFunc = async (meta const t2iAdapter: T2IAdapterConfigMetadata = { type: 't2i_adapter', isEnabled: true, - model: zModelIdentifierWithBase.parse(t2iAdapterModel), + model: zModelIdentifierField.parse(t2iAdapterModel), weight: typeof weight === 'number' ? weight : initialT2IAdapter.weight, beginStepPct: begin_step_percent ?? initialT2IAdapter.beginStepPct, endStepPct: end_step_percent ?? initialT2IAdapter.endStepPct, @@ -354,7 +349,7 @@ const parseIPAdapter: MetadataParseFunc = async (metada id: uuidv4(), type: 'ip_adapter', isEnabled: true, - model: zModelIdentifierWithBase.parse(ipAdapterModel), + model: zModelIdentifierField.parse(ipAdapterModel), controlImage: image?.image_name ?? null, weight: weight ?? initialIPAdapter.weight, beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct, diff --git a/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts b/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts index 13c3db52fd4..f2ebf94b062 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts @@ -1,5 +1,4 @@ import type { - BaseModel, BoardField, Classification, ColorField, @@ -7,6 +6,7 @@ import type { ImageField, ImageOutput, IPAdapterField, + ModelIdentifierField, ProgressImage, SchedulerField, T2IAdapterField, @@ -33,10 +33,9 @@ describe('Common types', () => { test('T2IAdapterField', () => assert>()); // Model component types - test('BaseModel', () => assert>()); + test('ModelIdentifier', () => assert>()); // Misc types - // @ts-expect-error TODO(psyche): There is no `ProgressImage` in the server types yet test('ProgressImage', () => assert>()); test('ImageOutput', () => assert>()); test('Classification', () => assert>()); diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index cbbe150ed41..06d5ecd5c75 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -55,6 +55,17 @@ export type SchedulerField = z.infer; // #region Model-related schemas const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +const zModelType = z.enum([ + 'main', + 'vae', + 'lora', + 'controlnet', + 't2i_adapter', + 'ip_adapter', + 'embedding', + 'onnx', + 'clip_vision', +]); const zSubModelType = z.enum([ 'unet', 'text_encoder', @@ -67,26 +78,25 @@ const zSubModelType = z.enum([ 'scheduler', 'safety_checker', ]); - -const zModelIdentifier = z.object({ +export const zModelIdentifierField = z.object({ key: z.string().min(1), + hash: z.string().min(1), + name: z.string().min(1), + base: zBaseModel, + type: zModelType, submodel_type: zSubModelType.nullish(), }); -export const isModelIdentifier = (field: unknown): field is ModelIdentifier => - zModelIdentifier.safeParse(field).success; +export const isModelIdentifier = (field: unknown): field is ModelIdentifierField => + zModelIdentifierField.safeParse(field).success; export const isModelIdentifierV2 = (field: unknown): field is ModelIdentifierV2 => zModelIdentifierV2.safeParse(field).success; -const zModelFieldBase = zModelIdentifier; -export const zModelIdentifierWithBase = zModelIdentifier.extend({ base: zBaseModel }); -export type BaseModel = z.infer; -type ModelIdentifier = z.infer; -export type ModelIdentifierWithBase = z.infer; +export type ModelIdentifierField = z.infer; // #endregion // #region Control Adapters export const zControlField = z.object({ image: zImageField, - control_model: zModelFieldBase, + control_model: zModelIdentifierField, control_weight: z.union([z.number(), z.array(z.number())]).optional(), begin_step_percent: z.number().optional(), end_step_percent: z.number().optional(), @@ -97,7 +107,7 @@ export type ControlField = z.infer; export const zIPAdapterField = z.object({ image: zImageField, - ip_adapter_model: zModelFieldBase, + ip_adapter_model: zModelIdentifierField, weight: z.number(), begin_step_percent: z.number().optional(), end_step_percent: z.number().optional(), @@ -106,7 +116,7 @@ export type IPAdapterField = z.infer; export const zT2IAdapterField = z.object({ image: zImageField, - t2i_adapter_model: zModelFieldBase, + t2i_adapter_model: zModelIdentifierField, weight: z.union([z.number(), z.array(z.number())]).optional(), begin_step_percent: z.number().optional(), end_step_percent: z.number().optional(), diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 39fa903fd77..84eceff94a4 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -1,6 +1,6 @@ import { z } from 'zod'; -import { zBoardField, zColorField, zImageField, zModelIdentifierWithBase, zSchedulerField } from './common'; +import { zBoardField, zColorField, zImageField, zModelIdentifierField, zSchedulerField } from './common'; /** * zod schemas & inferred types for fields. @@ -277,7 +277,7 @@ export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputT const zMainModelFieldType = zFieldTypeBase.extend({ name: z.literal('MainModelField'), }); -export const zMainModelFieldValue = zModelIdentifierWithBase.optional(); +export const zMainModelFieldValue = zModelIdentifierField.optional(); const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zMainModelFieldValue, }); @@ -348,7 +348,7 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR const zVAEModelFieldType = zFieldTypeBase.extend({ name: z.literal('VAEModelField'), }); -export const zVAEModelFieldValue = zModelIdentifierWithBase.optional(); +export const zVAEModelFieldValue = zModelIdentifierField.optional(); const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zVAEModelFieldValue, }); @@ -372,7 +372,7 @@ export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelField const zLoRAModelFieldType = zFieldTypeBase.extend({ name: z.literal('LoRAModelField'), }); -export const zLoRAModelFieldValue = zModelIdentifierWithBase.optional(); +export const zLoRAModelFieldValue = zModelIdentifierField.optional(); const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zLoRAModelFieldValue, }); @@ -396,7 +396,7 @@ export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFie const zControlNetModelFieldType = zFieldTypeBase.extend({ name: z.literal('ControlNetModelField'), }); -export const zControlNetModelFieldValue = zModelIdentifierWithBase.optional(); +export const zControlNetModelFieldValue = zModelIdentifierField.optional(); const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zControlNetModelFieldValue, }); @@ -420,7 +420,7 @@ export const isControlNetModelFieldInputTemplate = (val: unknown): val is Contro const zIPAdapterModelFieldType = zFieldTypeBase.extend({ name: z.literal('IPAdapterModelField'), }); -export const zIPAdapterModelFieldValue = zModelIdentifierWithBase.optional(); +export const zIPAdapterModelFieldValue = zModelIdentifierField.optional(); const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zIPAdapterModelFieldValue, }); @@ -444,7 +444,7 @@ export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapt const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ name: z.literal('T2IAdapterModelField'), }); -export const zT2IAdapterModelFieldValue = zModelIdentifierWithBase.optional(); +export const zT2IAdapterModelFieldValue = zModelIdentifierField.optional(); const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zT2IAdapterModelFieldValue, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts index c65bddc9b1b..28981a1a8a4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts @@ -1,15 +1,10 @@ import type { RootState } from 'app/store/store'; -import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import { zModelIdentifierField } from 'features/nodes/types/common'; import { filter, size } from 'lodash-es'; -import { - type CoreMetadataInvocation, - isLoRAModelConfig, - type LoRALoaderInvocation, - type NonNullableGraph, -} from 'services/api/types'; +import type { CoreMetadataInvocation, LoRALoaderInvocation, NonNullableGraph } from 'services/api/types'; import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants'; -import { getModelMetadataField, upsertMetadata } from './metadata'; +import { upsertMetadata } from './metadata'; export const addLoRAsToGraph = async ( state: RootState, @@ -49,19 +44,18 @@ export const addLoRAsToGraph = async ( const { weight } = lora; const { key } = lora.model; const currentLoraNodeId = `${LORA_LOADER}_${key}`; + const parsedModel = zModelIdentifierField.parse(lora.model); const loraLoaderNode: LoRALoaderInvocation = { type: 'lora_loader', id: currentLoraNodeId, is_intermediate: true, - lora: { key }, + lora: parsedModel, weight, }; - const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig); - loraMetadata.push({ - model: getModelMetadataField(modelConfig), + model: parsedModel, weight, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts index fb71e5f76e6..1a803102b12 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts @@ -1,12 +1,7 @@ import type { RootState } from 'app/store/store'; -import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import { zModelIdentifierField } from 'features/nodes/types/common'; import { filter, size } from 'lodash-es'; -import { - type CoreMetadataInvocation, - isLoRAModelConfig, - type NonNullableGraph, - type SDXLLoRALoaderInvocation, -} from 'services/api/types'; +import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoRALoaderInvocation } from 'services/api/types'; import { LORA_LOADER, @@ -16,7 +11,7 @@ import { SDXL_REFINER_INPAINT_CREATE_MASK, SEAMLESS, } from './constants'; -import { getModelMetadataField, upsertMetadata } from './metadata'; +import { upsertMetadata } from './metadata'; export const addSDXLLoRAsToGraph = async ( state: RootState, @@ -63,20 +58,18 @@ export const addSDXLLoRAsToGraph = async ( enabledLoRAs.forEach(async (lora) => { const { weight } = lora; - const { key } = lora.model; - const currentLoraNodeId = `${LORA_LOADER}_${key}`; + const currentLoraNodeId = `${LORA_LOADER}_${lora.model.key}`; + const parsedModel = zModelIdentifierField.parse(lora.model); const loraLoaderNode: SDXLLoRALoaderInvocation = { type: 'sdxl_lora_loader', id: currentLoraNodeId, is_intermediate: true, - lora: { key }, + lora: parsedModel, weight, }; - const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig); - - loraMetadata.push({ model: getModelMetadataField(modelConfig), weight }); + loraMetadata.push({ model: parsedModel, weight }); // add to graph graph.nodes[currentLoraNodeId] = loraLoaderNode; diff --git a/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx index 0759020cc82..c3989721f6e 100644 --- a/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/MainModel/ParamMainModelSelect.tsx @@ -3,6 +3,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { useModelCustomSelect } from 'common/hooks/useModelCustomSelect'; +import { zModelIdentifierField } from 'features/nodes/types/common'; import { modelSelected } from 'features/parameters/store/actions'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { memo, useCallback } from 'react'; @@ -24,7 +25,11 @@ const ParamMainModelSelect = () => { if (!model) { return; } - dispatch(modelSelected({ key: model.key, base: model.base })); + try { + dispatch(modelSelected(zModelIdentifierField.parse(model))); + } catch { + // no-op + } }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx index 7a9a4302865..282723b6bf0 100644 --- a/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/VAEModel/ParamVAEModelSelect.tsx @@ -3,7 +3,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; -import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers'; +import { zModelIdentifierField } from 'features/nodes/types/common'; import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; @@ -30,7 +30,7 @@ const ParamVAEModelSelect = () => { ); const _onChange = useCallback( (vae: VAEModelConfig | null) => { - dispatch(vaeSelected(vae ? getModelKeyAndBase(vae) : null)); + dispatch(vaeSelected(vae ? zModelIdentifierField.parse(vae) : null)); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index 2476634c265..75693cd47fa 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -1,6 +1,6 @@ import { NUMPY_RAND_MAX } from 'app/constants'; import { roundToMultiple } from 'common/util/roundDownToMultiple'; -import { zModelIdentifierWithBase, zSchedulerField } from 'features/nodes/types/common'; +import { zModelIdentifierField, zSchedulerField } from 'features/nodes/types/common'; import { z } from 'zod'; /** @@ -92,37 +92,37 @@ export const isParameterHeight = (val: unknown): val is ParameterHeight => zPara // #endregion // #region Model -export const zParameterModel = zModelIdentifierWithBase; +export const zParameterModel = zModelIdentifierField; export type ParameterModel = z.infer; // #endregion // #region SDXL Refiner Model -const zParameterSDXLRefinerModel = zModelIdentifierWithBase; +const zParameterSDXLRefinerModel = zModelIdentifierField; export type ParameterSDXLRefinerModel = z.infer; // #endregion // #region VAE Model -export const zParameterVAEModel = zModelIdentifierWithBase; +export const zParameterVAEModel = zModelIdentifierField; export type ParameterVAEModel = z.infer; // #endregion // #region LoRA Model -const zParameterLoRAModel = zModelIdentifierWithBase; +const zParameterLoRAModel = zModelIdentifierField; export type ParameterLoRAModel = z.infer; // #endregion // #region ControlNet Model -const zParameterControlNetModel = zModelIdentifierWithBase; +const zParameterControlNetModel = zModelIdentifierField; export type ParameterControlNetModel = z.infer; // #endregion // #region IP Adapter Model -const zParameterIPAdapterModel = zModelIdentifierWithBase; +const zParameterIPAdapterModel = zModelIdentifierField; export type ParameterIPAdapterModel = z.infer; // #endregion // #region T2I Adapter Model -const zParameterT2IAdapterModel = zModelIdentifierWithBase; +const zParameterT2IAdapterModel = zModelIdentifierField; export type ParameterT2IAdapterModel = z.infer; // #endregion diff --git a/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts b/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts index 92b4f182727..a2f612137b6 100644 --- a/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts +++ b/invokeai/frontend/web/src/features/parameters/util/optimalDimension.ts @@ -1,11 +1,11 @@ -import type { ModelIdentifierWithBase } from 'features/nodes/types/common'; +import type { ModelIdentifierField } from 'features/nodes/types/common'; /** * Gets the optimal dimension for a givel model, based on the model's base_model * @param model The model identifier * @returns The optimal dimension for the model */ -export const getOptimalDimension = (model?: ModelIdentifierWithBase | null): number => +export const getOptimalDimension = (model?: ModelIdentifierField | null): number => model?.base === 'sdxl' ? 1024 : 512; const MIN_AREA_FACTOR = 0.8; diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx index e3dde06d72f..c5151087957 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerModelSelect.tsx @@ -3,7 +3,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { useModelCombobox } from 'common/hooks/useModelCombobox'; -import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers'; +import { zModelIdentifierField } from 'features/nodes/types/common'; import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; @@ -26,7 +26,7 @@ const ParamSDXLRefinerModelSelect = () => { dispatch(refinerModelChanged(null)); return; } - dispatch(refinerModelChanged(getModelKeyAndBase(model))); + dispatch(refinerModelChanged(zModelIdentifierField.parse(model))); }, [dispatch] );