# Stable Diffusion v2 Demo with Torch Compile

## Stable Diffusion v2 for Text-to-Image Generation

To start, let's look on Text-to-Image process for Stable Diffusion v2. We will use [Stable Diffusion v2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) model for these purposes. The main difference from Stable Diffusion v2 and Stable Diffusion v2.1 is usage of more data, more training, and less restrictive filtering of the dataset, that gives promising results for selecting wide range of input text prompts. More details about model can be found in [Stability AI blog post](https://stability.ai/blog/stablediffusion2-1-release7-dec-2022) and original model [repository](https://github.com/Stability-AI/stablediffusion).

### Stable Diffusion in Diffusers library
To work with Stable Diffusion v2, we will use Hugging Face [Diffusers](https://github.com/huggingface/diffusers) library. To experiment with Stable Diffusion models, Diffusers exposes the [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/using-diffusers/conditional_image_generation) similar to the [other Diffusers pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview).  The code below demonstrates how to create `StableDiffusionPipeline` using `stable-diffusion-2-1`:

In [None]:
# %pip install -q "gradio>=4.19" "torch>=2.1"  "transformers" "nncf>=2.12.0" "datasets>=2.14.6" "opencv-python" "pillow" "peft>=0.7.0" --extra-index-url https://download.pytorch.org/whl/cpu
# %pip install -qU "openvino>=2024.3.0"
%pip install git+https://github.com/anzr299/nncf.git@174fb328dd841c61f231361f5d80e89c103f264e

In [1]:
from torch._export import capture_pre_autograd_graph
from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
import numpy as np
import gc

INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, onnx, openvino


In [2]:
from diffusers import StableDiffusion3Pipeline
import torch
import nncf
import random
torch.manual_seed(141)
random.seed(141)
np.random.seed(141)
generator = torch.Generator(device="cpu").manual_seed(141)
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", text_encoder_3=None, tokenizer_3=None)
pipe.to("cpu")

  deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

StableDiffusion3Pipeline {
  "_class_name": "StableDiffusion3Pipeline",
  "_diffusers_version": "0.29.0.dev0",
  "_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
  "scheduler": [
    "diffusers",
    "FlowMatchEulerDiscreteScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModelWithProjection"
  ],
  "text_encoder_2": [
    "transformers",
    "CLIPTextModelWithProjection"
  ],
  "text_encoder_3": [
    null,
    null
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "tokenizer_2": [
    "transformers",
    "CLIPTokenizer"
  ],
  "tokenizer_3": [
    null,
    null
  ],
  "transformer": [
    "diffusers",
    "SD3Transformer2DModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

### FP Inference

In [3]:
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
latents = np.random.randn(1, 16, 128, 128).astype(np.float32)
latents = torch.from_numpy(latents).to("cpu")
generator = torch.Generator(device="cpu").manual_seed(42)
prompt = "valley in the Alps at sunset, epic vista, beautiful landscape, 4k, 8k"
# prompt = 'A raccoon trapped inside a glass jar full of colorful candies, the background is steamy with vivid colors'
# with torch.no_grad():
#     image = pipe(prompt=prompt, negative_prompt='', num_inference_steps=28, generator=generator, guidance_scale=5).images[0]
# image.show()

### Convert Models to Torch Fx Graph

In [4]:
text_encoder_input = torch.ones((1, 77), dtype=torch.long)
text_encoder_kwargs = {}
text_encoder_kwargs['output_hidden_states'] = True

vae_encoder_input = torch.ones((1, 3, 128, 128))
vae_decoder_input = torch.ones((1, 16, 128, 128))

unet_kwargs = {}
unet_kwargs["hidden_states"] = torch.ones((2, 16, 128, 128))
unet_kwargs["timestep"] = torch.from_numpy(np.array([1,2], dtype=np.float32))
unet_kwargs["encoder_hidden_states"] = torch.ones((2, 154, 4096))
unet_kwargs["pooled_projections"] = torch.ones((2, 2048))
unet_kwargs["return_dict"] = False


with torch.no_grad():
    with disable_patching():
        text_encoder = capture_pre_autograd_graph(pipe.text_encoder.eval(), args=(text_encoder_input,), kwargs=(text_encoder_kwargs))
        text_encoder_2 = capture_pre_autograd_graph(pipe.text_encoder_2.eval(), args=(text_encoder_input,), kwargs=(text_encoder_kwargs))
        vae_encoder = capture_pre_autograd_graph(pipe.vae.encoder.eval(), args=(vae_encoder_input,))
        vae_decoder = capture_pre_autograd_graph(pipe.vae.decoder.eval(), args=(vae_decoder_input,))
        transformer = capture_pre_autograd_graph(pipe.transformer.eval(), args=(), kwargs=(unet_kwargs))
del unet_kwargs
del vae_encoder_input
del vae_decoder_input
del text_encoder_input
del text_encoder_kwargs
gc.collect()

90955

### Collect Calibration Dataset

In [5]:
import datasets
import numpy as np
from tqdm.notebook import tqdm
from typing import Any, Dict, List
import torch
import pickle
from pathlib import Path

def disable_progress_bar(pipeline, disable=True):
    if not hasattr(pipeline, "_progress_bar_config"):
        pipeline._progress_bar_config = {'disable': disable}
    else:
        pipeline._progress_bar_config['disable'] = disable


class UNetWrapper(torch.nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.unet = unet
        self.captured_args = []

    def forward(self, *args, **kwargs):
        if np.random.rand() <= 0.7:
            self.captured_args.append((*args, *tuple(kwargs.values())))
        return self.unet(*args, **kwargs)

CALIBRATION_DATASET_PATH = Path('test_sd3')

def collect_calibration_data(ov_pipe, calibration_dataset_size: int, num_inference_steps: int) -> List[Dict]:
    if CALIBRATION_DATASET_PATH.exists():
        original_unet = ov_pipe.transformer
        calibration_data = []
        disable_progress_bar(ov_pipe)
        
        dataset = datasets.load_dataset("google-research-datasets/conceptual_captions", split="train", trust_remote_code=True).shuffle(seed=42)

        wrapped_unet = UNetWrapper(ov_pipe.transformer)
        pipe.transformer = wrapped_unet
        # Run inference for data collection
        pbar = tqdm(total=calibration_dataset_size)
        for i, batch in enumerate(dataset):
            prompt = batch["caption"]
            print(prompt)
            if len(prompt) > ov_pipe.tokenizer.model_max_length:
                continue
            # Run the pipeline
            ov_pipe(prompt, num_inference_steps=num_inference_steps)
            calibration_data.extend(wrapped_unet.captured_args)
            wrapped_unet.captured_args = []
            pbar.update(len(calibration_data) - pbar.n)
            if pbar.n >= calibration_dataset_size:
                break

        disable_progress_bar(ov_pipe, disable=False)
        ov_pipe.transformer = original_unet

        with open("test_sd3", "rb") as fp:
            pickle.dump(unet_calibration_data, fp)

        return calibration_data
    else:
        with open("test_sd3", "rb") as fp:
            unet_calibration_data = pickle.load(fp)
        return unet_calibration_data

In [22]:
with disable_patching():
    with torch.no_grad():
        text_encoder = nncf.compress_weights(text_encoder)
        text_encoder_2 = nncf.compress_weights(text_encoder_2)
        vae_encoder = nncf.compress_weights(vae_encoder)
        vae_decoder = nncf.compress_weights(vae_decoder)

INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│   Num bits (N) │ % all parameters (layers)   │ % ratio-defining parameters (layers)   │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│              8 │ 100% (75 / 75)              │ 100% (75 / 75)                         │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙


Output()

INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│   Num bits (N) │ % all parameters (layers)   │ % ratio-defining parameters (layers)   │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│              8 │ 100% (195 / 195)            │ 100% (195 / 195)                       │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙


Output()

INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│   Num bits (N) │ % all parameters (layers)   │ % ratio-defining parameters (layers)   │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│              8 │ 100% (31 / 31)              │ 100% (31 / 31)                         │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙


Output()

INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│   Num bits (N) │ % all parameters (layers)   │ % ratio-defining parameters (layers)   │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│              8 │ 100% (39 / 39)              │ 100% (39 / 39)                         │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙


Output()

In [6]:
calibration_dataset_size = 210
unet_calibration_data = collect_calibration_data(pipe,
                                                    calibration_dataset_size=calibration_dataset_size,
                                                    num_inference_steps=30)

  0%|          | 0/210 [00:00<?, ?it/s]

a soldier clears the area outside the school .
a man is digging in soil and he is putting the soil into a wheelbarrow .
young couple in love embraces gently against the blue sky and white clouds
person before grooming left side view
photo : hikers on the trail
actor attends the opening night
portrait of a living grasshopper
wind turbine on a wheat field in the summer
olympic athlete competes during the first day
love the low , bare back and full tulle skirt .
automotive industry business at show


In [8]:
with open("test_sd3", "wb") as fp:
            pickle.dump(unet_calibration_data, fp)

In [10]:
import nncf
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.range_estimator import RangeEstimatorParametersSet


with disable_patching():
    with torch.no_grad():
        quantized_transformer = nncf.quantize( #2
            model=transformer,
            calibration_dataset=nncf.Dataset(unet_calibration_data),
            subset_size=len(unet_calibration_data),
            model_type=nncf.ModelType.TRANSFORMER,
            ignored_scope=nncf.IgnoredScope(names=['conv2d']),
            advanced_parameters=nncf.AdvancedQuantizationParameters(weights_range_estimator_params=RangeEstimatorParametersSet.MINMAX, activations_range_estimator_params=RangeEstimatorParametersSet.MINMAX)
        )



Output()

Output()

INFO:nncf:1 ignored nodes were found by names in the NNCFGraph
INFO:nncf:Not adding activation input quantizer for operation: 6 conv2d
8 add__tensor



Output()

Output()

In [11]:
from torch.ao.quantization.fx.utils import create_getattr_from_value
from copy import copy
def update_constant(model: torch.fx.GraphModule, node: torch.fx.Node, value: torch.Tensor, input_port_id: int = 1):
    graph = model.graph
    with graph.inserting_before(node):
        new_constant = create_getattr_from_value(model, graph, node.name + "_updated_constant", value)

    args = list(node.args)
    previous_const = args[input_port_id]
    new_constant.meta = copy(previous_const.meta)
    new_constant.meta["val"] = value

    consumer_nodes = list(previous_const.users.keys())
    args[input_port_id] = new_constant
    for node in consumer_nodes:
        node.replace_input_with(previous_const, new_constant)
    graph.erase_node(previous_const)
    graph.eliminate_dead_code()

In [12]:
import nncf.common
import nncf.common.factory
import nncf.common.graph
import nncf.common.graph.graph
from nncf.experimental.torch.fx.transformations import constant_update_transformation_builder, constant_update_fn

def create_int8_constant(model):
    count = 1
    for node in model.graph.nodes:
        if node.target == torch.ops.quantized_decomposed.quantize_per_channel.default:
            input_tup = []
            for i in node.args:
                if isinstance(i, torch.fx.Node):
                    input_tup.append(getattr(model, i.target))
                else:
                    input_tup.append(i)
            result = node.target(*tuple(input_tup)).type(torch.int8)
            count += 1
            update_constant(model, node, result, 0)
            model = torch.fx.GraphModule(model, model.graph)
            # model.graph.eliminate_dead_code()
            # model.recompile()
    return model

In [13]:
from copy import deepcopy
comp_model = deepcopy(quantized_transformer)
comp_model = create_int8_constant(comp_model)

In [14]:
def pattern(weight, scale, zero_point, mid, low, high, dtype):
    quantized = torch.ops.quantized_decomposed.quantize_per_channel.default(
        weight, scale, zero_point, mid, low, high, dtype
    )
    dequantized = torch.ops.quantized_decomposed.dequantize_per_channel.default(
        quantized, scale, zero_point, mid, low, high, dtype
    )
    return dequantized

def replacement(x, scale, zero_point, mid, low, high, dtype):
    return torch.mul(x, torch.unsqueeze(scale, 1))

def constant_compression_transformation(model: torch.fx.GraphModule, pattern, replacement):
     print(torch.fx.subgraph_rewriter.replace_pattern(model, pattern, replacement))
     return model

with disable_patching():
    # compressed_model = create_int8_constant(quantized_transformer)
    compressed_model = constant_compression_transformation(comp_model, pattern, replacement)
    compressed_model.graph.eliminate_dead_code()
    compressed_model.recompile()


[Match(anchor=dequantize_per_channel_default, nodes_map={dequantize_per_channel_default: dequantize_per_channel_default, quantize_per_channel_default: quantize_per_channel_default, weight: quantize_per_channel_default_updated_constant0, scale: linear_scale_0, zero_point: linear_zero_point_0, mid: 0, low: -128, high: 127, dtype: torch.int8}), Match(anchor=dequantize_per_channel_default, nodes_map={dequantize_per_channel_default: dequantize_per_channel_default_1, quantize_per_channel_default: quantize_per_channel_default_1, weight: quantize_per_channel_default_1_updated_constant0, scale: linear_1_scale_0, zero_point: linear_1_zero_point_0, mid: 0, low: -128, high: 127, dtype: torch.int8}), Match(anchor=dequantize_per_channel_default, nodes_map={dequantize_per_channel_default: dequantize_per_channel_default_2, quantize_per_channel_default: quantize_per_channel_default_2, weight: quantize_per_channel_default_2_updated_constant0, scale: linear_2_scale_0, zero_point: linear_2_zero_point_0, m

In [23]:
pipe.text_encoder = torch.compile(text_encoder, backend='openvino')
pipe.text_encoder_2 = torch.compile(text_encoder_2, backend='openvino')
pipe.vae.encoder = torch.compile(vae_encoder, backend='openvino')
pipe.vae.decoder = torch.compile(vae_decoder, backend='openvino')
pipe.transformer = torch.compile(compressed_model, backend='openvino')

### Inference for Compilation

In [25]:
#Warmup the model for initial compile
prompt = "valley in the Alps at sunset, epic vista, beautiful landscape, 4k, 8k"
negative_prompt = "frames, borderline, text, charachter, duplicate, error, out of frame, watermark, low quality, ugly, deformed, blur"
num_steps = 1
with torch.no_grad():
    image = pipe(prompt=prompt, negative_prompt='', num_inference_steps=num_steps, generator=generator).images[0]

  0%|          | 0/1 [00:00<?, ?it/s]

In [27]:
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
generator = torch.Generator(device="cpu").manual_seed(seed)
latents = np.random.randn(1, 16, 128, 128).astype(np.float32)
latents = torch.from_numpy(latents).to("cpu")
prompt = "valley in the Alps at sunset, epic vista, beautiful landscape, 4k, 8k"
# prompt = 'A raccoon trapped inside a glass jar full of colorful candies, the background is steamy with vivid colors'
with torch.no_grad():
    image = pipe(prompt=prompt, negative_prompt='', num_inference_steps=28, generator=generator, guidance_scale=5, latents=latents).images[0]
image.show()
image.save("stable_diffusion_3_unet_compressed_compiled_prompt1.png")

  0%|          | 0/28 [00:00<?, ?it/s]

In [33]:
from nncf.common.factory import NNCFGraphFactory
NNCFGraphFactory.create(quantized_transformer).visualize_graph('graph.dot')

In [None]:
count = 0
for node in comp_model.graph.nodes:
    if node.op == 'get_attr':
        node_value = getattr(comp_model, node.target)
        if node.name == 'linear_13_updated_constant0':
            print(node_value.dtype, node_value.shape)

print(count)

In [24]:
def get_model_size(models):
    total_size = 0
    ele_count = 0
    for model in models:
        param_size = 0
        for param in model.parameters():
            param_size += param.nelement() * param.element_size()
        buffer_size = 0
        for buffer in model.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()
            ele_count+=1

        model_size_mb = (param_size + buffer_size) / 1024**2
        total_size += model_size_mb
    return total_size
print("Transformer Size:")
print(get_model_size([pipe.transformer]))
print("Pipeline Size:")
get_model_size([pipe.transformer, pipe.vae.encoder, pipe.vae.decoder, pipe.text_encoder, pipe.text_encoder_2])

Transformer Size:
2162.203857421875
Pipeline Size:


3026.265887260437