# Instructions
Please run the following cells sequentially
1. Initialize SD Model
2. Add your own image and update ``image_path`` variable. 
3. Feel free to play with DiffSeg hyper-parameters such as the ``KL_THRESHOLD``.

# Import

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
from keras_cv.src.models.stable_diffusion.image_encoder import ImageEncoder
from diffusion_models.stable_diffusion import StableDiffusion
from utils import process_image, augmenter, vis_without_label
from diffseg.segmentor import DiffSeg

from visualisation_utils import plot_attention_location, animate_locations

# !nvidia-smi # Uncomment if you have an NVIDIA GPU

In [None]:
print(f"GPUs available: ", tf.config.experimental.list_physical_devices('GPU'))
device = tf.test.gpu_device_name()
print(tf.test.gpu_device_name())

# Initialize SD Model

In [None]:
# Inialize Stable Diffusion Model on GPU:0
with tf.device(device):
    image_encoder = ImageEncoder()
    vae = tf.keras.Model(
        image_encoder.input,
        image_encoder.layers[-1].output,
    )
    model = StableDiffusion(img_width=512, img_height=512)

# Run inference on real image

In [None]:
# The first time running this cell will be slow because the model needs to download and loads pre-trained weights.

image_path = "./images/polar_bear.jpg"  # Specify the path to your image

timestep = 999

with tf.device(device):
    images = process_image(image_path)
    images = augmenter(images)
    latent = vae(tf.expand_dims(images, axis=0), training=False)
    input_image, output_image, weight_64, weight_32, weight_16, weight_8, _, _, _, _ = model.text_to_image(
        batch_size=1,
        latent=latent,
        timestep=timestep
    )

# Store self-attention maps in a dictionary for resolution-specific access
res2weights = { 8: weight_8, 16: weight_16, 32: weight_32, 64: weight_64 }

### Show the image

In [None]:
# Plotting the input and output image
fig, axs = plt.subplots(1, 2, figsize=(16, 8))
axs[0].imshow(input_image[0])
axs[1].imshow(output_image[0])

axs[0].set_title("Input image")
axs[1].set_title("Output image")

axs[0].axis("off")
axs[1].axis("off")

plt.show()

### Visualise the VAE latents

In [None]:
# Plotting the latents
fig, axs = plt.subplots(1, 4, figsize=(20, 5))  # 1 row, 4 columns

# Loop over each channel
for i in range(4):
    channel = latent[0, :, :, i]
    axs[i].imshow(channel, cmap="gray")
    axs[i].set_title(f"Channel {i+1}")
    axs[i].axis("off")  # Hide axis

fig.suptitle("Latents")
plt.show()

### Visualise self-attention maps

Now we compare two ways of visualising the same self-attention map by reshaping a self-attention tensor `weight_N` (for `N` $\in \{8, 16, 32, 64\}$) in two different ways.

As an example, take a $64 \times 64$ attention map with original tensor shape $(1, 8, 4096, 4096):$

Interpretation $A$ reshapes the $(1, 8, 4096, 4096)$ tensor to $(1, 8, 4096, 64,   64),$ then sums across the $8$ attention heads. The resulting shape of $A$ is $(4096, 64,   64).$

Interpretation $B$ reshapes the $(1, 8, 4096, 4096)$ tensor to $(1, 8, 64,  64, 4096),$ then sums across the $8$ attention heads. The resulting shape of $B$ is $(64,   64, 4096).$

In [None]:
# Change this to a value between 0 and 4095
channel64_idx = 2353

# Change `interpolate` to `False` to see raw pixel data
plot_attention_location(
    res2weights,
    orig_channel_idx=channel64_idx,
    orig_res=64,
    interpolate=False,
    timestep=timestep
);

The following cell renders the previous $2 \times 4$ plot as an animation that iterates over each pixel in the image.

In [None]:
# Uncomment to render animation. This can take a while.
# animate_locations(res2weights, num_frames=64**2, fps=15, interpolate=False)

# Generate Segmentation Masks

In [None]:
KL_THRESHOLD = [0.9]*3 # KL_THRESHOLD controls the merging threshold
NUM_POINTS = 16
REFINEMENT = True


with tf.device(device):
  segmentor = DiffSeg(KL_THRESHOLD, REFINEMENT, NUM_POINTS)
  pred = segmentor.segment(weight_64, weight_32, weight_16, weight_8) # b x 512 x 512

  for i in range(len(images)):
    vis_without_label(pred[i], images[i], num_class=len(set(pred[i].flatten())))