In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
from visualization_utils import visualize_tokens_attentions
from FLUX.flux_pipeline import AttentionFluxPipeline
from FLUX.flux_transformer import FluxTransformer2DModel

FLUX_TYPE = "dev"

if FLUX_TYPE == "dev":
    FLUX_ID = "black-forest-labs/FLUX.1-dev"
elif FLUX_TYPE == "schnell":
    FLUX_ID = "black-forest-labs/FLUX.1-schnell"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer  = FluxTransformer2DModel.from_pretrained(FLUX_ID, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe = AttentionFluxPipeline.from_pretrained(FLUX_ID, 
                                             transformer=transformer, 
                                             torch_dtype=torch.bfloat16)
pipe = pipe.to(device)
# pipe.enable_model_cpu_offload()

In [None]:
prompt = "a photo of an astronaut riding a horse on mars"
PROMPT_LENGTH = 77
seed = 2

if FLUX_TYPE == "dev":
    image = pipe(
        prompt=prompt,
        guidance_scale=3.5,
        height=1024,
        width=1024,
        num_inference_steps=30,
        generator=torch.Generator(device).manual_seed(seed),
        max_sequence_length=PROMPT_LENGTH
    ).images[0]
elif FLUX_TYPE == "schnell":
    image = pipe(
        prompt=prompt,
        guidance_scale=0.0,
        height=1024,
        width=1024,
        num_inference_steps=4,
        generator=torch.Generator(device).manual_seed(seed),
        max_sequence_length=PROMPT_LENGTH
    ).images[0]

display(image)

# Show Heatmap

In [None]:
agg_attn = pipe.attention_store.aggregate_attention().float().cpu()
tokens_text = [pipe.tokenizer_2.decode(x) for x in pipe.tokenizer_2(prompt, padding="max_length", return_tensors="pt").input_ids[0]]
tokens_text = [f"{x}_{i}" for i, x in enumerate(tokens_text)]

# idx_range = (490, 512)
idx_range = (0,20)
visualize_tokens_attentions(agg_attn[0, idx_range[0]:idx_range[1]], tokens_text[idx_range[0]:idx_range[1]], image, heatmap_interpolation="bilinear")

# Show Distribution

In [None]:
import matplotlib.pyplot as plt

agg_attn = pipe.attention_store.aggregate_attention().float().cpu()
print(agg_attn.shape)
sum_attn_per_token = agg_attn.view(PROMPT_LENGTH,-1).mean(dim=1).cpu()
sum_attn_per_token = sum_attn_per_token / sum_attn_per_token.sum()

tokens_text = [pipe.tokenizer_2.decode(x) for x in pipe.tokenizer_2(prompt, padding="max_length", return_tensors="pt").input_ids[0]][:PROMPT_LENGTH]

# Show a bar plot of the attention per token
attn_per_token = {f'{t}_{i}': sum_attn_per_token[i] for i, t in enumerate(tokens_text)}

plt.figure(figsize=(100, 30))
plt.bar(attn_per_token.keys(), attn_per_token.values())
plt.xticks(rotation=90)
plt.show()