Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5aa7bfe
Fix masked generation with inpaint models
StAlKeR7779 Aug 16, 2023
bf0dfca
Add inapint mask field class
StAlKeR7779 Aug 17, 2023
ff5c725
Update mask field type
StAlKeR7779 Aug 17, 2023
b213335
feat: Add InpaintMask Field type
blessedcoolant Aug 17, 2023
e9a294f
Merge branch 'main' into fix/inpaint_gen
lstein Aug 17, 2023
cfd827c
Added node for creating mask inpaint
StAlKeR7779 Aug 18, 2023
3c43594
Merge branch 'main' into fix/inpaint_gen
psychedelicious Aug 18, 2023
c49851e
chore: minor cleanup after merge & flake8
psychedelicious Aug 18, 2023
e9633a3
Merge branch 'main' into fix/inpaint_gen
blessedcoolant Aug 26, 2023
382a55a
fix: merge conflicts
blessedcoolant Aug 26, 2023
af3e316
chore: Regen schema
blessedcoolant Aug 26, 2023
226721c
feat: Setup UnifiedCanvas to work with new InpaintMaskField
blessedcoolant Aug 26, 2023
c923d09
rename: Inpaint Mask to Denoise Mask
blessedcoolant Aug 26, 2023
521da55
feat: Update color of Denoise Mask socket
blessedcoolant Aug 26, 2023
249048a
fix: Reorder DenoiseMask socket fields
blessedcoolant Aug 26, 2023
b18695d
fix: Update color of denoise mask socket
blessedcoolant Aug 26, 2023
3f8d17d
chore: Black linting
blessedcoolant Aug 26, 2023
71c3955
feat: Add Scale Before Processing To Canvas Txt2Img / Img2Img (w/ SDXL)
blessedcoolant Aug 26, 2023
1811b54
Provide metadata to image creation call
StAlKeR7779 Aug 27, 2023
526c7e7
Provide antialias argument as behaviour will be changed in future(dep…
StAlKeR7779 Aug 27, 2023
3e6c490
Change antialias to True as input - image
StAlKeR7779 Aug 27, 2023
4f00dbe
Merge branch 'main' into fix/inpaint_gen
hipsterusername Aug 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions invokeai/app/invocations/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ class ImageResizeInvocation(BaseInvocation):
width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
metadata: Optional[CoreMetadata] = InputField(
default=None, description=FieldDescriptions.core_metadata, ui_hidden=True
)

def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
Expand All @@ -393,6 +396,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
)

return ImageOutput(
Expand Down
155 changes: 108 additions & 47 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import (
DenoiseMaskField,
DenoiseMaskOutput,
ImageField,
ImageOutput,
LatentsField,
Expand All @@ -31,8 +33,8 @@
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings

from ...backend.model_management.models import BaseModelType
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import BaseModelType
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,
Expand All @@ -44,16 +46,7 @@
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
FieldDescriptions,
Input,
InputField,
InvocationContext,
UIType,
tags,
title,
)
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, UIType, tags, title
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField
Expand All @@ -64,6 +57,72 @@
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]


@title("Create Denoise Mask")
@tags("mask", "denoise")
class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""

# Metadata
type: Literal["create_denoise_mask"] = "create_denoise_mask"

# Inputs
vae: VaeField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32, ui_order=4)

def prep_mask_tensor(self, mask_image):
if mask_image.mode != "L":
mask_image = mask_image.convert("L")
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0)
# if shape is not None:
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
return mask_tensor

@torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
if self.image is not None:
image = context.services.images.get_pil_image(self.image.image_name)
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image.dim() == 3:
image = image.unsqueeze(0)
else:
image = None

mask = self.prep_mask_tensor(
context.services.images.get_pil_image(self.mask.image_name),
)

if image is not None:
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
context=context,
)

img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
# TODO:
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())

masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents"
context.services.latents.save(masked_latents_name, masked_latents)
else:
masked_latents_name = None

mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
context.services.latents.save(mask_name, mask)

return DenoiseMaskOutput(
denoise_mask=DenoiseMaskField(
mask_name=mask_name,
masked_latents_name=masked_latents_name,
),
)


def get_scheduler(
context: InvocationContext,
scheduler_info: ModelInfo,
Expand Down Expand Up @@ -126,10 +185,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
control: Union[ControlField, list[ControlField]] = InputField(
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
)
latents: Optional[LatentsField] = InputField(
description=FieldDescriptions.latents, input=Input.Connection, ui_order=4
)
mask: Optional[ImageField] = InputField(
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.mask,
)
Expand Down Expand Up @@ -342,19 +399,18 @@ def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_en

return num_inference_steps, timesteps, init_timestep

def prep_mask_tensor(self, mask, context, lantents):
if mask is None:
return None
def prep_inpaint_mask(self, context, latents):
if self.denoise_mask is None:
return None, None

mask_image = context.services.images.get_pil_image(mask.image_name)
if mask_image.mode != "L":
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
mask_image = mask_image.convert("L")
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0)
mask_tensor = tv_resize(mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR)
return 1 - mask_tensor
mask = context.services.latents.get(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
if self.denoise_mask.masked_latents_name is not None:
masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name)
else:
masked_latents = None

