In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
from visualization_utils import visualize_tokens_attentions
from FLUX.flux_pipeline import AttentionFluxPipeline
from FLUX.flux_transformer import FluxTransformer2DModel

FLUX_TYPE = "dev"

if FLUX_TYPE == "dev":
    FLUX_ID = "black-forest-labs/FLUX.1-dev"
elif FLUX_TYPE == "schnell":
    FLUX_ID = "black-forest-labs/FLUX.1-schnell"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer  = FluxTransformer2DModel.from_pretrained(FLUX_ID, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe = AttentionFluxPipeline.from_pretrained(FLUX_ID, 
                                             transformer=transformer, 
                                             torch_dtype=torch.bfloat16)
pipe = pipe.to(device)
# pipe.enable_model_cpu_offload()

In [None]:
# prompt = ["a photo of an astronaut in the beach", "a photo of an astronaut riding a horse in the snow", "a photo of an astronaut in the forest"]
# prompt = ["a photo of a cat in a room", "a photo of a cat in the forest", "a photo of a cat in the snow", "a photo of a cat wearing a hat"]
# prompt = ["a photo of a girl in a room", "a photo of a girl in the forest", "a photo of a girl in the snow", "a photo of a girl wearing a hat", "a photo of a girl in the desert"]
# prompt = ["a photo of an girl in the beach", "a photo of an girl riding a horse in the snow", "a photo of an girl in the forest", "a photo of an girl playing with a cat"]
# prompt = ["a photo of a dog in the beach", "a photo of a dog riding bicycles in the snow", "a photo of a dog in the forest", "a photo of a dog playing with a cat"]
prompt = ["origami style of A dragon, atop an ancient castle", "origami style of A dragon, breathing a plume of fire", "origami style of A dragon, coiling around its hoard", "origami style of A dragon, guarding a treasure trove", "origami style of A dragon, resting in a cavern"]
# prompt = ["A photo of A dog, chasing a frisbee", "A photo of A dog, on a beach", "A photo of A dog, sleeping on a porch", "A photo of A dog, dressed in a superhero cape", "A photo of A dog, sitting by a fireplace"]
# prompt = ["A photo of a robot, in a room", "A photo of a robot, in a forest", "A photo of a robot, in a snow", "A photo of a robot, in a desert", "A photo of a robot, in a cave"]
# prompt = ["A photo of a man in the kitchen", "A photo of a man and a woman sitting in the living room", "A photo of a man and a woman in the garden", "A photo of a woman in the bathroom"]


PROMPT_LENGTH = 512
seed = 2

extended_attn_kwargs = {'t_range': [(0, 25)]}

#["none",  "multi", "multi_even",  "multi_first_half", "multi_second_half", "q1", "q2", "q3", "q4", "single", "single_even",  "single_first_half", "single_second_half", "mix"]:
for single_config, multi_config in [("even", "second_half")]:
    for dropout in [0., 0.25, 0.5]:
        images = pipe(
            prompt=prompt,
            guidance_scale=3.5,
            height=1024,
            width=1024,
            num_inference_steps=30,
            extended_attn_kwargs=extended_attn_kwargs,
            layers_extended_config={'single': single_config, 'multi': multi_config},
            
            generator=torch.Generator(device).manual_seed(seed),
            max_sequence_length=PROMPT_LENGTH,
            dropout_value=dropout,
            same_latents=False,
        ).images

        # Display images in a row
        from PIL import Image
        import matplotlib.pyplot as plt

        # Create a figure with subplots in a row
        fig, axes = plt.subplots(1, len(images), figsize=(20, 5))
        fig.suptitle(f'Layer Config: Single - {single_config}, Multi - {multi_config}, dropout - {dropout}', y=1.05, fontsize=14)

        # If only one image, axes will not be an array
        if len(images) == 1:
            axes = [axes]

        # Display each image
        for idx, (ax, img) in enumerate(zip(axes, images)):
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(f'Image {idx+1}')

        plt.tight_layout()
        plt.show()


# Show Heatmap

In [None]:
agg_attn = pipe.attention_store.aggregate_attention().float().cpu()
tokens_text = [pipe.tokenizer_2.decode(x) for x in pipe.tokenizer_2(prompt, padding="max_length", return_tensors="pt").input_ids[0]]
tokens_text = [f"{x}_{i}" for i, x in enumerate(tokens_text)]

# idx_range = (490, 512)
idx_range = (0,20)
visualize_tokens_attentions(agg_attn[0, idx_range[0]:idx_range[1]], tokens_text[idx_range[0]:idx_range[1]], image, heatmap_interpolation="bilinear")

# Show Distribution

In [None]:
import matplotlib.pyplot as plt

agg_attn = pipe.attention_store.aggregate_attention().float().cpu()
print(agg_attn.shape)
sum_attn_per_token = agg_attn.view(PROMPT_LENGTH,-1).mean(dim=1).cpu()
sum_attn_per_token = sum_attn_per_token / sum_attn_per_token.sum()

tokens_text = [pipe.tokenizer_2.decode(x) for x in pipe.tokenizer_2(prompt, padding="max_length", return_tensors="pt").input_ids[0]][:PROMPT_LENGTH]

# Show a bar plot of the attention per token
attn_per_token = {f'{t}_{i}': sum_attn_per_token[i] for i, t in enumerate(tokens_text)}

plt.figure(figsize=(100, 30))
plt.bar(attn_per_token.keys(), attn_per_token.values())
plt.xticks(rotation=90)
plt.show()