# Stable Diffusion v2 Demo with Torch Compile

## Stable Diffusion v2 for Text-to-Image Generation

To start, let's look on Text-to-Image process for Stable Diffusion v2. We will use [Stable Diffusion v2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) model for these purposes. The main difference from Stable Diffusion v2 and Stable Diffusion v2.1 is usage of more data, more training, and less restrictive filtering of the dataset, that gives promising results for selecting wide range of input text prompts. More details about model can be found in [Stability AI blog post](https://stability.ai/blog/stablediffusion2-1-release7-dec-2022) and original model [repository](https://github.com/Stability-AI/stablediffusion).

### Stable Diffusion in Diffusers library
To work with Stable Diffusion v2, we will use Hugging Face [Diffusers](https://github.com/huggingface/diffusers) library. To experiment with Stable Diffusion models, Diffusers exposes the [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/using-diffusers/conditional_image_generation) similar to the [other Diffusers pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview).  The code below demonstrates how to create `StableDiffusionPipeline` using `stable-diffusion-2-1`:

In [None]:
%pip install -q --extra-index-url https://download.pytorch.org/whl/cpu "torch>=2.1,<2.4" "torchvision<0.19.0" "diffusers>=0.18.0" "invisible-watermark>=0.2.0" "transformers>=4.33.0" "accelerate" "onnx" "peft==0.6.2"
%pip install -q "openvino>=2023.1.0" "gradio>=4.19"
%pip install git+https://github.com/anzr299/nncf.git@fx_compress_weights

In [None]:
from torch._export import capture_pre_autograd_graph
from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
import numpy as np

In [None]:
from diffusers import DiffusionPipeline
import torch
import random
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
generator = torch.Generator(device="cpu").manual_seed(42)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.to("cpu")


### Convert Models to Torch Fx Graph

In [8]:
text_encoder_input = torch.ones((1, 77), dtype=torch.long)
text_encoder_2_input = torch.ones((1, 77), dtype=torch.long)
vae_encoder_input = torch.ones((1, 3, 256, 256))
vae_decoder_input = torch.ones((1, 4, 128, 128))
vae_decoder_kwargs = {}
vae_decoder_kwargs["return_dict"] = False

latents_shape = (2, 4, 128, 128)
latents = torch.randn(latents_shape)
t = torch.from_numpy(np.array(1, dtype=np.float32))
added_cond_kwargs = {}
added_cond_kwargs["text_embeds"] = torch.ones((2, 1280))
added_cond_kwargs["time_ids"] = torch.ones((2,6))
unet_kwargs = {}
unet_kwargs["encoder_hidden_states"] = torch.ones((2, 77, 2048))
unet_kwargs["added_cond_kwargs"] = added_cond_kwargs
unet_kwargs["return_dict"] = False
unet_input = (latents, t)

text_encoder_kwargs = {}
text_encoder_kwargs['output_hidden_states'] = True

with torch.no_grad():
    with disable_patching():
        pipe.text_encoder = capture_pre_autograd_graph(pipe.text_encoder.eval(), args=(text_encoder_input,), kwargs=(text_encoder_kwargs))
        pipe.text_encoder_2 = capture_pre_autograd_graph(pipe.text_encoder_2.eval(), args=(text_encoder_2_input,), kwargs=(text_encoder_kwargs))
        pipe.vae.encoder = capture_pre_autograd_graph(pipe.vae.encoder.eval(), args=(vae_encoder_input,))
        pipe.vae.decoder  = capture_pre_autograd_graph(pipe.vae.decoder.eval(), args=(vae_decoder_input,))
        pipe.unet = capture_pre_autograd_graph(pipe.unet.eval(), args=(*unet_input,), kwargs=(unet_kwargs))
del added_cond_kwargs
del unet_kwargs
del unet_input
del latents
del t
del vae_encoder_input
del vae_decoder_input
del text_encoder_2_input
del text_encoder_input
del text_encoder_kwargs
del vae_decoder_kwargs

### Quantization

#### Collect Calibration Dataset

In [None]:
import datasets
import numpy as np
from tqdm.notebook import tqdm
from typing import Any, Dict, List
import torch

def disable_progress_bar(pipeline, disable=True):
    if not hasattr(pipeline, "_progress_bar_config"):
        pipeline._progress_bar_config = {'disable': disable}
    else:
        pipeline._progress_bar_config['disable'] = disable


class UNetWrapper(torch.nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.unet = unet
        self.captured_args = []

    def forward(self, *args, **kwargs):
        if np.random.rand() <= 0.7:
            self.captured_args.append((*args, *tuple(kwargs.values())))
        return self.unet(*args, **kwargs)

def collect_calibration_data(ov_pipe, calibration_dataset_size: int, num_inference_steps: int) -> List[Dict]:
    
    original_unet = ov_pipe.transformer
    calibration_data = []
    disable_progress_bar(ov_pipe)
    
    dataset = datasets.load_dataset("google-research-datasets/conceptual_captions", split="train", trust_remote_code=True).shuffle(seed=42)

    pipe_copy = ov_pipe
    wrapped_unet = UNetWrapper(ov_pipe.transformer)
    pipe_copy.transformer = wrapped_unet
    # Run inference for data collection
    pbar = tqdm(total=calibration_dataset_size)
    for i, batch in enumerate(dataset):
        prompt = batch["caption"]
        print(prompt)
        if len(prompt) > ov_pipe.tokenizer.model_max_length:
            continue
        # Run the pipeline
        ov_pipe(prompt, num_inference_steps=num_inference_steps)
        calibration_data.extend(wrapped_unet.captured_args)
        wrapped_unet.captured_args = []
        pbar.update(len(calibration_data) - pbar.n)
        if pbar.n >= calibration_dataset_size:
            break

    disable_progress_bar(ov_pipe, disable=False)
    pipe_copy.transformer = original_unet
    ov_pipe = pipe_copy
    return calibration_data

In [10]:
def collect_ops_with_weights(graph_module):
    ops_with_weights = []
    for node in graph_module.graph.nodes:
        if "linear" in node.name:
            ops_with_weights.append(node.name)
    return ops_with_weights

calibration_dataset_size = 30
unet_calibration_data = collect_calibration_data(pipe,
                                                    calibration_dataset_size=calibration_dataset_size,
                                                    num_inference_steps=20)
unet_ignored_scope = collect_ops_with_weights(pipe.unet)

In [None]:
import nncf
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.range_estimator import RangeEstimatorParametersSet

def disable_progress_bar(pipeline, disable=True):
    if not hasattr(pipeline, "_progress_bar_config"):
        pipeline._progress_bar_config = {'disable': disable}
    else:
        pipeline._progress_bar_config['disable'] = disable

with disable_patching():
    with torch.no_grad():
        nncf.compress_weights(pipe.text_encoder)
        nncf.compress_weights(pipe.text_encoder_2)
        nncf.compress_weights(pipe.vae.encoder)
        nncf.compress_weights(pipe.vae.decoder)
        quantized_unet = nncf.quantize( #2
            model=pipe.unet,
            calibration_dataset=nncf.Dataset(unet_calibration_data),
            subset_size=len(unet_calibration_data),
            model_type=nncf.ModelType.TRANSFORMER,
            ignored_scope=nncf.IgnoredScope(names=unet_ignored_scope),
            advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=-1, disable_bias_correction=True, weights_range_estimator_params=RangeEstimatorParametersSet.MINMAX, activations_range_estimator_params=RangeEstimatorParametersSet.MINMAX)
        )

### Compile Models with OV Backend

In [17]:
pipe.unet = torch.compile(quantized_unet, backend='openvino')
pipe.text_encoder = torch.compile(pipe.text_encoder, backend='openvino')
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, backend='openvino')
pipe.vae.encoder = torch.compile(pipe.vae.encoder, backend='openvino')
pipe.vae.decoder = torch.compile(pipe.vae.decoder, backend='openvino')

In [None]:
def get_model_size(models):
    total_size = 0
    for model in models:
        param_size = 0
        for param in model.parameters():
            param_size += param.nelement() * param.element_size()
        buffer_size = 0
        for buffer in model.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()

        model_size_mb = (param_size + buffer_size) / 1024**2

        total_size += model_size_mb
    return total_size
print("Unet Model Size: ", get_model_size([pipe.unet]))
print("Pipeline Models Size: ", get_model_size([pipe.unet, pipe.vae.encoder, pipe.vae.decoder, pipe.text_encoder, pipe.text_encoder_2]))

### Inference for Compilation

In [None]:
#Warmup the model for initial compile
prompt = "cute cat 4k, high-res, masterpiece, best quality, soft lighting, dynamic angle"
# prompt = "valley in the Alps at sunset, epic vista, beautiful landscape, 4k, 8k"
negative_prompt = "frames, borderline, text, charachter, duplicate, error, out of frame, watermark, low quality, ugly, deformed, blur"
num_steps = 1
with torch.no_grad():
    image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_steps, generator=generator).images[0]

In [None]:
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
with torch.no_grad():
    image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=25, generator=generator).images[0]
image.show()
image.save("SDXL.png")