In [1]:
from diffusers import DiffusionPipeline
import torch


import tvm
from tvm import relax
from tvm.relax.frontend.torch import dynamo_capture_subgraphs
from tvm.relax.frontend.torch import from_fx
from tvm.script import relax as R

import torch
from torch import fx

from web_stable_diffusion import utils
from web_stable_diffusion import trace

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(tvm.__file__)

/home/guoyaol/tvm/python/tvm/__init__.py


# CLIP

In [3]:
def clip_to_text_embeddings(pipe) -> tvm.IRModule:
    # Define the wrapper torch.nn.Module for CLIP.
    class CLIPModelWrapper(torch.nn.Module):
        def __init__(self, clip):
            super().__init__()
            self.clip = clip

        def forward(self, text_input_ids):
            text_embeddings = self.clip(text_input_ids)[0]
            return text_embeddings

    clip = pipe.text_encoder
    clip_to_text_embeddings = CLIPModelWrapper(clip)

    # Create random input (77 is the maximum length).
    text_input_ids = torch.rand((1, 77)).to(torch.int32)
    # Capture CLIP's computational graph.
    mod = dynamo_capture_subgraphs(
        clip_to_text_embeddings.forward,
        text_input_ids,
        keep_params_as_input=True,
    )
    assert len(mod.functions) == 1

    return tvm.IRModule({"clip": mod["subgraph_0"]})

# UNET

In [4]:
def unet_latents_to_noise_pred(pipe, device_str: str) -> tvm.IRModule:
    class UNetModelWrapper(torch.nn.Module):
        def __init__(self, unet):
            super().__init__()
            self.unet = unet
            # Default guidance scale factor in stable diffusion.
            self.guidance_scale = 7.5

        def forward(self, latents, timestep_tensor, text_embeddings, added_cond_kwargs_text_embeds, added_cond_kwargs_text_time_ids):
            # Latent concatenation.
            latent_model_input = torch.cat([latents] * 2, dim=0)
            # UNet forward.
            noise_pred = self.unet(latent_model_input, timestep_tensor, text_embeddings, added_cond_kwargs_text_embeds, added_cond_kwargs_text_time_ids)
            # Classifier-free guidance.
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + self.guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )
            return noise_pred

    unet = utils.get_unet(pipe, device_str)
    unet_to_noise_pred = UNetModelWrapper(unet)
    graph = fx.symbolic_trace(unet_to_noise_pred)
    mod = from_fx(
        graph,
        [((1, 4, 128, 128), "float32"), ((), "int32"), ((2, 77, 2048), "float32"), 
         ((2, 1280), "float32"), ((2, 6), "float32")],
        keep_params_as_input=True,
    )
    return tvm.IRModule({"unet": mod["main"]})

# VAE

In [5]:
def vae_to_image(pipe) -> tvm.IRModule:
    class VAEModelWrapper(torch.nn.Module):
        def __init__(self, vae):
            super().__init__()
            self.vae = vae

        def forward(self, latents):
            # Scale the latents so that it can be decoded by VAE.
            latents = 1 / 0.18215 * latents
            # VAE decode
            z = self.vae.post_quant_conv(latents)
            image = self.vae.decoder(z)
            # Image normalization
            image = (image / 2 + 0.5).clamp(min=0, max=1)
            image = (image.permute(0, 2, 3, 1) * 255).round()
            return image

    vae = utils.get_vae(pipe)
    vae_to_image = VAEModelWrapper(vae)

    # z = torch.rand((1, 4, 64, 64), dtype=torch.float32)
    # mod = dynamo_capture_subgraphs(
    #     vae_to_image.forward,
    #     z,
    #     keep_params_as_input=True,
    # )
    # assert len(mod.functions) == 1

    # return tvm.IRModule({"vae": mod["subgraph_0"]})
    graph = fx.symbolic_trace(vae_to_image)
    mod = from_fx(
        graph,
        [((1, 4, 64, 64), "float32")],
        keep_params_as_input=True,
    )
    return tvm.IRModule({"vae": mod["main"]})

# Sheduler

In [6]:
def euler_discrete_scheduler_steps() -> tvm.IRModule:
    bb = relax.BlockBuilder()

    # step, the function.
    sample = relax.Var("sample", R.Tensor((1, 4, 64, 64), "float32"))
    model_output = relax.Var("model_output", R.Tensor((1, 4, 64, 64), "float32"))
    sigma = relax.Var(f"sigma", R.Tensor((), "float32"))
    sigma_next = relax.Var(f"sigma", R.Tensor((), "float32"))

    with bb.function(
        "euler_discrete_scheduler_step",
        [sample, model_output, sigma, sigma_next],
    ):
        prev_sample = bb.emit(
            sample + model_output * (sigma_next - sigma),
            "prev_sample",
        )
        bb.emit_func_output(prev_sample)

    return bb.get()

In [7]:
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")

torch_dev_key = "cpu"

clip = clip_to_text_embeddings(pipe)
unet = unet_latents_to_noise_pred(pipe, torch_dev_key)
vae = vae_to_image(pipe)
# concat_embeddings = concat_embeddings()
# image_to_rgba = image_to_rgba()
scheduler = euler_discrete_scheduler_steps()

mod: tvm.IRModule = utils.merge_irmodules(
    clip,
    unet,
    vae,
    # concat_embeddings,
    # image_to_rgba,
    scheduler,
)

The config attributes {'add_watermarker': None} were passed to StableDiffusionXLPipeline, but are not expected and will be ignored. Please verify your model_index.json configuration file.
Keyword arguments {'add_watermarker': None} are not expected by StableDiffusionXLPipeline and will be ignored.
The config attributes {'force_upcast': True} were passed to AutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.


