In [None]:
from typing import Dict

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from keras_cv.src.models.stable_diffusion.image_encoder import ImageEncoder
from diffusion_models.stable_diffusion import StableDiffusion
from utils import process_image, augmenter

from visualisation_utils import plot_attention_location, animate_locations
from my_utils import dict_to_disk, dict_from_disk

In [None]:
print(f"GPUs available: ", tf.config.experimental.list_physical_devices('GPU'))
device = tf.test.gpu_device_name()
print(tf.test.gpu_device_name())

# Initialize SD Model

In [None]:
# Inialize Stable Diffusion Model on GPU:0
with tf.device(device):
    image_encoder = ImageEncoder()
    vae = tf.keras.Model(
        image_encoder.input,
        image_encoder.layers[-1].output,
    )
    model = StableDiffusion(img_width=512, img_height=512)

# Get latents for input image

In [None]:
image_path = "./images/img2.jpeg"

with tf.device(device):
    input_image = process_image(image_path)
    input_augmented = augmenter(input_image)
    latent = vae(tf.expand_dims(input_augmented, axis=0), training=False)

### Show the image

In [None]:
input_image = tf.cast(input_image, tf.int32)
plt.imshow(input_image)
plt.axis('off');

# Run inference on input image

In [None]:
# Dictionary of structure { timestep : { resolution : self-attention map } }
self_attn_dict: Dict[int, Dict[int, np.ndarray]] = { }

In [None]:
# Perform one denoising step
num_timesteps = 10

for timestep in np.arange(0, 1000, 1000 // num_timesteps):
    with tf.device(device):
        weight_64, weight_32, weight_16, weight_8 = model.generate_image(
            batch_size=1,
            latent=latent,
            timestep=timestep,
        )

        # Average over attention heads and store self-attention maps for
        # current time step in dictionary
        self_attn_dict[timestep] = {
            8:  weight_8.mean(axis=(0,1)),
            16: weight_16.mean(axis=(0,1)),
            32: weight_32.mean(axis=(0,1)),
            64: weight_64.mean(axis=(0,1))
        }

In [None]:
# Save self-attention maps to disk
dict_to_disk(
    self_attn_dict=self_attn_dict,
    filename="self_attn_maps/img2"
)

### Visualise the VAE latents

In [None]:
# Plotting the latents
fig, axs = plt.subplots(1, 4, figsize=(20, 5))  # 1 row, 4 columns

# Loop over each channel
for i in range(4):
    channel = latent[0, :, :, i]
    axs[i].imshow(channel, cmap="gray")
    axs[i].set_title(f"Channel {i+1}")
    axs[i].axis("off")  # Hide axis

fig.suptitle("Latents")
plt.show()

In [None]:
self_attn_dict[0][32].shape

In [None]:
# Change this to a value between 0 and 4095
channel64_idx = 1140
t = 900

# Change `interpolate` to `False` to see raw pixel data
plot_attention_location(
    self_attn_dict[t],
    orig_channel_idx=channel64_idx,
    orig_res=64,
    interpolate=False,
    timestep=timestep
);

The following cell renders the previous $2 \times 4$ plot as an animation that iterates over each pixel in the image.

In [None]:
# Uncomment to render animation. This can take a while.
# animate_locations(self_attn_dict[t], num_frames=64**2, fps=15, interpolate=False)

In [None]:
# Example usage
# render_attention_animation(
#     list(range(0, 1001, 5)),
#     orig_channel_idx=1404,
#     save_path=f"hockey_loc{1404}_timesteps.mp4",
#     fps=30
# )

### Plot all eight attention heads for a single pixel and time step

In [None]:
# t = 900

# # Create a 2 x 4 grid of subplots
# fig, axs = plt.subplots(2, 4, figsize=(15, 8))
# fig.suptitle(f"Attention heads for 64 x 64 map at time step {t}", fontsize=20)

# # Loop through heads 0 to 7
# for i in range(8):
#     # Compute row and column for subplot
#     row = i // 4
#     col = i % 4

#     # Plotting each head
#     axs[row, col].imshow(self_attn_dict[t][64][0][i].reshape(64, 64, -1)[:, :, 2700])
#     axs[row, col].set_title(f'Head {i}')

# plt.tight_layout()
# plt.show()