In [None]:
import matplotlib.pyplot as plt
import torch
from diffusers import (
    AutoencoderKL,
    DiffusionPipeline,
)

from src.hooked_model.hooked_model_sd3 import HookedDiffusionModel
from src.hooked_model.hooks import AblateHook


In [None]:
model_name = "stabilityai/stable-diffusion-3-medium-diffusers"


### How to register ablation hook and use it during the inference

In [None]:
pipe = DiffusionPipeline.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    text_encoder_3=None,
    tokenizer_3=None,
    vae=None,
).to("cuda")

model = pipe.transformer


vae = AutoencoderKL.from_pretrained(
    model_name,
    subfolder="vae",
    torch_dtype=torch.float16,
)


In [None]:
hooked_model = HookedDiffusionModel(
    model=model,
    scheduler=pipe.scheduler,
    vae=vae,
    encode_prompt=pipe.encode_prompt,
)


In [None]:
model.transformer_blocks[0].attn

In [None]:
import re

hookpoints = []
pattern = re.compile(r"transformer_blocks\.(\d+).attn$")
for n, m in model.named_modules():
    match = pattern.match(n)
    if match:
        hookpoints.append(n)
        print(n)


In [None]:
prompts = ["An image of cat" for _ in range(4)]

In [None]:
all_images = []
matrices_to_ablate = ["to_q", "to_k", "to_v", "add_k_proj", "add_v_proj", "add_q_proj"]

for i, hookpoint in enumerate(hookpoints):
    all_hookpoints = {
        f"{hookpoint}.{matrix}": AblateHook() for matrix in matrices_to_ablate
    }
    with torch.no_grad():
        image = hooked_model.run_with_hooks(
            all_hookpoints,
            prompt=prompts,
            num_inference_steps=28,
            guidance_scale=7.0,
            generator=torch.Generator(device="cuda").manual_seed(1),
        )

    all_images.append(image)

In [None]:
def display_images(all_images, hookpoints, images_per_row=4):
    rows = len(all_images)
    fig, axes = plt.subplots(
        rows, images_per_row, figsize=(images_per_row * 3, rows * 3)
    )
    fig.subplots_adjust(hspace=0.5, wspace=0.5)  # Adjust space between rows and columns

    for i, row_images in enumerate(all_images[:rows]):  # Limit to the first `rows`
        for j, image in enumerate(
            row_images[:images_per_row]
        ):  # Limit to `images_per_row`
            ax = axes[i, j] if rows > 1 else axes[j]  # Handle single row case
            ax.imshow(image)
            ax.axis("off")  # Turn off axes for a cleaner look
            if j == 0:
                ax.set_title(hookpoints[i])
    plt.tight_layout()
    plt.show()


display_images(all_images, hookpoints)
