Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6ab9a5e
Draft
StAlKeR7779 Jul 5, 2023
788dcbd
fix(nodes): add missing import
psychedelicious Jul 8, 2023
82fa39b
feat(nodes): add controlnet nodes type hint
psychedelicious Jul 8, 2023
96c9db6
chore(ui): typegen
psychedelicious Jul 8, 2023
29b2e59
fix(nodes): fix ref to ctx mgr service, missing import
psychedelicious Jul 8, 2023
5ac1145
feat(ui): add controlnet field to nodes
psychedelicious Jul 8, 2023
76dc47e
remove frontend constants, use backend response for controlnet models…
Jul 11, 2023
0d41346
feat(ui): fix controlNet models
psychedelicious Jul 15, 2023
ae72f37
fix(nodes): do not use hardcoded controlnet model
psychedelicious Jul 15, 2023
d270f21
feat(nodes): valid controlnet weights are -1 to 2
psychedelicious Jul 15, 2023
8f66d82
feat(ui): refactor controlnet UI components to use local memoized sel…
psychedelicious Jul 15, 2023
7b6d91c
feat(ui): control net UI weights 0 to 2
psychedelicious Jul 15, 2023
952a7a8
feat(ui): do not autoprocess if user just disabled autoconfig
psychedelicious Jul 15, 2023
77ad3c9
feat(ui): tweak slider styles
psychedelicious Jul 15, 2023
8a14c5d
feat(ui): wip controlnet layout
psychedelicious Jul 15, 2023
19e076c
fix(ui): fix no controlnet model selected by default
psychedelicious Jul 15, 2023
401727b
feat(ui): add cnet advanced tooltip
psychedelicious Jul 15, 2023
13d182e
feat(ui): move cnet add button to top of list
psychedelicious Jul 15, 2023
7dec2d0
feat(ui): disable specific controlnet inputs when that controlnet is …
psychedelicious Jul 15, 2023
d1ecd00
feat(ui): promote controlnet to be just under general
psychedelicious Jul 15, 2023
457e4b7
feat(ui): tweak slider label spacing
psychedelicious Jul 15, 2023
7daafc0
fix(ui): fix invoke button styles when processing
psychedelicious Jul 15, 2023
4ac0ce5
fix(ui): add custom label to IAIMantineSelects
psychedelicious Jul 15, 2023
be4705e
feat(ui): move control mode and processor to main view
psychedelicious Jul 15, 2023
8e0ba24
feat(ui): fix cnet ui alignment
psychedelicious Jul 15, 2023
b265956
fix(ui): disable drop when controlnet disabled
psychedelicious Jul 15, 2023
f7230d0
feat(ui): fix controlnet image preview alignment
psychedelicious Jul 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions invokeai/app/invocations/controlnet_image_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from PIL import Image
from pydantic import BaseModel, Field, validator

from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
Expand Down Expand Up @@ -105,9 +106,15 @@
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]


class ControlNetModelField(BaseModel):
"""ControlNet model field"""

model_name: str = Field(description="Name of the ControlNet model")
base_model: BaseModelType = Field(description="Base model")

class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image")
control_model: Optional[str] = Field(default=None, description="The ControlNet model to use")
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1,
Expand All @@ -118,22 +125,23 @@ class ControlField(BaseModel):
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")

