In [None]:
from typing import Dict

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from keras_cv.src.models.stable_diffusion.image_encoder import ImageEncoder
from diffusion_models.stable_diffusion import StableDiffusion
from utils import process_image, augmenter

from visualisation_utils import plot_attention_location, animate_locations
from my_utils import dict_to_disk, dict_from_disk

In [None]:
print(f"GPUs available: ", tf.config.experimental.list_physical_devices('GPU'))
device = tf.test.gpu_device_name()
print(tf.test.gpu_device_name())

# Initialize SD Model

In [None]:
# Inialize Stable Diffusion Model on GPU:0
with tf.device(device):
    image_encoder = ImageEncoder()
    vae = tf.keras.Model(
        image_encoder.input,
        image_encoder.layers[-1].output,
    )
    model = StableDiffusion(img_width=512, img_height=512)

# Run inference on real image

In [None]:
# Run image through VAE encoder
image_path = "./images/img2.jpeg"

with tf.device(device):
    image = process_image(image_path)
    image = augmenter(image)
    latent = vae(tf.expand_dims(image, axis=0), training=False)

In [None]:
# Dictionary of structure { timestep : { resolution : self-attention map } }
self_attn_dict: Dict[int, Dict[int, np.ndarray]] = { }

In [None]:
# Perform one denoising step
num_timesteps = 10

