In [1]:
from diffusers import DiffusionPipeline
import torch


import tvm
from tvm import relax
from tvm.relax.frontend.torch import dynamo_capture_subgraphs
from tvm.relax.frontend.torch import from_fx
from tvm.script import relax as R

import torch
from torch import fx

from web_stable_diffusion import utils
from web_stable_diffusion import trace

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
print(tvm.__file__)

# CLIP

In [None]:
def clip_to_text_embeddings(pipe) -> tvm.IRModule:
    # Define the wrapper torch.nn.Module for CLIP.
    class CLIPModelWrapper(torch.nn.Module):
        def __init__(self, clip):
            super().__init__()
            self.clip = clip

        def forward(self, text_input_ids):
            result = self.clip(text_input_ids, output_hidden_states=True)
            text_embeddings = result.hidden_states[-2]
            pool_text_embeddings = result[0]
            return text_embeddings, pool_text_embeddings

    clip = pipe.text_encoder
    clip_to_text_embeddings = CLIPModelWrapper(clip)

    # Create random input (77 is the maximum length).
    text_input_ids = torch.rand((1, 77)).to(torch.int32)
    # Capture CLIP's computational graph.
    mod = dynamo_capture_subgraphs(
        clip_to_text_embeddings.forward,
        text_input_ids,
        keep_params_as_input=True,
    )
    assert len(mod.functions) == 1

    return tvm.IRModule({"clip": mod["subgraph_0"]})

# UNET

In [None]:
def unet_latents_to_noise_pred(pipe, device_str: str) -> tvm.IRModule:
    class UNetModelWrapper(torch.nn.Module):
        def __init__(self, unet):
            super().__init__()
            self.unet = unet
            # Default guidance scale factor in stable diffusion.
            self.guidance_scale = 7.5

        def forward(self, latents, timestep_tensor, text_embeddings, added_cond_kwargs_text_embeds, added_cond_kwargs_text_time_ids):
            # Latent concatenation.
            latent_model_input = torch.cat([latents] * 2, dim=0)
            # UNet forward.
            noise_pred = self.unet(latent_model_input, timestep_tensor, text_embeddings, added_cond_kwargs_text_embeds, added_cond_kwargs_text_time_ids)
            # Classifier-free guidance.
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + self.guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )
            return noise_pred

    unet = utils.get_unet(pipe, device_str)
    unet_to_noise_pred = UNetModelWrapper(unet)
    graph = fx.symbolic_trace(unet_to_noise_pred)
    mod = from_fx(
        graph,
        [((1, 4, 128, 128), "float32"), ((), "int32"), ((2, 77, 2048), "float32"), 
         ((2, 1280), "float32"), ((2, 6), "float32")],
        keep_params_as_input=True,
    )
    return tvm.IRModule({"unet": mod["main"]})

# VAE

In [None]:
def vae_to_image(pipe) -> tvm.IRModule:
    class VAEModelWrapper(torch.nn.Module):
        def __init__(self, vae):
            super().__init__()
            self.vae = vae

        def forward(self, latents):
            # Scale the latents so that it can be decoded by VAE.
            latents = 1 / 0.18215 * latents
            # VAE decode
            z = self.vae.post_quant_conv(latents)
            image = self.vae.decoder(z)
            # Image normalization
            image = (image / 2 + 0.5).clamp(min=0, max=1)
            image = (image.permute(0, 2, 3, 1) * 255).round()
            return image

    vae = utils.get_vae(pipe)
    vae_to_image = VAEModelWrapper(vae)

    # z = torch.rand((1, 4, 64, 64), dtype=torch.float32)
    # mod = dynamo_capture_subgraphs(
    #     vae_to_image.forward,
    #     z,
    #     keep_params_as_input=True,
    # )
    # assert len(mod.functions) == 1

    # return tvm.IRModule({"vae": mod["subgraph_0"]})
    graph = fx.symbolic_trace(vae_to_image)
    mod = from_fx(
        graph,
        [((1, 4, 64, 64), "float32")],
        keep_params_as_input=True,
    )
    return tvm.IRModule({"vae": mod["main"]})

# Concat Embeddings