@validator("control_weight")
def abs_le_one(cls, v):
"""validate that all abs(values) are <=1"""
def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range"""
if isinstance(v, list):
for i in v:
if abs(i) > 1:
raise ValueError('all abs(control_weight) must be <= 1')
if i < -1 or i > 2:
raise ValueError('Control weights must be within -1 to 2 range')
else:
if abs(v) > 1:
raise ValueError('abs(control_weight) must be <= 1')
if v < -1 or v > 2:
raise ValueError('Control weights must be within -1 to 2 range')
return v
class Config:
schema_extra = {
"required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"],
"ui": {
"type_hints": {
"control_weight": "float",
"control_model": "controlnet_model",
# "control_weight": "number",
}
}
Expand All @@ -154,10 +162,10 @@ class ControlNetInvocation(BaseInvocation):
type: Literal["controlnet"] = "controlnet"
# Inputs
image: ImageField = Field(default=None, description="The control image")
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
begin_step_percent: float = Field(default=0, ge=0, le=1,
begin_step_percent: float = Field(default=0, ge=-1, le=2,
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
Expand Down
103 changes: 65 additions & 38 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)

from contextlib import ExitStack
from typing import List, Literal, Optional, Union

import einops
Expand All @@ -11,6 +12,7 @@

from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models.base import ModelType

from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
Expand Down Expand Up @@ -71,16 +73,21 @@ def get_scheduler(
scheduler_name: str,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_name, SCHEDULER_MAP['ddim']
)
orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict())
**scheduler_info.dict()
)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config

if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **
scheduler_extra_config, "_backup": scheduler_config}
scheduler_config = {
**scheduler_config,
**scheduler_extra_config,
"_backup": scheduler_config,
}
scheduler = scheduler_class.from_config(scheduler_config)

# hack copied over from generate.py
Expand Down Expand Up @@ -137,8 +144,11 @@ class Config(InvocationConfig):

# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, source_node_id: str,
intermediate_state: PipelineIntermediateState) -> None:
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
Expand All @@ -147,11 +157,16 @@ def dispatch_progress(
)

def get_conditioning_data(
self, context: InvocationContext, scheduler) -> ConditioningData:
self,
context: InvocationContext,
scheduler,
) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(
self.positive_conditioning.conditioning_name)
self.positive_conditioning.conditioning_name
)
uc, _ = context.services.latents.get(
self.negative_conditioning.conditioning_name)
self.negative_conditioning.conditioning_name
)

conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
Expand All @@ -178,7 +193,10 @@ def get_conditioning_data(
return conditioning_data

def create_pipeline(
self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
self,
unet,
scheduler,
) -> StableDiffusionGeneratorPipeline:
# TODO:
# configure_model_padding(
# unet,
Expand Down Expand Up @@ -213,6 +231,7 @@ def prep_control_data(
model: StableDiffusionGeneratorPipeline,
control_input: List[ControlField],
latents_shape: List[int],
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:

Expand All @@ -238,25 +257,19 @@ def prep_control_data(
control_data = []
control_models = []
for control_info in control_list:
# handle control models
if ("," in control_info.control_model):
control_model_split = control_info.control_model.split(",")
control_name = control_model_split[0]
control_subfolder = control_model_split[1]
print("Using HF model subfolders")
print(" control_name: ", control_name)
print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(
control_name, subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(
model.device)
else:
control_model = ControlNetModel.from_pretrained(
control_info.control_model, torch_dtype=model.unet.dtype).to(model.device)
control_model = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
)
)

control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(
control_image_field.image_name)
control_image_field.image_name
)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
Expand All @@ -278,7 +291,8 @@ def prep_control_data(
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,)
control_mode=control_info.control_mode,
)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
Expand All @@ -289,7 +303,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:

# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]

def step_callback(state: PipelineIntermediateState):
Expand All @@ -298,14 +313,17 @@ def step_callback(state: PipelineIntermediateState):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
**lora.dict(exclude={"weight"})
)
yield (lora_info.context.model, lora.weight)
del lora_info
return

unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
**self.unet.unet.dict()
)
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:

scheduler = get_scheduler(
Expand All @@ -322,6 +340,7 @@ def _lora_loader():
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)

# TODO: Verify the noise is the right size
Expand Down Expand Up @@ -374,7 +393,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:

# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]

def step_callback(state: PipelineIntermediateState):
Expand All @@ -383,14 +403,17 @@ def step_callback(state: PipelineIntermediateState):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
**lora.dict(exclude={"weight"})
)
yield (lora_info.context.model, lora.weight)
del lora_info
return

unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict())
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
**self.unet.unet.dict()
)
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:

scheduler = get_scheduler(
Expand All @@ -407,11 +430,13 @@ def _lora_loader():
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)

# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=unet.device, dtype=latent.dtype)
latent, device=unet.device, dtype=latent.dtype
)

timesteps, _ = pipeline.get_img2img_timesteps(
self.steps,
Expand Down Expand Up @@ -535,7 +560,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
resized_latents = torch.nn.functional.interpolate(
latents, size=(self.height // 8, self.width // 8),
mode=self.mode, antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,)
if self.mode in ["bilinear", "bicubic"] else False,
)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
Expand Down Expand Up @@ -569,7 +595,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
resized_latents = torch.nn.functional.interpolate(
latents, scale_factor=self.scale_factor, mode=self.mode,
antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,)
if self.mode in ["bilinear", "bicubic"] else False,
)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ import { RootState } from 'app/store/store';

const moduleLog = log.child({ namespace: 'controlNet' });

const predicate: AnyListenerPredicate<RootState> = (action, state) => {
const predicate: AnyListenerPredicate<RootState> = (
action,
state,
prevState
) => {
const isActionMatched =
controlNetProcessorParamsChanged.match(action) ||
controlNetModelChanged.match(action) ||
Expand All @@ -25,6 +29,16 @@ const predicate: AnyListenerPredicate<RootState> = (action, state) => {
return false;
}

if (controlNetAutoConfigToggled.match(action)) {
// do not process if the user just disabled auto-config
if (
prevState.controlNet.controlNets[action.payload.controlNetId]
.shouldAutoConfig === true
) {
return false;
}
}

const { controlImage, processorType, shouldAutoConfig } =
state.controlNet.controlNets[action.payload.controlNetId];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { forEach } from 'lodash-es';
import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';

const moduleLog = log.child({ module: 'models' });

Expand Down Expand Up @@ -51,7 +52,14 @@ export const addModelSelectedListener = () => {
modelsCleared += 1;
}

// TODO: handle incompatible controlnet; pending model manager support
const { controlNets } = state.controlNet;
forEach(controlNets, (controlNet, controlNetId) => {
if (controlNet.model?.base_model !== base_model) {
dispatch(controlNetRemoved({ controlNetId }));
modelsCleared += 1;
}
});

if (modelsCleared > 0) {
dispatch(
addToast(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';

const moduleLog = log.child({ module: 'models' });

Expand Down Expand Up @@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => {
matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
// TODO: pending model manager controlnet support
const controlNets = getState().controlNet.controlNets;

forEach(controlNets, (controlNet, controlNetId) => {
const isControlNetAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === controlNet?.model?.model_name &&
m?.base_model === controlNet?.model?.base_model
);

if (isControlNetAvailable) {
return;
}

dispatch(controlNetRemoved({ controlNetId }));
});
},
});
};
Loading