for timestep in np.arange(0, 1000, 1000 // num_timesteps):
    with tf.device(device):
        weight_64, weight_32, weight_16, weight_8 = model.generate_image(
            batch_size=1,
            latent=latent,
            timestep=timestep,
        )

        # Average over attention heads and store self-attention maps for
        # current time step in dictionary
        self_attn_dict[timestep] = {
            8:  weight_8.mean(axis=(0,1)),
            16: weight_16.mean(axis=(0,1)),
            32: weight_32.mean(axis=(0,1)),
            64: weight_64.mean(axis=(0,1))
        }

In [None]:
# Save self-attention maps to disk
dict_to_disk(
    self_attn_dict=self_attn_dict,
    filename="self_attn_maps/car"
)

### Show the image

In [None]:
# Plotting the input and output image
fig, axs = plt.subplots(1, 2, figsize=(16, 8))
axs[0].imshow(input_image[0])
axs[1].imshow(output_image[0])

axs[0].set_title("Input image")
axs[1].set_title("Output image")

axs[0].axis("off")
axs[1].axis("off")

fig.suptitle(f"Time step {timestep}", fontsize=24)
fig.tight_layout()

plt.show()

### Visualise the VAE latents

In [None]:
# Plotting the latents
fig, axs = plt.subplots(1, 4, figsize=(20, 5))  # 1 row, 4 columns

# Loop over each channel
for i in range(4):
    channel = latent[0, :, :, i]
    axs[i].imshow(channel, cmap="gray")
    axs[i].set_title(f"Channel {i+1}")
    axs[i].axis("off")  # Hide axis

fig.suptitle("Latents")
plt.show()

In [None]:
# Change this to a value between 0 and 4095
channel64_idx = 1110

# Change `interpolate` to `False` to see raw pixel data
plot_attention_location(
    self_attn_dict[0],
    orig_channel_idx=channel64_idx,
    orig_res=64,
    interpolate=False,
    timestep=timestep
);

The following cell renders the previous $2 \times 4$ plot as an animation that iterates over each pixel in the image.

In [None]:
# Uncomment to render animation. This can take a while.
# animate_locations(res2weights, num_frames=64**2, fps=15, interpolate=False)

In [None]:
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

def render_attention_animation(
        timestep_range,
        orig_channel_idx=2355,
        orig_res=64,
        interpolate=False,
        save_path='attention_animation.mp4',
        fps: int = 5
    ):
    """
    Renders an animation of attention maps over a range of timesteps.

    Parameters
    ----------
    res2weights      : Dictionary containing { resolution: attention_map } pairs
    timestep_range   : Tuple or list defining the start and end of the timestep range
    orig_channel_idx : Channel index specifying a location in the orig_res map
    orig_res         : Resolution of the attention map in which orig_channel_idx
                       has been chosen
    interpolate      : Boolean deciding whether to render with bicubic upscaling
    save_path        : Path to save the animation

    Returns
    -------
    None
    """
    fig, axs = plt.subplots(2, 4, figsize=(20, 10))
    title = plt.suptitle(t='', fontsize = 20)

    def update_frame(timestep):
        # Clear both subplots
        axs[0, 0].cla()  
        axs[1, 0].cla()

        # Set title
        title.set_text(f"Self-attention maps for timestep {timestep}")

        image_path = "./images/img1.jpeg"  # Specify the path to your image

        # Run inference to obtain self-attention maps for current time step
        with tf.device(device):
            images = process_image(image_path)
            images = augmenter(images)
            latent = vae(tf.expand_dims(images, axis=0), training=False)
            _, _, weight_64, weight_32, weight_16, weight_8, _, _, _, _ = model.text_to_image(
                batch_size=1,
                latent=latent,
                timestep=timestep
            )

        # Store self-attention maps in a dictionary for resolution-specific access
        res2weights = { 8: weight_8, 16: weight_16, 32: weight_32, 64: weight_64 }

        artists = plot_attention_location(
            res2weights,
            orig_channel_idx=orig_channel_idx,
            orig_res=orig_res,
            interpolate=interpolate,
            timestep=timestep,
            fig=fig,
            axs=axs
        )
        return artists

    ani = FuncAnimation(fig, update_frame, frames=timestep_range, blit=True)

    # Show a progress bar while rendering the animation, then save file to disk
    with tqdm(total=len(timestep_range), desc="Saving animation") as pbar:
        ani.save(
            save_path,
            writer='ffmpeg',
            fps=fps,
            progress_callback=lambda i, n: pbar.update()
        )

In [None]:
# Example usage
render_attention_animation(
    list(range(0, 1001, 5)),
    orig_channel_idx=1404,
    save_path=f"hockey_loc{1404}_timesteps.mp4",
    fps=30
)

In [None]:
x = torch.tensor([
        [[2., 2.], [2., 2.]],
        [[3., 3.], [3., 3.]],
        [[4., 4.], [4., 4.]],
        [[5., 5.], [5., 5.]],
        [[6., 6.], [6., 6.]],
]).unsqueeze(1).repeat((1, 4, 1, 1))

scalars = torch.tensor([1/2, 1/3, 1/4, 1/5, 1/6])

x = x * scalars.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
x = x.sum(dim=0)
x.shape

### Compute disk space used for all self-attention maps

In [None]:
bytes_per_float = 4

# 64 x 64
size_64 = 64**4 * num_timesteps * bytes_per_float / 1e9

# 32 x 32
size_32 = 32**4 * num_timesteps * bytes_per_float / 1e9

# 16 x 16
size_16 = 16**4 * num_timesteps * bytes_per_float / 1e9

# 8 x 8
size_8 = 8**4 * num_timesteps * bytes_per_float / 1e9


f"{(size_64 + size_32 + size_16 + size_8):.2f} GB"

### Plot all eight attention heads for a single pixel and time step

In [None]:
t = 900

# Create a 2 x 4 grid of subplots
fig, axs = plt.subplots(2, 4, figsize=(15, 8))
fig.suptitle(f"Attention heads for 64 x 64 map at time step {t}", fontsize=20)

# Loop through heads 0 to 7
for i in range(8):
    # Compute row and column for subplot
    row = i // 4
    col = i % 4

    # Plotting each head
    axs[row, col].imshow(self_attn_dict[t][64][0][i].reshape(64, 64, -1)[:, :, 2700])
    axs[row, col].set_title(f'Head {i}')

plt.tight_layout()
plt.show()