In [None]:
def concat_embeddings() -> tvm.IRModule:
    bb = relax.BlockBuilder()
    cond_embeddings = relax.Var("cond_embeddings", R.Tensor([1, 77, 768], "float32"))
    uncond_embeddings = relax.Var(
        "uncond_embeddings", R.Tensor([1, 77, 768], "float32")
    )
    with bb.function("concat_embeddings", [cond_embeddings, uncond_embeddings]):
        res = bb.emit(
            relax.op.concat([cond_embeddings, uncond_embeddings], axis=0)
        )
        bb.emit_func_output(res)
    return bb.get()

# Image to rgba

In [None]:
def image_to_rgba() -> tvm.IRModule:
    from tvm import te

    def f_image_to_rgba(A):
        def fcompute(y, x):
            return (
                A[0, y, x, 0].astype("uint32")
                | (A[0, y, x, 1].astype("uint32") << 8)
                | (A[0, y, x, 2].astype("uint32") << 16)
                | tvm.tir.const(255 << 24, "uint32")
            )

        return te.compute((1024, 1024), fcompute, name="image_to_rgba")

    bb = relax.BlockBuilder()
    x = relax.Var("x", R.Tensor([1, 1024, 1024, 3], "float32"))
    with bb.function("image_to_rgba", [x]):
        image = bb.emit(
            bb.call_te(f_image_to_rgba, x, primfunc_name_hint="tir_image_to_rgba")
        )
        bb.emit_func_output(image)
    return bb.get()

# Sheduler

In [None]:
def euler_discrete_scheduler_steps() -> tvm.IRModule:
    bb = relax.BlockBuilder()

    # step, the function.
    sample = relax.Var("sample", R.Tensor((1, 4, 64, 64), "float32"))
    model_output = relax.Var("model_output", R.Tensor((1, 4, 64, 64), "float32"))
    sigma = relax.Var(f"sigma", R.Tensor((), "float32"))
    sigma_next = relax.Var(f"sigma", R.Tensor((), "float32"))

    with bb.function(
        "euler_discrete_scheduler_step",
        [sample, model_output, sigma, sigma_next],
    ):
        prev_sample = bb.emit(
            sample + model_output * (sigma_next - sigma),
            "prev_sample",
        )
        bb.emit_func_output(prev_sample)

    return bb.get()

In [None]:
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")

torch_dev_key = "cpu"

clip = clip_to_text_embeddings(pipe)
unet = unet_latents_to_noise_pred(pipe, torch_dev_key)
vae = vae_to_image(pipe)
concat_embeddings = concat_embeddings()
image_to_rgba = image_to_rgba()
scheduler = euler_discrete_scheduler_steps()

mod: tvm.IRModule = utils.merge_irmodules(
    clip,
    unet,
    vae,
    concat_embeddings,
    image_to_rgba,
    scheduler,
)

In [None]:
mod, params = relax.frontend.detach_params(mod)

In [None]:
mod = relax.pipeline.get_pipeline()(mod)

In [None]:
model_names = ["clip", "unet", "vae"]
scheduler_func_names = ["euler_discrete_scheduler_step"]
entry_funcs = (
    model_names + scheduler_func_names  + ["image_to_rgba", "concat_embeddings"]
)

# Clean up unused parts of the IRModule.
mod = relax.transform.DeadCodeElimination(entry_funcs)(mod)

In [None]:
mod = relax.transform.LiftTransformParams()(mod)

In [None]:
mod_transform, mod_deploy = utils.split_transform_deploy_mod(
    mod, model_names, entry_funcs
)

In [None]:
def print_relax_funcnames(mod: tvm.IRModule):
    for global_var, func in mod.functions.items():
        if isinstance(func, relax.Function):
            print(global_var.name_hint)
    print()
    
print("In IRModule for build stage:")
print_relax_funcnames(mod_transform)

print("In IRModule for deployment stage:")
print_relax_funcnames(mod_deploy)

In [None]:
# Compute and save the scheduler constants.

# trace.compute_save_scheduler_consts(artifact_path="dist")
#TODO: add this compute

# Compute and save the models's weight parameters.
new_params = utils.transform_params(mod_transform, params)
utils.save_params(new_params, artifact_path="dist")

In [None]:
from tvm import meta_schedule as ms

target = tvm.target.Target("cuda")
device = tvm.cuda()

with target, tvm.transform.PassContext(opt_level=3):
    mod_deploy = tvm.tir.transform.DefaultGPUSchedule()(mod_deploy)

