Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ui): connection validation rework #6386

Merged
merged 55 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
c54fe01
feat(ui): add `originalType` to FieldType, improved connection valida…
psychedelicious May 17, 2024
d8000fd
feat(nodes): make all `ModelIdentifierField` inputs accept connections
psychedelicious May 17, 2024
79cebb9
fix(ui): stupid ts
psychedelicious May 17, 2024
df7cdc1
fix(ui): do not allow comparison between undefined original types
psychedelicious May 17, 2024
8c04641
feat(ui): add ModelIdentifierField field type
psychedelicious May 17, 2024
a8b9c75
feat(nodes): add `ModelIdentifierInvocation`
psychedelicious May 17, 2024
f6f83af
fix(ui): rebase resolution
psychedelicious May 17, 2024
0b4bf72
chore(ui): lint
psychedelicious May 17, 2024
1eff66a
tidy(ui): tidy connection validation functions and logic
psychedelicious May 18, 2024
0d45d1b
feat(ui): makeConnectionErrorSelector now creates a parameterized sel…
psychedelicious May 18, 2024
677c37f
fix(ui): rebase conflict
psychedelicious May 18, 2024
1a339e3
tidy(ui): clean up addnodepopover hotkeys
psychedelicious May 18, 2024
ce0b0e5
fix(ui): rework edge update logic
psychedelicious May 18, 2024
ceabd9a
fix(ui): call updateNodeInternals when making connections
psychedelicious May 18, 2024
92b74ad
tests(ui): add tests for consolidated connection validation
psychedelicious May 18, 2024
9878224
tests(ui): finish test cases for validateConnection
psychedelicious May 18, 2024
503abf5
tidy(ui): areTypesEqual var names
psychedelicious May 18, 2024
c0eff8f
tidy(ui): validateConnection code clarity
psychedelicious May 18, 2024
da8db49
tests(ui): coverage for validateConnectionTypes
psychedelicious May 18, 2024
0baff06
tests(ui): coverage for getCollectItemType
psychedelicious May 18, 2024
c18938c
tests(ui): add iterate to test schema
psychedelicious May 18, 2024
9637cc3
feat(ui): better types for validateConnection
psychedelicious May 18, 2024
d2093ef
feat(ui): add strict mode to validateConnection
psychedelicious May 18, 2024
4fe32f4
feat(ui): use new validateConnection
psychedelicious May 18, 2024
1a3431f
fix(ui): handling for in-progress edge updates during conection valid…
psychedelicious May 18, 2024
67f2461
tidy(ui): extraneous vars in makeConnectionErrorSelector
psychedelicious May 18, 2024
ec99d72
tests(ui): add buildNode convenience wrapper for buildInvocationNode
psychedelicious May 18, 2024
2d35d0f
feat(ui): extract logic for finding candidate fields to own function
psychedelicious May 18, 2024
87a5216
tests(ui): candidate fields, getFirstValidConnection (wip)
psychedelicious May 18, 2024
0271c9b
feat(ui): rework getFirstValidConnection with new helpers
psychedelicious May 18, 2024
e1ec42f
tests(ui): coverage for getFirstValidConnection
psychedelicious May 18, 2024
da1dd71
tests(ui): coverage for getCollectItemType
psychedelicious May 19, 2024
9d01282
feat(ui): rework pendingConnection
psychedelicious May 19, 2024
c70392d
chore(ui): knip
psychedelicious May 19, 2024
75ec910
fix(ui): duplicated edges when updating edge with lazy connect
psychedelicious May 19, 2024
838bff9
feat(ui): rework node and edge mutation logic
psychedelicious May 19, 2024
338bb39
fix(ui): edge styling
psychedelicious May 19, 2024
c1ec9d9
fix(ui): collapsed edges selected state
psychedelicious May 19, 2024
5775db8
feat(ui): tweak edge styling
psychedelicious May 19, 2024
ff99fd2
feat(ui): get rid of nodeAdded
psychedelicious May 19, 2024
fea1305
fix(ui): group edge selection actions
psychedelicious May 19, 2024
7257f1f
feat(ui): remove nodeReplaced action
psychedelicious May 19, 2024
bcb8df9
feat(ui): remove selectedAll action
psychedelicious May 19, 2024
0343f32
feat(ui): remove selectionPasted action
psychedelicious May 19, 2024
5bb3239
feat(ui): remove selectionDeleted action
psychedelicious May 19, 2024
0362026
tidy(ui): more succinct syntax for edge and node updates
psychedelicious May 19, 2024
ed7aac9
feat(ui): remove nodeExclusivelySelected action
psychedelicious May 19, 2024
8888455
fix(ui): do not remove exposed fields when updating workflows
psychedelicious May 19, 2024
0fe0640
perf(ui): ignore all no-op node and edge changes
psychedelicious May 19, 2024
1fefe3b
fix(ui): set nodeDragThreshold to prevent spurious position change ev…
psychedelicious May 19, 2024
0490383
feat(ui): use connection validationResults directly in components
psychedelicious May 19, 2024
16c59fe
perf(ui): memoize WorkflowName selectors
psychedelicious May 19, 2024
dde1725
feat(nodes): make ModelIdentifierInvocation a prototype
psychedelicious May 19, 2024
dae60bb
fix(ui): delete edges when their source or target no longer exists
psychedelicious May 19, 2024
31d43cc
fix(ui): missed node execution state for progress images
psychedelicious May 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions invokeai/app/invocations/controlnet_image_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
OutputField,
UIType,
Expand Down Expand Up @@ -80,13 +79,13 @@ class ControlOutput(BaseInvocationOutput):
control: ControlField = OutputField(description=FieldDescriptions.control)


