Skip to content

Commit

Permalink
Update l2i invoke and seamless to support AutoencoderTiny, remove att… (
Browse files Browse the repository at this point in the history
#5936)

…ention processors if no mid_block is detected

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [x] No


## Description
L2i throws an assertion error when run with `madebyollin/taesdxl` due to
it requiring a different class in diffusers to load it. This is a small
PR to update seamless and l2i to accept AutoencoderTiny models and not
throw exceptions while processing them.

## QA Instructions, Screenshots, Recordings

<img width="445" alt="Screenshot 2024-03-12 at 12 04 29 PM"
src="https://github.com/invoke-ai/InvokeAI/assets/58442074/34a17e44-d911-4fef-8fc1-71f7b688688c">
Run an sdxl pipeline using a vae that requires AutoencoderTiny and
validate that the image successfully encodes and decodes.

## Merge Plan

This PR can be merged when approved
  • Loading branch information
blessedcoolant committed Mar 12, 2024
2 parents 43948e0 + 8d2a4db commit 54f1a1f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 3 additions & 3 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,14 +837,14 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)

vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module)
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)

use_torch_2_0_or_xformers = isinstance(
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
Expand Down Expand Up @@ -1018,7 +1018,7 @@ def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: t
if upcast:
vae.to(dtype=torch.float32)

use_torch_2_0_or_xformers = isinstance(
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
Expand Down
3 changes: 2 additions & 1 deletion invokeai/backend/stable_diffusion/seamless.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch.nn as nn
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel


Expand All @@ -26,7 +27,7 @@ def _conv_forward_asymmetric(self, input, weight, bias):


@contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]):
# Callable: (input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor
to_restore: list[tuple[nn.Conv2d | nn.ConvTranspose2d, Callable]] = []
try:
Expand Down

0 comments on commit 54f1a1f

Please sign in to comment.