In [None]:
ex = relax.build(mod=mod_deploy, target=target)

In [None]:
ex.export_library("dist/stable_diffusion.so")

# Load Back

In [3]:
# Load the model weight parameters back.
target = tvm.target.Target("cuda")
device = tvm.cuda()
const_params_dict = utils.load_params(artifact_path="dist", device=device)
# Load the model executable back from the shared library.
ex = tvm.runtime.load_module("dist/stable_diffusion.so")

In [4]:
vm = relax.VirtualMachine(rt_mod=ex, device=device)

In [5]:
def wrapper(f, params):
    def wrapped_f(*args):
        return f(*args, params)

    return wrapped_f

In [6]:
import json
import numpy as np

from web_stable_diffusion import runtime


class EulerDiscreteScheduler(runtime.Scheduler):
    scheduler_name = "euler-discrete-solver"

    def __init__(self, artifact_path: str, device) -> None:
        with open(
            f"{artifact_path}/scheduler_euler_discrete_consts.json", "r"
        ) as file:
            jsoncontent = file.read()
        scheduler_consts = json.loads(jsoncontent)

        def f_convert(data, dtype):
            return [tvm.nd.array(np.array(t, dtype=dtype), device) for t in data]

        self.timesteps = f_convert(scheduler_consts["timesteps"], "int32")
        self.sigma = f_convert(scheduler_consts["sigma"], "float32")

        # self.last_model_output: tvm.nd.NDArray = tvm.nd.empty(
        #     (1, 4, 64, 64), "float32", device
        # )

    def step(
        self,
        vm: relax.VirtualMachine,
        model_output: tvm.nd.NDArray,
        sample: tvm.nd.NDArray,
        counter: int,
    ) -> tvm.nd.NDArray:
        # model_output = vm["dpm_solver_multistep_scheduler_convert_model_output"](
        #     sample, model_output, self.alpha[counter], self.sigma[counter]
        # )
        prev_latents = vm["euler_discrete_scheduler_step"](
            sample,
            model_output,
            self.sigma[counter],
            self.sigma[counter+1]
        )
        # self.last_model_output = model_output
        return prev_latents

# Stable Diffusion XL pipeline

In [7]:
from PIL import Image
from tqdm import tqdm
from transformers import CLIPTokenizer


class TVMSDPipeline:
    def __init__(
        self,
        vm: relax.VirtualMachine,
        tokenizer: CLIPTokenizer,
        scheduler: runtime.Scheduler,
        tvm_device,
        param_dict,
    ):
        def wrapper(f, params):
            def wrapped_f(*args):
                return f(*args, params)

            return wrapped_f

        self.vm = vm
        self.clip_to_text_embeddings = wrapper(vm["clip"], param_dict["clip"])
        self.unet_latents_to_noise_pred = wrapper(vm["unet"], param_dict["unet"])
        self.vae_to_image = wrapper(vm["vae"], param_dict["vae"])
        self.concat_embeddings = vm["concat_embeddings"]
        self.image_to_rgba = vm["image_to_rgba"]
        self.tokenizer = tokenizer
        self.scheduler = scheduler
        self.tvm_device = tvm_device
        self.param_dict = param_dict

    def __call__(self, prompt: str, negative_prompt: str = ""):
        # The height and width are fixed to 512.

        # Compute the embeddings for the prompt and negative prompt.
        list_text_embeddings = []
        for text in [negative_prompt, prompt]:
            text = [text]
            # Tokenize the text.
            text_inputs = self.tokenizer(
                text,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,  # 77
                return_tensors="pt",
            )
            text_input_ids = text_inputs.input_ids.to(torch.int32)
            # Clip the text if the length exceeds the maximum allowed length.
            if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
                text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]

            # Compute text embeddings.
            text_input_ids = tvm.nd.array(text_input_ids.cpu().numpy(), self.tvm_device)
            clip_output = self.clip_to_text_embeddings(text_input_ids)
            text_embeddings = clip_output[0]
            pooled_prompt_embeds = clip_output[1]

            list_text_embeddings.append(text_embeddings)

        #TODO convert data to tvm.nd.array, fold into TVM
        torch_template = torch.from_numpy(pooled_prompt_embeds.asnumpy())
        negative_pooled_prompt_embeds = torch.zeros_like(torch_template)
        negative_pooled_prompt_embeds = tvm.nd.array(negative_pooled_prompt_embeds, self.tvm_device)
        pooled_list_text_embeddings = [negative_pooled_prompt_embeds, pooled_prompt_embeds]

        
        # Concatenate the text embeddings.
        text_embeddings = self.concat_embeddings(*list_text_embeddings)

        add_text_embeds = self.concat_embeddings(*pooled_list_text_embeddings)
        print(add_text_embeds.shape)

        #TODO: check correct, fold into TVM
        add_time_ids = torch.tensor([[1024., 1024., 0., 0., 1024., 1024.],[1024., 1024., 0., 0., 1024., 1024.]], dtype=torch.float32)
        add_time_ids = tvm.nd.array(add_time_ids, self.tvm_device)


        # Randomly initialize the latents.
        latents = torch.randn(
            (1, 4, 128, 128),
            device="cpu",
            dtype=torch.float32,
        )
        latents = tvm.nd.array(latents.numpy(), self.tvm_device)

        # UNet iteration.
        for i in tqdm(range(len(self.scheduler.timesteps))):
            #TODO: add this
            #latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
            t = self.scheduler.timesteps[i]
            noise_pred = self.unet_latents_to_noise_pred(latents, t, text_embeddings, add_text_embeds, add_time_ids)
            latents = self.scheduler.step(self.vm, noise_pred, latents, i)

        # VAE decode.
        image = self.vae_to_image(latents)

        # Transform generated image to RGBA mode.
        image = self.image_to_rgba(image)
        return Image.fromarray(image.numpy().view("uint8").reshape(1024, 1024, 4))