@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.1")
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.2")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""

image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
Expand Down
5 changes: 2 additions & 3 deletions invokeai/app/invocations/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import Self

from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
Expand Down Expand Up @@ -58,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput):
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}


@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.1")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""

Expand All @@ -67,7 +67,6 @@ class IPAdapterInvocation(BaseInvocation):
ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.",
title="IP-Adapter Model",
input=Input.Direct,
ui_order=-1,
ui_type=UIType.IPAdapterModel,
)
Expand Down
52 changes: 40 additions & 12 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
Expand Down Expand Up @@ -93,19 +94,46 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
pass


@invocation_output("model_identifier_output")
class ModelIdentifierOutput(BaseInvocationOutput):
"""Model identifier output"""

model: ModelIdentifierField = OutputField(description="Model identifier", title="Model")


@invocation(
"model_identifier",
title="Model identifier",
tags=["model"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class ModelIdentifierInvocation(BaseInvocation):
"""Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as
input for any model, even if the model types don't match. If you connect this to a mismatched input, you'll get an
error."""

model: ModelIdentifierField = InputField(description="The model to select", title="Model")

def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
if not context.models.exists(self.model.key):
raise Exception(f"Unknown model {self.model.key}")

return ModelIdentifierOutput(model=self.model)


@invocation(
"main_model_loader",
title="Main Model",
tags=["model"],
category="model",
version="1.0.2",
version="1.0.3",
)
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""

model: ModelIdentifierField = InputField(
description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel
)
model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel)
# TODO: precision?

def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
Expand Down Expand Up @@ -134,12 +162,12 @@ class LoRALoaderOutput(BaseInvocationOutput):
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")


@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.2")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3")
class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""

lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
Expand Down Expand Up @@ -197,12 +225,12 @@ class LoRASelectorOutput(BaseInvocationOutput):
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")


@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.0")
@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.1")
class LoRASelectorInvocation(BaseInvocation):
"""Selects a LoRA model and weight."""

lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)

Expand Down Expand Up @@ -273,13 +301,13 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
title="SDXL LoRA",
tags=["lora", "model"],
category="model",
version="1.0.2",
version="1.0.3",
)
class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""

lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel
description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
unet: Optional[UNetField] = InputField(
Expand Down Expand Up @@ -414,12 +442,12 @@ def invoke(self, context: InvocationContext) -> SDXLLoRALoaderOutput:
return output


@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2")
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.3")
class VAELoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""

vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel
description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel
)

def invoke(self, context: InvocationContext) -> VAEOutput:
Expand Down
10 changes: 5 additions & 5 deletions invokeai/app/invocations/sdxl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import SubModelType

Expand Down Expand Up @@ -30,12 +30,12 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")


@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.2")
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.3")
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""

model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel
)
# TODO: precision?

Expand Down Expand Up @@ -67,13 +67,13 @@ def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
title="SDXL Refiner Model",
tags=["model", "sdxl", "refiner"],
category="model",
version="1.0.2",
version="1.0.3",
)
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""

model: ModelIdentifierField = InputField(
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel
)
# TODO: precision?

Expand Down
5 changes: 2 additions & 3 deletions invokeai/app/invocations/t2i_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
Expand Down Expand Up @@ -45,7 +45,7 @@ class T2IAdapterOutput(BaseInvocationOutput):


@invocation(
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.2"
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.3"
)
class T2IAdapterInvocation(BaseInvocation):
"""Collects T2I-Adapter info to pass to other nodes."""
Expand All @@ -55,7 +55,6 @@ class T2IAdapterInvocation(BaseInvocation):
t2i_adapter_model: ModelIdentifierField = InputField(
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
input=Input.Direct,
ui_order=-1,
ui_type=UIType.T2IAdapterModel,
)
Expand Down
3 changes: 3 additions & 0 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,9 @@
"cannotConnectToSelf": "Cannot connect to self",
"cannotDuplicateConnection": "Cannot create duplicate connections",
"cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types",
"missingNode": "Missing invocation node",
"missingInvocationTemplate": "Missing invocation template",
"missingFieldTemplate": "Missing field template",
"nodePack": "Node pack",
"collection": "Collection",
"collectionFieldType": "{{name}} Collection",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions';

Expand All @@ -18,6 +18,7 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps;
nes.progressImage = progress_image ?? null;
upsertExecutionState(nes.nodeId, nes);
}
},
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { updateAllNodesRequested } from 'features/nodes/store/actions';
import { $templates, nodeReplaced } from 'features/nodes/store/nodesSlice';
import { $templates, nodesChanged } from 'features/nodes/store/nodesSlice';
import { NodeUpdateError } from 'features/nodes/types/error';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
Expand Down Expand Up @@ -31,7 +31,12 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
}
try {
const updatedNode = updateNode(node, template);
dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode }));
dispatch(
nodesChanged([
{ type: 'remove', id: updatedNode.id },
{ type: 'add', item: updatedNode },
])
);
} catch (e) {
if (e instanceof NodeUpdateError) {
unableToUpdateCount++;
Expand Down
Loading
Loading