Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 68 additions & 6 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,8 @@ class TextToLatentsInvocation(BaseInvocation):
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
Expand Down Expand Up @@ -363,9 +360,74 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
session_id=context.graph_execution_state_id, node=self
)

torch.cuda.empty_cache()

context.services.images.save(image_type, image_name, image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image
image_type=image_type, image_name=image_name, image=image
)


LATENTS_INTERPOLATION_MODE = Literal[
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
]


class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""

type: Literal["lresize"] = "lresize"

# Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")

def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)

resized_latents = torch.nn.functional.interpolate(
latents,
size=(self.height // 8, self.width // 8),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()

name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))


class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor."""

type: Literal["lscale"] = "lscale"

# Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")

def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)

# resizing
resized_latents = torch.nn.functional.interpolate(
latents,
scale_factor=self.scale_factor,
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()

name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))