In [None]:
import torch
from SDXL.sdxl_pipeline import AttentionStableDiffusionXLPipeline
from visualization_utils import show_image_and_heatmap, visualize_tokens_attentions

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe = AttentionStableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16"
)
pipe = pipe.to(device)

In [None]:
seed = 2
prompt = "a photo of an astronaut riding a horse on mars"
# prompt = "a puppy sitting on a chair"
image = pipe(prompt, 
             generator=torch.Generator(device=device).manual_seed(seed),
             num_inference_steps=50).images[0]

display(image)

# Show heatmap

In [None]:
agg_attn = pipe.attention_store.aggregate_attention()

tokens_ids = pipe.tokenizer(prompt, padding="max_length", 
                            max_length=pipe.tokenizer.model_max_length, 
                            truncation=True,return_tensors="pt").input_ids[0]
tokens_text = [pipe.tokenizer.decode(x) for x in tokens_ids]

idx_range = (0, 20)
visualize_tokens_attentions(agg_attn.permute(0, 3, 1, 2)[0, idx_range[0]:idx_range[1]], tokens_text[idx_range[0]:idx_range[1]], image, heatmap_interpolation="bilinear")

# Show Distribution

In [None]:
import matplotlib.pyplot as plt


agg_attn = pipe.attention_store.aggregate_attention()
sum_attn_per_token = agg_attn.view(-1, 77).mean(dim=0).cpu()
sum_attn_per_token = sum_attn_per_token / sum_attn_per_token.sum()

tokens_ids = pipe.tokenizer(prompt, padding="max_length", 
                            max_length=pipe.tokenizer.model_max_length, 
                            truncation=True,return_tensors="pt").input_ids[0]
tokens_text = [pipe.tokenizer.decode(x) for x in tokens_ids]

# Show a bar plot of the attention per token
attn_per_token = {f'{t}_{i}': sum_attn_per_token[i] for i, t in enumerate(tokens_text)}

plt.figure(figsize=(25, 5))
plt.bar(attn_per_token.keys(), attn_per_token.values())
plt.xticks(rotation=90)
plt.show()