return 1 - mask, masked_latents

@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
Expand All @@ -375,7 +431,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
if seed is None:
seed = 0

mask = self.prep_mask_tensor(self.mask, context, latents)
mask, masked_latents = self.prep_inpaint_mask(context, latents)

# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
Expand Down Expand Up @@ -406,6 +462,8 @@ def _lora_loader():
noise = noise.to(device=unet.device, dtype=unet.dtype)
if mask is not None:
mask = mask.to(device=unet.device, dtype=unet.dtype)
if masked_latents is not None:
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)

scheduler = get_scheduler(
context=context,
Expand Down Expand Up @@ -442,6 +500,7 @@ def _lora_loader():
noise=noise,
seed=seed,
mask=mask,
masked_latents=masked_latents,
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
Expand Down Expand Up @@ -663,26 +722,11 @@ class ImageToLatentsInvocation(BaseInvocation):
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image(self.image.image_name)

# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
context=context,
)

image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")

@staticmethod
def vae_encode(vae_info, upcast, tiled, image_tensor):
with vae_info as vae:
orig_dtype = vae.dtype
if self.fp32:
if upcast:
vae.to(dtype=torch.float32)

use_torch_2_0_or_xformers = isinstance(
Expand All @@ -707,7 +751,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
vae.to(dtype=torch.float16)
# latents = latents.half()

if self.tiled:
if tiled:
vae.enable_tiling()
else:
vae.disable_tiling()
Expand All @@ -721,6 +765,23 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype)

return latents

@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get_pil_image(self.image.image_name)

vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
context=context,
)

image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")

latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)

name = f"{context.graph_execution_state_id}__{self.id}"
latents = latents.to("cpu")
context.services.latents.save(name, latents)
Expand Down
19 changes: 19 additions & 0 deletions invokeai/app/invocations/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,25 @@ def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.collection)


# endregion

# region DenoiseMask


class DenoiseMaskField(BaseModel):
"""An inpaint mask field"""

mask_name: str = Field(description="The name of the mask image")
masked_latents_name: Optional[str] = Field(description="The name of the masked image latents")


class DenoiseMaskOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image"""

type: Literal["denoise_mask_output"] = "denoise_mask_output"
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")


# endregion

# region Latents
Expand Down
41 changes: 21 additions & 20 deletions invokeai/backend/stable_diffusion/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = Tr
w, h = trim_to_multiple_of(*image.size, multiple_of=multiple_of)
transformation = T.Compose(
[
T.Resize((h, w), T.InterpolationMode.LANCZOS),
T.Resize((h, w), T.InterpolationMode.LANCZOS, antialias=True),
T.ToTensor(),
]
)
Expand Down Expand Up @@ -358,6 +358,7 @@ def latents_from_embeddings(
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
seed: Optional[int] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if init_timestep.shape[0] == 0:
Expand All @@ -376,28 +377,28 @@ def latents_from_embeddings(
latents = self.scheduler.add_noise(latents, noise, batched_t)

if mask is not None:
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
).to(device=orig_latents.device, dtype=orig_latents.dtype)

latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)

if is_inpainting_model(self.unet):
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
# (that's why there's a mask!) but it seems to really want that blanked out.
# masked_latents = latents * torch.where(mask < 0.5, 1, 0) TODO: inpaint/outpaint/infill
if masked_latents is None:
raise Exception("Source image required for inpaint mask when inpaint model used!")

# TODO: we should probably pass this in so we don't have to try/finally around setting it.
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(self._unet_forward, mask, orig_latents)
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
self._unet_forward, mask, masked_latents
)
else:
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
).to(device=orig_latents.device, dtype=orig_latents.dtype)

latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)

additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise))

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import ColorInputField from './inputs/ColorInputField';
import ConditioningInputField from './inputs/ConditioningInputField';
import ControlInputField from './inputs/ControlInputField';
import ControlNetModelInputField from './inputs/ControlNetModelInputField';
import DenoiseMaskInputField from './inputs/DenoiseMaskInputField';
import EnumInputField from './inputs/EnumInputField';
import ImageCollectionInputField from './inputs/ImageCollectionInputField';
import ImageInputField from './inputs/ImageInputField';
Expand Down Expand Up @@ -105,6 +106,19 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}

if (
field?.type === 'DenoiseMaskField' &&
fieldTemplate?.type === 'DenoiseMaskField'
) {
return (
<DenoiseMaskInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}

if (
field?.type === 'ConditioningField' &&
fieldTemplate?.type === 'ConditioningField'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import {
DenoiseMaskInputFieldTemplate,
DenoiseMaskInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';

const DenoiseMaskInputFieldComponent = (
_props: FieldComponentProps<
DenoiseMaskInputFieldValue,
DenoiseMaskInputFieldTemplate
>
) => {
return null;
};

export default memo(DenoiseMaskInputFieldComponent);
Loading