In [8]:
pipe = TVMSDPipeline(
    vm=vm,
    tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14"),
    scheduler=runtime.EulerDiscreteScheduler(artifact_path="dist", device=device),
    tvm_device=device,
    param_dict=const_params_dict,
)

In [9]:
import time

prompt = "Jellyfish floating in a forest"

start = time.time()
image = pipe(prompt)
end = time.time()

print(f"Time elapsed: {end - start} seconds.")

(2, 77, 768)


  0%|          | 0/50 [00:00<?, ?it/s]


ValueError: Traceback (most recent call last):
  [bt] (8) /home/guoyaol/tvm/build/libtvm.so(tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)+0x1b3) [0x7f02a71028c3]
  [bt] (7) /home/guoyaol/tvm/build/libtvm.so(+0x5146b8a) [0x7f02a7105b8a]
  [bt] (6) /home/guoyaol/tvm/build/libtvm.so(tvm::runtime::relax_vm::VirtualMachineImpl::InvokeBytecode(long, std::vector<tvm::runtime::TVMRetValue, std::allocator<tvm::runtime::TVMRetValue> > const&)+0x2b6) [0x7f02a71056a6]
  [bt] (5) /home/guoyaol/tvm/build/libtvm.so(tvm::runtime::relax_vm::VirtualMachineImpl::RunLoop()+0x2bd) [0x7f02a710392d]
  [bt] (4) /home/guoyaol/tvm/build/libtvm.so(tvm::runtime::relax_vm::VirtualMachineImpl::RunInstrCall(tvm::runtime::relax_vm::VMFrame*, tvm::runtime::relax_vm::Instruction)+0xb2e) [0x7f02a710475e]
  [bt] (3) /home/guoyaol/tvm/build/libtvm.so(tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)+0x65) [0x7f02a7102775]
  [bt] (2) /home/guoyaol/tvm/build/libtvm.so(tvm::runtime::relax_vm::CheckTensorInfo(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)+0x5a9) [0x7f02a70c89c9]
  [bt] (1) /home/guoyaol/tvm/build/libtvm.so(tvm::runtime::detail::LogFatal::Entry::Finalize()+0x3d) [0x7f02a48e358d]
  [bt] (0) /home/guoyaol/tvm/build/libtvm.so(tvm::runtime::Backtrace[abi:cxx11]()+0x2c) [0x7f02a707af5c]
  File "/home/guoyaol/tvm/src/runtime/relax_vm/builtin.cc", line 177
ValueError: Check failed: (ptr->dl_tensor.ndim == ndim) is false: ErrorContext(fn=unet, loc=param[3], param=inp_3, annotation=R.Tensor((2, 1280), dtype="float32"))  expect Tensor with ndim 2 but get 3