[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/LayerDiffuse-jupyter/blob/main/LayerDiffuse_jupyter.ipynb)

In [None]:
!pip install -q torchsde diffusers accelerate einops xformers==0.0.25

%cd /content
!git clone -b totoro https://github.com/camenduru/ComfyUI-layerdiffuse /content/layerdiffuse
%cd /content/layerdiffuse

from typing import Optional
import os, torch
from urllib.parse import urlparse
import numpy as np

def load_file_from_url(url: str, *, model_dir: str, progress: bool = True, file_name: Optional[str] = None,) -> str:
    os.makedirs(model_dir, exist_ok=True)
    if not file_name:
        parts = urlparse(url)
        file_name = os.path.basename(parts.path)
    cached_file = os.path.abspath(os.path.join(model_dir, file_name))
    if not os.path.exists(cached_file):
        print(f'Downloading: "{url}" to {cached_file}\n')
        from torch.hub import download_url_to_file
        download_url_to_file(url, cached_file, progress=progress)
    return cached_file

ckpt_path = load_file_from_url(url="https://huggingface.co/RunDiffusion/Juggernaut-XL-v9/resolve/main/Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors", model_dir="/content/model", file_name="Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors")

import totoro.sd, torch
ckpt_path = "/content/model/Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors"
with torch.no_grad():
  model_patcher, clip, vae, clipvision = totoro.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=None)

from rng_noise_generator import ImageRNGNoise
def generate_latent(width, height, seed, subseed, subseed_strength, seed_resize_from_h=None, seed_resize_from_w=None):
    shape = [4, height // 8, width // 8]
    rng = ImageRNGNoise(shape=shape, seeds=[seed], subseeds=[subseed], subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w)
    noise = rng.next()
    return {"samples": noise}
latent = generate_latent(1024, 1024, 1, 1, 0, 1024, 1024)

def common_ksampler_with_custom_noise(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent,
                                      denoise=1.0, disable_noise=False, start_step=None, last_step=None,
                                      force_full_denoise=False, noise=None):
    latent_image = latent["samples"]
    if noise is not None:
        rng_noise = noise.next().detach().cpu()
        noise = rng_noise.clone()
    else:
        if disable_noise:
            noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
        else:
            batch_inds = latent["batch_index"] if "batch_index" in latent else None
            from totoro.sample import prepare_noise
            noise = prepare_noise(latent_image, seed, batch_inds)

    noise_mask = None
    if "noise_mask" in latent:
        noise_mask = latent["noise_mask"]

    from totoro.sample import sample as sample_k

    samples = sample_k(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
                       denoise=denoise, disable_noise=disable_noise, start_step=start_step,
                       last_step=last_step,
                       force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=None,
                       disable_pbar=False, seed=seed)
    out = latent.copy()
    out["samples"] = samples

    return (out, )

def get_conds(prompt):
    with torch.inference_mode():
        clip_skip = -2
        if clip_skip != clip_skip or clip.layer_idx != clip_skip:
            clip.layer_idx = clip_skip
            clip.clip_layer(clip_skip)
            clip_skip = clip_skip
        tokens = clip.tokenize(prompt)
        cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
        return [[cond, {"pooled_output": pooled}]]


In [None]:
from PIL import Image
cond = get_conds("duck")
n_cond = get_conds("")
sample = common_ksampler_with_custom_noise(model=model_patcher, seed=0, steps=20, cfg=8.0, sampler_name="euler", scheduler="normal", positive=cond, negative=n_cond,
                                           latent=latent, denoise=1, disable_noise=False, start_step=0, last_step=20, force_full_denoise=True, noise=None)
def decode_sample(sample):
    with torch.inference_mode():
        sample = sample.to(torch.float32)
        vae.first_stage_model.cuda()
        decoded = vae.decode_tiled(sample).detach()
    return decoded
decoded = decode_sample(sample[0]["samples"])
np_array = np.clip(255. * decoded.cpu().numpy(), 0, 255).astype(np.uint8)[0]
image = Image.fromarray(np_array)
image = image.convert("RGB")
# sample[0]["samples"]
image

In [None]:
import totoro.model_management
from totoro.utils import load_torch_file
from lib_layerdiffusion.utils import load_file_from_url
from lib_layerdiffusion.models import TransparentVAEDecoder
from lib_layerdiffusion.enums import StableDiffusionVersion

def decode(samples, images, sd_version: str, sub_batch_size: int):
    """
    sub_batch_size: How many images to decode in a single pass.
    See https://github.com/huchenlei/TotoroUI-layerdiffuse/pull/4 for more
    context.
    """
    vae_transparent_decoder = {}
    sd_version = StableDiffusionVersion(sd_version)
    if sd_version == StableDiffusionVersion.SD1x:
        url = "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_vae_transparent_decoder.safetensors"
        file_name = "layer_sd15_vae_transparent_decoder.safetensors"
    elif sd_version == StableDiffusionVersion.SDXL:
        url = "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/vae_transparent_decoder.safetensors"
        file_name = "vae_transparent_decoder.safetensors"

    if not vae_transparent_decoder.get(sd_version):
        model_path = load_file_from_url(
            url=url, model_dir="/content/model", file_name=file_name
        )
        vae_transparent_decoder[sd_version] = TransparentVAEDecoder(
            load_torch_file(model_path),
            device=totoro.model_management.get_torch_device(),
            dtype=(
                torch.float16
                if totoro.model_management.should_use_fp16()
                else torch.float32
            ),
        )
    pixel = images.movedim(-1, 1)  # [B, H, W, C] => [B, C, H, W]

    # Decoder requires dimension to be 64-aligned.
    B, C, H, W = pixel.shape
    assert H % 64 == 0, f"Height({H}) is not multiple of 64."
    assert W % 64 == 0, f"Height({W}) is not multiple of 64."

    decoded = []
    for start_idx in range(0, samples["samples"].shape[0], sub_batch_size):
        decoded.append(
            vae_transparent_decoder[sd_version].decode_pixel(
                pixel[start_idx : start_idx + sub_batch_size],
                samples["samples"][start_idx : start_idx + sub_batch_size],
            )
        )
    pixel_with_alpha = torch.cat(decoded, dim=0)
    # [B, C, H, W] => [B, H, W, C]
    pixel_with_alpha = pixel_with_alpha.movedim(1, -1)
    image = pixel_with_alpha[..., 1:]
    alpha = pixel_with_alpha[..., 0]
    alpha = 1.0 - alpha
    return (image, alpha)

def resize_mask(mask, shape):
    return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)

def join_image_with_alpha(image: torch.Tensor, alpha: torch.Tensor):
    batch_size = min(len(image), len(alpha))
    out_images = []

    alpha = 1.0 - resize_mask(alpha, image.shape[1:])
    for i in range(batch_size):
        out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))

    result = (torch.stack(out_images),)
    return result

In [None]:
t_image, t_alpha = decode(latent, decoded, StableDiffusionVersion.SDXL, 1)
result = join_image_with_alpha(t_image, t_alpha)

In [None]:
import PIL
PIL.Image.fromarray(np.array(t_image*255, dtype=np.uint8)[0])

In [None]:
import PIL
PIL.Image.fromarray(np.array(t_alpha*255, dtype=np.uint8)[0])

In [None]:
import PIL
PIL.Image.fromarray(np.array(result[0]*255, dtype=np.uint8)[0])