In [None]:
import matplotlib.pyplot as plt
import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel

import src.hooked_model.scheduler
from src.hooked_model.hooked_model import HookedDiffusionModel
from src.hooked_model.hooks import AblateHook
from src.hooked_model.utils import (
    get_timesteps,
)


In [None]:
model_name = "sd-legacy/stable-diffusion-v1-5"


### How to register ablation hook and use it during the inference

In [None]:
pipe = DiffusionPipeline.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    use_safetensors=True,
).to("cuda")


In [None]:
model = UNet2DConditionModel.from_pretrained(
    model_name,
    subfolder="unet",
    torch_dtype=torch.float16,
    use_safetensors=True,
).to("cuda")


In [None]:
scheduler = src.hooked_model.scheduler.DDIMScheduler.from_config(pipe.scheduler.config)


In [None]:
hooked_model = HookedDiffusionModel(
    model=model,
    scheduler=scheduler,
    encode_prompt=pipe.encode_prompt,
    get_timesteps=get_timesteps,
    vae=pipe.vae,
)


In [None]:
import re

hookpoints = []
pattern = re.compile(r".*\.attentions\.(\d+)$")
for n, m in pipe.unet.named_modules():
    match = pattern.match(n)
    if match:
        hookpoints.append(n)
        print(n)


In [None]:
prompts = ["A photo of an astronaut in Van Gogh style" for _ in range(4)]

In [None]:
all_images = []

for i, hookpoint in enumerate(hookpoints):
    image = hooked_model.run_with_hooks(
        {hookpoint: AblateHook()},
        prompt=prompts,
        num_inference_steps=50,
        guidance_scale=7.5,
        generator=torch.Generator(device="cuda").manual_seed(1),
    )

    all_images.append(image)

In [None]:
def display_images(all_images, hookpoints, images_per_row=4):
    rows = len(all_images)
    fig, axes = plt.subplots(
        rows, images_per_row, figsize=(images_per_row * 3, rows * 3)
    )
    fig.subplots_adjust(hspace=0.5, wspace=0.5)  # Adjust space between rows and columns

    for i, row_images in enumerate(all_images[:rows]):  # Limit to the first `rows`
        for j, image in enumerate(
            row_images[:images_per_row]
        ):  # Limit to `images_per_row`
            ax = axes[i, j] if rows > 1 else axes[j]  # Handle single row case
            ax.imshow(image)
            ax.axis("off")  # Turn off axes for a cleaner look
            if j == 0:
                ax.set_title(hookpoints[i])
    plt.tight_layout()
    plt.show()


display_images(all_images, hookpoints)
