# Instructions
Please run the following cells sequentially
1. Initialize SD Model
2. Add your own image and update ``image_path`` variable. 
3. Feel free to play with DiffSeg hyper-parameters such as the ``KL_THRESHOLD``.

# Import

In [None]:
import os
import tensorflow as tf
from keras_cv.src.models.stable_diffusion.image_encoder import ImageEncoder
from diffusion_models.stable_diffusion import StableDiffusion
from utils import process_image, augmenter, vis_without_label
from diffseg.segmentor import DiffSeg

# !nvidia-smi # Uncomment if you have an NVIDIA GPU

In [None]:
print(f"GPUs available: ", tf.config.experimental.list_physical_devices('GPU'))
device = tf.test.gpu_device_name()
print(tf.test.gpu_device_name())

# 1. Initialize SD Model

In [None]:
# Inialize Stable Diffusion Model on GPU:0 
with tf.device('/GPU:0'):
  image_encoder = ImageEncoder()
  vae=tf.keras.Model(
            image_encoder.input,
            image_encoder.layers[-1].output,
        )
  model = StableDiffusion(img_width=512, img_height=512)

# 2. Run Inference on Real Images

In [None]:
# The first time running this cell will be slow because the model needs to download and loads pre-trained weights.

image_path = "./images/img1.jpeg" # Specify the path to your image

with tf.device('/GPU:0'):
  images = process_image(image_path)
  images = augmenter(images)
  latent = vae(tf.expand_dims(images, axis=0), training=False)
  images, weight_64, weight_32, weight_16, weight_8, _, _, _, _ = model.text_to_image(
    prompt=None,
    batch_size=1,
    latent=latent,
    timestep=300
  )

# 3. Generate Segmentation Masks

In [None]:
KL_THRESHOLD = [0.9]*3 # KL_THRESHOLD controls the merging threshold
NUM_POINTS = 16
REFINEMENT = True


with tf.device('/GPU:0'):
  segmentor = DiffSeg(KL_THRESHOLD, REFINEMENT, NUM_POINTS)
  pred = segmentor.segment(weight_64, weight_32, weight_16, weight_8) # b x 512 x 512

  for i in range(len(images)):
    vis_without_label(pred[i], images[i], num_class=len(set(pred[i].flatten())))