Copyright 2023 Google LLC

Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.

# Instructions
1. Download CoCo-Stuff [annotations](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip) and [val images](http://images.cocodataset.org/zips/val2017.zip).
*  Please first download the [annotations](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip) and rename `val2017` to `annotation_val2017`.


2.  Download [Cityscapes](https://www.cityscapes-dataset.com/).
* Cityscapes download requires login.
* Please download `leftImg8bit_trainvaltest.zip` and `gtFine_trainvaltest.zip` to your data folder.

3. Please run the cells in order and choose 2a or 2b, not both.
* 2a: load CoCo-Stuff data.
* 2b: load Cityscapes data.

4. Metrics
* The inference code will return pixel accuracy (ACC) and mean IoU (mIoU).

# Imports

In [None]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm
from diffseg.segmentor import DiffSeg
from keras_cv.src.models.stable_diffusion.image_encoder import ImageEncoder
from third_party.keras_cv.stable_diffusion import StableDiffusion 
from data.cityscapes import cityscapes_data
from data.coco import coco_data
from diffseg.utils import hungarian_matching

!nvidia-smi

# 1. Initialize SD Model

In [None]:
# Initialize Stable Diffusion Model on all GPUs.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
with strategy.scope():
  image_encoder = ImageEncoder()
  vae=tf.keras.Model(
            image_encoder.input,
            image_encoder.layers[-1].output,
        )
  model = StableDiffusion(img_width=512, img_height=512)

# 2a. Load COCO-Stuff Data

In [None]:
ROOT = "../coco_data/" # Change this directory to your coco data folder.
FINE_TO_COARSE_PATH = "./data/coco/fine_to_coarse_dict.pickle"
BATCH_SIZE = strategy.num_replicas_in_sync

# Load fine to coarse label mapping.
fine_to_coarse_map = coco_data.get_fine_to_coarse(FINE_TO_COARSE_PATH)

# Prepare the coco-stuff validation dataset.
file_list = coco_data.load_imdb("./data/coco/Coco164kFull_Stuff_Coarse_7.txt")
image_list, label_list = coco_data.create_path(ROOT, file_list)
val_dataset = coco_data.prepare_dataset(
    image_list, label_list, batch_size=BATCH_SIZE
)

# 2b. Load Cityscapes Data

In [None]:
ROOT = "../cityscapes_data/"
BATCH_SIZE = strategy.num_replicas_in_sync

# Load fine to coarse label mapping.
fine_to_coarse_map = cityscapes_data.get_fine_to_coarse()

# Prepare the cityscapes validation dataset.
image_list, label_list = cityscapes_data.create_path(ROOT)
val_dataset = cityscapes_data.prepare_dataset(
    image_list, label_list, batch_size=BATCH_SIZE
)

# 3. Run Inference

In [None]:
N_CLASS = 27
TP = np.zeros(N_CLASS)
FP = np.zeros(N_CLASS)
FN = np.zeros(N_CLASS)
ALL = 0

# Initialize DiffSeg
KL_THRESHOLD = [1.1]*3 # This controls the merge threshold for masks (1.1 for CoCo-Stuff and 0.9 for Cityscapes)
NUM_POINTS = 16
REFINEMENT = False # Whether use K-Means refinement. Increase inference time from 2s to 3s.

with strategy.scope():
  segmentor = DiffSeg(KL_THRESHOLD, REFINEMENT, NUM_POINTS)

  for i,batch in enumerate(tqdm(val_dataset)):
    images = batch["images"]
    labels = fine_to_coarse_map(batch["labels"][:,:,:,0])
    latent = vae(images, training=False)

    # Extract attention maps from a single iteration of diffusion.
    images, weight_64, weight_32, weight_16, weight_8, _, _, _, _ = model.text_to_image(
      None,
      batch_size=images.shape[0],
      latent=latent,
      timestep=300
    )

    # Segment using DiffSeg.
    pred = segmentor.segment(weight_64, weight_32, weight_16, weight_8) # b x 512 x 512
    
    # Run hungarian matching for evaluation.
    tp, fp, fn, all = hungarian_matching(pred, labels, N_CLASS)
    TP += tp
    FP += fp
    FN += fn
    ALL += all

    # Print accuracy and mean IoU occasionally.
    if (i+1) % 10 == 0:
      acc = TP.sum()/ALL
      iou = TP / (TP + FP + FN)
      miou = np.nanmean(iou)
      print("pixel accuracy:{}, mIoU:{}".format(acc, miou))

# Print final accuracy and mean IoU.
acc = TP.sum()/ALL
iou = TP / (TP + FP + FN)
miou = np.nanmean(iou)
print("final pixel accuracy:{}, mIoU:{}".format(acc, miou))