[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/LayerDiffuse-jupyter/blob/main/LayerDiffuse_jupyter.ipynb)

In [None]:
!pip install -q torchsde einops diffusers accelerate xformers==0.0.25

%cd /content
!git clone -b totoro https://github.com/camenduru/ComfyUI-layerdiffuse /content/layerdiffuse
%cd /content/layerdiffuse

import os, torch
import numpy as np
from lib_layerdiffusion.utils import load_file_from_url
import totoro.sd
from PIL import Image
import totoro.model_management
from totoro.utils import load_torch_file
from lib_layerdiffusion.models import TransparentVAEDecoder
from lib_layerdiffusion.enums import StableDiffusionVersion
from rng_noise_generator import ImageRNGNoise
import nodes
import nodes_compositing

def generate_latent(width, height, seed, subseed, subseed_strength, seed_resize_from_h=None, seed_resize_from_w=None):
    shape = [4, height // 8, width // 8]
    rng = ImageRNGNoise(shape=shape, seeds=[seed], subseeds=[subseed], subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w)
    noise = rng.next()
    return {"samples": noise}

def get_conds(prompt):
    with torch.inference_mode():
        clip_skip = -2
        if clip_skip != clip_skip or clip.layer_idx != clip_skip:
            clip.layer_idx = clip_skip
            clip.clip_layer(clip_skip)
            clip_skip = clip_skip
        tokens = clip.tokenize(prompt)
        cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
        return [[cond, {"pooled_output": pooled}]]

def decode(samples, images, sd_version: str, sub_batch_size: int):
    vae_transparent_decoder = {}
    sd_version = StableDiffusionVersion(sd_version)
    if sd_version == StableDiffusionVersion.SD1x:
        url = "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/layer_sd15_vae_transparent_decoder.safetensors"
        file_name = "layer_sd15_vae_transparent_decoder.safetensors"
    elif sd_version == StableDiffusionVersion.SDXL:
        url = "https://huggingface.co/LayerDiffusion/layerdiffusion-v1/resolve/main/vae_transparent_decoder.safetensors"
        file_name = "vae_transparent_decoder.safetensors"
    if not vae_transparent_decoder.get(sd_version):
        model_path = load_file_from_url(
            url=url, model_dir="/content/layerdiffuse/model", file_name=file_name
        )
        vae_transparent_decoder[sd_version] = TransparentVAEDecoder(
            load_torch_file(model_path),
            device=totoro.model_management.get_torch_device(),
            dtype=(
                torch.float16
                if totoro.model_management.should_use_fp16()
                else torch.float32
            ),
        )
    pixel = images.movedim(-1, 1)
    B, C, H, W = pixel.shape
    assert H % 64 == 0, f"Height({H}) is not multiple of 64."
    assert W % 64 == 0, f"Height({W}) is not multiple of 64."
    decoded = []
    for start_idx in range(0, samples["samples"].shape[0], sub_batch_size):
        decoded.append(
            vae_transparent_decoder[sd_version].decode_pixel(
                pixel[start_idx : start_idx + sub_batch_size],
                samples["samples"][start_idx : start_idx + sub_batch_size],
            )
        )
    pixel_with_alpha = torch.cat(decoded, dim=0)
    pixel_with_alpha = pixel_with_alpha.movedim(1, -1)
    image = pixel_with_alpha[..., 1:]
    alpha = pixel_with_alpha[..., 0]
    alpha = 1.0 - alpha
    return (image, alpha)

def decode_sample(sample):
    with torch.inference_mode():
        sample = sample.to(torch.float32)
        vae.first_stage_model.cuda()
        decoded = vae.decode_tiled(sample).detach()
    return decoded

ckpt_path = load_file_from_url(url="https://huggingface.co/RunDiffusion/Juggernaut-XL-v9/resolve/main/Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors", 
                               model_dir="/content/layerdiffuse/model", 
                               file_name="Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors")
with torch.no_grad():
  model_patcher, clip, vae, clipvision = totoro.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=None)

In [None]:
latent = generate_latent(1024, 1024, 1, 1, 0, 1024, 1024)
cond = get_conds("toy car")
n_cond = get_conds("")
sample = nodes.common_ksampler(model=model_patcher, 
                          seed=42423, 
                          steps=20, 
                          cfg=8.0, 
                          sampler_name="euler", 
                          scheduler="normal", 
                          positive=cond, 
                          negative=n_cond,
                          latent=latent, 
                          denoise=1, 
                          disable_noise=False, 
                          start_step=0, 
                          last_step=20, 
                          force_full_denoise=True)
decoded = decode_sample(sample[0]["samples"])
t_image, t_alpha = decode(latent, decoded, StableDiffusionVersion.SDXL, 1)
result = nodes_compositing.JoinImageWithAlpha().join_image_with_alpha(t_image, t_alpha)
Image.fromarray(np.array(result[0]*255, dtype=np.uint8)[0])