enter arange!
enter arange!


In [8]:
mod, params = relax.frontend.detach_params(mod)

In [9]:
mod = relax.pipeline.get_pipeline()(mod)



In [10]:
model_names = ["clip", "unet", "vae"]
scheduler_func_names = ["euler_discrete_scheduler_step"]
entry_funcs = (
    model_names + scheduler_func_names
)

# Clean up unused parts of the IRModule.
mod = relax.transform.DeadCodeElimination(entry_funcs)(mod)

In [11]:
mod = relax.transform.LiftTransformParams()(mod)

In [12]:
mod_transform, mod_deploy = utils.split_transform_deploy_mod(
    mod, model_names, entry_funcs
)

In [13]:
def print_relax_funcnames(mod: tvm.IRModule):
    for global_var, func in mod.functions.items():
        if isinstance(func, relax.Function):
            print(global_var.name_hint)
    print()
    
print("In IRModule for build stage:")
print_relax_funcnames(mod_transform)

print("In IRModule for deployment stage:")
print_relax_funcnames(mod_deploy)

In IRModule for build stage:
unet_transform_params
vae_transform_params
clip_transform_params

In IRModule for deployment stage:
clip
euler_discrete_scheduler_step
unet
vae



In [14]:
# Compute and save the scheduler constants.
# trace.compute_save_scheduler_consts(artifact_path="dist")
# Compute and save the models's weight parameters.
new_params = utils.transform_params(mod_transform, params)
utils.save_params(new_params, artifact_path="dist")

Start storing to cache dist/params
[2016/2016] saving clip_195 
All finished, 188 total shards committed, record saved to dist/params/ndarray-cache.json
Also saved a bf16 record to dist/params/ndarray-cache-b16.json


In [15]:
from tvm import meta_schedule as ms

target = tvm.target.Target("cuda")
device = tvm.cuda()

with target, tvm.transform.PassContext(opt_level=3):
    mod_deploy = tvm.tir.transform.DefaultGPUSchedule()(mod_deploy)

In [16]:
ex = relax.build(mod=mod_deploy, target=target)

In [17]:
ex.export_library("dist/stable_diffusion.so")

# Load Back

In [18]:
# Load the model weight parameters back.
const_params_dict = utils.load_params(artifact_path="dist", device=device)
# Load the model executable back from the shared library.
ex = tvm.runtime.load_module("dist/stable_diffusion.so")

In [19]:
vm = relax.VirtualMachine(rt_mod=ex, device=device)

In [20]:
def wrapper(f, params):
    def wrapped_f(*args):
        return f(*args, params)

    return wrapped_f

In [21]:
run_clip = wrapper(vm["clip"], const_params_dict["clip"])

input = torch.rand((1, 77)).to(torch.int32)
input_nd = tvm.nd.array(input, device=device)

nd_res1 = run_clip(input_nd)

In [22]:
run_unet = wrapper(vm["unet"], const_params_dict["unet"])

input1 = torch.rand((1, 4, 128, 128)).to(torch.float32)
input2 = torch.tensor(3).to(torch.int32)
input3 = torch.rand((2, 77, 2048)).to(torch.float32)
input4 = torch.rand((2, 1280)).to(torch.float32)
input5 = torch.rand((2, 6)).to(torch.float32)


input1_nd = tvm.nd.array(input1, device=device)
input2_nd = tvm.nd.array(input2, device=device)
input3_nd = tvm.nd.array(input3, device=device)
input4_nd = tvm.nd.array(input4, device=device)
input5_nd = tvm.nd.array(input5, device=device)

nd_res2 = run_unet(input1_nd, input2_nd, input3_nd, input4_nd, input5_nd).numpy()

In [None]:
import json
import numpy as np

from web_stable_diffusion import runtime


class DPMSolverMultistepScheduler(runtime.Scheduler):
    scheduler_name = "multistep-dpm-solver"

    def __init__(self, artifact_path: str, device) -> None:
        # Load the scheduler constants.
        with open(
            f"{artifact_path}/scheduler_dpm_solver_multistep_consts.json", "r"
        ) as file:
            jsoncontent = file.read()
        scheduler_consts = json.loads(jsoncontent)

        def f_convert(data, dtype):
            return [tvm.nd.array(np.array(t, dtype=dtype), device) for t in data]

        self.timesteps = f_convert(scheduler_consts["timesteps"], "int32")
        self.alpha = f_convert(scheduler_consts["alpha"], "float32")
        self.sigma = f_convert(scheduler_consts["sigma"], "float32")
        self.c0 = f_convert(scheduler_consts["c0"], "float32")
        self.c1 = f_convert(scheduler_consts["c1"], "float32")
        self.c2 = f_convert(scheduler_consts["c2"], "float32")

        # Initialize the model_output history.
        self.last_model_output: tvm.nd.NDArray = tvm.nd.empty(
            (1, 4, 64, 64), "float32", device
        )

    def step(
        self,
        vm: relax.VirtualMachine,
        model_output: tvm.nd.NDArray,
        sample: tvm.nd.NDArray,
        counter: int,
    ) -> tvm.nd.NDArray:
        # Invoke the functions through VM.
        model_output = vm["dpm_solver_multistep_scheduler_convert_model_output"](
            sample, model_output, self.alpha[counter], self.sigma[counter]
        )
        prev_latents = vm["dpm_solver_multistep_scheduler_step"](
            sample,
            model_output,
            self.last_model_output,
            self.c0[counter],
            self.c1[counter],
            self.c2[counter],
        )
        self.last_model_output = model_output
        return prev_latents