In [None]:
import os
import requests
import sys
import copy
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.utils.peft_utils import set_weights_and_activate_adapters
from peft import LoraConfig


# from .model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training

import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
from diffusers import AutoencoderTiny, StableDiffusionPipeline
from diffusers import DDPMScheduler
from diffusers import UNet2DModel
from diffusers import AutoencoderTiny

unet2d_config = {
    "sample_size": 64,
    "in_channels": 4,
    "out_channels": 4,
    "center_input_sample": False,
    "time_embedding_type": "positional",
    "freq_shift": 0,
    "flip_sin_to_cos": True,
    "down_block_types": ("DownBlock2D", "DownBlock2D", "DownBlock2D"),
    "up_block_types": ("UpBlock2D", "UpBlock2D", "UpBlock2D"),
    "block_out_channels": [320, 640, 1280],
    "layers_per_block": 1,
    "mid_block_scale_factor": 1,
    "downsample_padding": 1,
    "downsample_type": "conv",
    "upsample_type": "conv",
    "dropout": 0.0,
    "act_fn": "silu",
    "norm_num_groups": 32,
    "norm_eps": 1e-05,
    "resnet_time_scale_shift": "default",
    "add_attention": False,
}


class Pix2PixLight(torch.nn.Module):
    def __init__(self, dtype=torch.bfloat16):
        super().__init__()
        sched = DDPMScheduler.from_pretrained(
            "stabilityai/sd-turbo",
            subfolder="scheduler",
        )
        sched.set_timesteps(1, device="cuda")
        sched.alphas_cumprod = sched.alphas_cumprod.cuda()
        sched.betas = sched.betas.to(dtype).cuda()
        sched.alphas = sched.alphas.to(dtype).cuda()
        sched.one = sched.one.to(dtype).cuda()
        sched.alphas_cumprod = sched.alphas_cumprod.to(dtype).cuda()
        self.sched = sched

        vae = AutoencoderTiny.from_pretrained(
            "madebyollin/taesd",
            torch_device="cuda",
            torch_dtype=dtype,
        ).cuda()

        vae.decoder.ignore_skip = False
        unet = UNet2DModel(**unet2d_config).to("cuda").to(dtype)

        # vae.decoder.gamma = 1
        self.timesteps = torch.tensor([999], device="cuda").long()
        self.unet = unet
        self.vae = vae

    def set_eval(self):
        self.unet.eval()
        self.vae.eval()
        self.unet.requires_grad_(False)
        self.vae.requires_grad_(False)

    def set_train(self):
        self.unet.train()
        self.vae.train()

    def forward(self, c_t):
        encoded_control = (
            self.vae.encode(c_t, False)[0] * self.vae.config.scaling_factor
        )
        model_pred = self.unet(
            encoded_control,
            self.timesteps,
            return_dict=False,
        )[0]
        x_denoised = self.sched.step(
            model_pred,
            self.timesteps,
            encoded_control,
            return_dict=False,
        )[0]
        output_image = (
            self.vae.decode(
                x_denoised / self.vae.config.scaling_factor,
                return_dict=False,
            )[0]
        ).clamp(-1, 1)

        return output_image

    def save_model(self, outf):
        self.unet.save_pretrained(outf + "unet")
        self.vae.save_pretrained(outf + "vae")


model = Pix2PixLight()

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
The config attributes {'shift_factor': 0.0, 'upsample_fn': 'nearest'} were passed to AutoencoderTiny, but are not expected and will be ignored. Please verify your config.json configuration file.


In [3]:
import time

model.set_eval()
begin = time.time()
steps = 90
for i in range(steps):
    c_t = torch.randn(
        [1, 3, 512, 512],
        device="cuda",
        dtype=torch.bfloat16,
    )
    model(c_t)
total = time.time() - begin
print(f"Inference time: {total}s {total/steps}s per image")

Inference time: 1.1434836387634277s 0.012705373764038085s per image


In [4]:
from sfast.compilers.diffusion_pipeline_compiler import (
    compile,
    compile_unet,
    compile_vae,
)
from sfast.compilers.diffusion_pipeline_compiler import compile, CompilationConfig

config = CompilationConfig.Default()
config.enable_cuda_graph = True
config.enable_triton = True
config.enable_xformers = True
model.vae = compile_vae(model.vae, config)
model.unet = compile_unet(model.unet, config)

In [9]:
import time

begin = time.time()
steps = 90
for i in range(steps):
    c_t = torch.randn(
        [1, 3, 512, 512],
        device="cuda",
        dtype=torch.bfloat16,
    )
    model(c_t)
total = time.time() - begin
print(f"Inference time: {total}s {total/steps}s per image")

Inference time: 0.6635997295379639s 0.0073733303281995986s per image
