In [1]:
print("Hello World")

Hello World


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from diffusers import StableDiffusionPipeline
from PIL import Image
import cv2

# using a lightweight diffusion model due to lack of time and computational resources
model_id = "stabilityai/stable-diffusion-2-base"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
pipe.to("cpu")  # Use CPU
pipe.enable_attention_slicing()

Fetching 13 files: 100%|██████████| 13/13 [1:13:10<00:00, 337.76s/it]
Loading pipeline components...: 100%|██████████| 6/6 [00:03<00:00,  1.82it/s]


In [None]:
##### Due to the deadline I was only able to load the model and not test the rest of the code
##### The reason is initally I useed another stabe diffusion model and it was too large. I wasted significant amount of time on it
##### Then I resorted to this small model. But even this took nearly two hours to load and hence did not leave me time for the rest of 
##### the code


# Step 2: Hook into attention
attention_store = {}

def save_cross_attention(name):
    def hook(module, input, output):
        if isinstance(output, tuple):  # (hidden_states, attention_weights)
            attention = output[1]      # (batch, heads, tokens, pixels)
            attention_store.setdefault(name, []).append(attention.detach().cpu())
    return hook

def register_attention_hooks(pipe):
    for name, module in pipe.unet.named_modules():
        if "attn2" in name:
            module.register_forward_hook(save_cross_attention(name))

register_attention_hooks(pipe)

In [None]:

# Step 3: Generate image
prompt = "A dog playing with a ball on the beach"
output = pipe(prompt, num_inference_steps=50, guidance_scale=7.5)
image = output.images[0]
image.save("output.png")

# Step 4: Process attention (pick one layer)
def get_token_attention(prompt, layer="down_blocks.2.attentions.1.transformer_blocks.0.attn2"):
    tokenizer = pipe.tokenizer
    text_input = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
    tokens = tokenizer.convert_ids_to_tokens(text_input[0])
    
    attention = torch.cat(attention_store[layer], dim=0)  # (steps, batch, heads, tokens, pixels)
    attention = attention.mean(dim=2)  # average heads -> (steps, batch, tokens, pixels)

    last_step_attention = attention[-1, 0]  # (tokens, pixels)
    return last_step_attention, tokens

# Step 5: Visualize heatmap
def show_attention_map(image, attn_map, token_index, token_label):
    image_np = np.array(image)
    h, w, _ = image_np.shape
    attn = attn_map[token_index].reshape(16, 16)  # pixel layout
    attn = cv2.resize(attn.numpy(), (w, h))
    attn = attn / attn.max()

    heatmap = cv2.applyColorMap(np.uint8(255 * attn), cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(image_np, 0.6, heatmap, 0.4, 0)

    plt.imshow(overlay)
    plt.title(f"Attention for token: {token_label}")
    plt.axis("off")
    plt.show()

# Step 6: Pick token and visualize
attn_map, tokens = get_token_attention(prompt)
for idx, token in enumerate(tokens):
    if token in ["dog", "ball", "beach"]:  # You can pick interactively
        show_attention_map(image, attn_map, idx, token)