# Installation & Set-up

First, install for Detectron2 PointRend Segmentation

In [1]:
# install dependencies: 
!pip install pyyaml==5.1
# check pytorch installation: 
import torch, torchvision
print(torch.__version__, torch.cuda.is_available())

# !pip uninstall ipykernel ipython traitlets ipython_genutils
# !pip install ipykernel ipython traitlets ipython_genutils
from google.colab import drive, output

1.11.0+cu113 True


In [2]:
# clone the repo in order to access pre-defined configs in PointRend project
!git clone --branch v0.6 https://github.com/facebookresearch/detectron2.git detectron2_repo
# install detectron2 from source
!pip install -e detectron2_repo
# See https://detectron2.readthedocs.io/tutorials/install.html for other installation options
output.clear()

Restart runtime once, then continue installing

In [3]:
# You may need to restart your runtime prior to this, to let your installation take effect
# Some basic setup:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import cv2
import torch
from google.colab.patches import cv2_imshow
from PIL import Image

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import MetadataCatalog
coco_metadata = MetadataCatalog.get("coco_2017_val")
seg_metadata = MetadataCatalog.get("cityscapes_fine_instance_seg_val")

# import PointRend project
from detectron2.projects import point_rend

# style transfer model
import os
import json

Import for style transfer model

In [4]:
%cd /content/
if not os.path.exists("/content/photorealistic_style_transfer"):
  !rm -rf '/content/photorealistic_style_transfer'
  !git clone https://github.com/ptran1203/photorealistic_style_transfer
%cd photorealistic_style_transfer

if not os.path.exists("/content/tfrecords"):
    !wget -O /content/tfrecords.zip https://github.com/ptran1203/photorealistic_style_transfer/releases/download/v1.0/tfrecords.zip
    !unzip /content/tfrecords.zip -d /content/
output.clear()

when you get error, change "tensorflow.python.keras.applications" to "keras.applications" in data_processing.py

In [7]:
# !python3 train.py --train-tfrec /content/tfrecords/train.tfrec\
#                 --val-tfrec /content/tfrecords/val.tfrec\
#                 --epochs 100\
#                 --resume\
#                 --batch-size 8\
#                 --lr 2e-4\

from model import WCT2
from utils import read_img, download_weight, display_outputs
import cv2

In [8]:
model = WCT2()
model.load_weight('/content/photorealistic_style_transfer/checkpoints/wtc2.h5')

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5


Annotations and class ID references for segmentation later

In [9]:
if not os.path.exists("/content/annotations"):
  !wget -c http://images.cocodataset.org/annotations/annotations_trainval2017.zip
  !unzip -o annotations_trainval2017.zip 
# %cd /content

stuff_dict = {'blanket': 557916,'bridge': 384949,'counter': 74209,'curtain': 405970,'floor-wood': 573094,'flower': 565607,'fruit':  489305,'house': 356169,'rock':361180,'wall-stone':361180,
'pillow': 557916,'platform': 290293,'playingfield': 135604,'railroad': 558421,'river': 220858,'road': 179265,'roof': 82821,'sand': 454798,'sea': 214703,'shelf': 302030,'snow': 247838,
'towel': 384808,'wall-brick': 421923,'wall-tile': 262440,'wall-wood': 350054,'water': 214703,'window-blind': 573094,'tree': 179265,'fence': 558213,'ceiling': 573094,'sky': 13348, #144114
'cabinet': 530836,'table': 573094,'floor': 573094,'pavement': 558213,'mountain': 332318,'grass': 518213,'dirt': 242678,'paper': 278463,'food': 559513,'building': 67616,'wall': 573094,'rug': 404484,}

--2022-05-11 02:42:49--  http://images.cocodataset.org/annotations/annotations_trainval2017.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 52.216.143.44
Connecting to images.cocodataset.org (images.cocodataset.org)|52.216.143.44|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 252907541 (241M) [application/zip]
Saving to: ‘annotations_trainval2017.zip’


2022-05-11 02:42:52 (77.7 MB/s) - ‘annotations_trainval2017.zip’ saved [252907541/252907541]

Archive:  annotations_trainval2017.zip
  inflating: annotations/instances_train2017.json  
  inflating: annotations/instances_val2017.json  
  inflating: annotations/captions_train2017.json  
  inflating: annotations/captions_val2017.json  
  inflating: annotations/person_keypoints_train2017.json  
  inflating: annotations/person_keypoints_val2017.json  


Import COCO API

In [10]:
from pycocotools.coco import COCO
import skimage.io as io
import matplotlib.pyplot as plt
dataDir='..'
dataType='val2017'
annFile='annotations/instances_val2017.json'.format(dataDir,dataType)
# initialize COCO api for instance annotations
coco=COCO(annFile)

loading annotations into memory...
Done (t=0.70s)
creating index...
index created!


object detection predictor

In [11]:
# load predictor model
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"))
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")
predictor = DefaultPredictor(cfg)
catalog = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])

model_final_cafdb1.pkl: 261MB [00:05, 51.3MB/s]                           


# Segment Image and get Classes

We first download an image from the COCO dataset:

In [12]:
!wget https://farm9.staticflickr.com/8302/7877972164_51f717def0_z.jpg -q -O input.jpg

image_size = 512

def load_inputs(input_file="input.jpg", image_size=image_size):
  # save source image for comparison later
  colored_source_img = read_img(input_file, image_size)
  colored_source_img = cv2.cvtColor(colored_source_img, cv2.COLOR_BGR2RGB)

  # create grayscale image to color
  im = cv2.imread(input_file,0)
  source_img = Image.fromarray(im)
  source_img.save("/content/content.jpeg")
  source_img = read_img("/content/content.jpeg", image_size, expand_dims=True)
  im = source_img[0]

  # display
  # cv2_imshow(colored_source_img)
  # cv2_imshow(im)

  return im, colored_source_img

# im, colored_source = load_inputs()

Then, we create a detectron2 config and a detectron2 `DefaultPredictor` to run inference on this image. We make a prediction using Panoptic Segmentation

In [13]:
# Inference with a panoptic segmentation model
def get_masks_and_classes(im):

  # run predictor model
  outputs = predictor(im)

  # get segmentation info for binary masking later
  panoptic_seg, segments_info = outputs["panoptic_seg"]

  mask = outputs["sem_seg"].argmax(dim=0).to("cpu")
  seg_mask = panoptic_seg.to("cpu")

  # visualize segmentation
  v = Visualizer(im[:, :, ::-1], catalog, scale=1.2)
  out = v.draw_panoptic_seg_predictions(panoptic_seg.to("cpu"), segments_info)
  # cv2_imshow(out.get_image()[:, :, ::-1])

  # classes_in_seg = torch.unique(panoptic_seg.to("cpu"))
  pointer_to_things = {}
  # print(segments_info)
  for s in segments_info:
    if 'instance_id' not in s:
      continue
    if s['category_id'] not in pointer_to_things:
      pointer_to_things[s['category_id']] = [s['id']]
    else:
      pointer_to_things[s['category_id']].append(s['id'])

  return mask, seg_mask, pointer_to_things, segments_info

# mask, seg_mask, pointer_to_things, segments_info, catalog = get_masks_and_classes(im)

# # print(classes_in_seg)
# print(pointer_to_things)

Next, find all classes and corresponding class IDs found during segmentation

In [14]:
# get all classes found in image
def get_categories(segments_info):
  cats_in_image = set()
  pointer_to_cat = {}
  for seg in segments_info:
    id_x = seg['category_id']
    if seg['isthing']:
      cats_in_image.add(catalog.thing_classes[id_x])
      pointer_to_cat[id_x] = catalog.thing_classes[id_x]
    else:
      cats_in_image.add(catalog.stuff_classes[id_x])
      pointer_to_cat[id_x] = catalog.stuff_classes[id_x]

  class_ids = list(pointer_to_cat.keys())
  return cats_in_image, pointer_to_cat, class_ids

# cats_in_image, pointer_to_cat, class_ids = get_categories(catalog, segments_info)

# print(cats_in_image)
# print(pointer_to_cat)

# Style Transfer

Below are all functions used for binary mask style transfer

In [15]:
def make_masked_image(class_id, masks, original_image, pointer_to_things):
    mask = masks["sem_seg"]
    seg_mask = masks["pan_seg"]
    if class_id in mask: # if in background
      binary_mask = (mask.to("cpu") == class_id).numpy().astype(np.uint8)
    else:
      if class_id in pointer_to_things: # if in foreground
        # print("using panoptic mask")
        binary_mask = (seg_mask.to("cpu") == pointer_to_things[class_id][0]).numpy().astype(np.uint8)

        # if there are multiple of the same class
        if len(pointer_to_things[class_id]) > 1:
          for i in range(1, len(pointer_to_things[class_id])):
            binary_mask += (seg_mask.to("cpu") == pointer_to_things[class_id][i]).numpy().astype(np.uint8)
      else: # if can't be found, default
        # print("could not find mask")
        binary_mask = (mask.to("cpu") == 0).numpy().astype(np.uint8)
    mask_im = binary_mask[:, :, None] * original_image
    # print("masked image")
    # print(mask_im)
    # io.imshow(Image.fromarray(mask_im.astype('uint8')))
    return mask_im

def get_color_image(cat):
  """
  retrieve a color image based on category from COCO dataset
  """
  if cat in catalog.thing_classes:
    imgIds = coco.getImgIds(catIds=coco.getCatIds(catNms=[cat]))
    img = coco.loadImgs(imgIds[np.random.randint(0,len(imgIds))])
    # print(img)
    img = img[0]
  else:
    try:
      img = coco.loadImgs(ids=stuff_dict[cat])[0]
    except KeyError:
      img = coco.loadImgs(ids=stuff_dict['rock'])[0] # load a generically gray image as filler
  
  I = io.imread(img['coco_url'])
  return I

# for cat in cats_in_image:
#   plt.imshow(get_color_image(cat))
#   plt.show()

def style_transfer(content, idx, masks, pointers):

  image_size = 512
  content = read_img("/content/content.jpeg", image_size, expand_dims=True)
  
  gen = np.array([[np.nan]])

  # ensure we are choosing an image that applies style correctly
  tries = 0
  while np.isnan(gen[0]).any() and tries < 10:
    # print("grabbing style")
    style = get_color_image(pointers["cat"][idx])
    style_img = Image.fromarray(style)
    style_img.save("/content/style.jpeg")
    
    style = read_img("/content/style.jpeg", image_size, expand_dims=True)
    # print("about to attempt transfer")
    gen = model.transfer(content, style, 1.0)
    tries += 1
  
  if tries >= 10: # can't find a style, so return b&w
    return make_masked_image(idx, masks, content[0], pointers["things"])

  cv2.imwrite('/content/test.png', gen[0][...,::-1]) # save image
  # print("style transferred")
  # display_outputs(content[0], style[0], gen[0])
  # gen[0] is our output pre-mask
  # print(gen[0])

  return make_masked_image(idx, masks, gen[0], pointers["things"])

# Image.fromarray(final_image.astype('uint8'))

Run style transfer on each class

In [16]:
def full_colorization(class_ids, im, masks, pointers):
  final_image = None
  for c in class_ids:
    # print(c)
    if final_image is None:
      final_image = style_transfer(im, c, masks, pointers)
    else:
      final_image += style_transfer(im, c, masks, pointers)

  return final_image

# save and display result
# final_result = full_colorization(class_ids, im)
# final_image = Image.fromarray(final_result.astype('uint8'))
# final_image.save("/content/result.jpeg")
# final_image

In [17]:
def colorize(input_file="input.jpg", save_path="/content/result.jpeg"):
  im, colored_source = load_inputs(input_file=input_file)
  # print("got inputs")

  mask, seg_mask, pointer_to_things, segments_info = get_masks_and_classes(im)

  cats_in_image, pointer_to_cat, class_ids = get_categories(segments_info)
  # print("segmented")

  final_result = full_colorization(class_ids, im, {"sem_seg":mask, "pan_seg":seg_mask}, {"cat":pointer_to_cat, "things":pointer_to_things})

  final_image = Image.fromarray(final_result.astype('uint8'))
  # print('saving')
  final_image.save(save_path)

  return final_result, final_image

In [18]:
# result, final_img = colorize()
# final_img

# quantitative comparison

In [19]:
import math
def compare_rgb(source_path, result_path, epsilon=10, image_size=image_size):
  """
  calculate color distance based on https://www.compuphase.com/cmetric.htm 
  """
  og = read_img(source_path, image_size)
  og = cv2.cvtColor(og, cv2.COLOR_BGR2RGB)
  colored = read_img(result_path, image_size)
  colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)

  h, w, channels = np.array(og).shape
  total_px = h*w
  correct_px = 0
  for x in range(h):
    for y in range(w):
      dist = 0
      rmean = (og[x,y,0] + colored[x,y,0])/2
      rgb_diff = og[x,y] - colored[x,y]
      r_term = (2+(rmean/256)) * (rgb_diff[0]**2)
      g_term = 4 * (rgb_diff[1]**2)
      b_term = (2 + ((255-rmean)/256)) * (rgb_diff[2]**2)
      color_dist = math.sqrt(r_term + g_term + b_term)
      # print(color_dist)
      if abs(color_dist) <= epsilon: # if within threshold
        correct_px += 1
          
  return correct_px / total_px

# testing

In [20]:
def get_test_images():
  test_images = {}
  imgs_per_cat = 5

  for cat in coco.cats.values():
    cat_name = cat['name']
    cat_id_set = set()
    for i in range(imgs_per_cat):
      img_id = -1
      while img_id == -1:
        imgs_in_cat = coco.getImgIds(catIds=[cat['id']])
        img = coco.loadImgs(imgs_in_cat[np.random.randint(0,len(imgs_in_cat))])
        img_id = img[0]['id']
        if img_id in cat_id_set: # ensure no repeats
          img_id = -1
      cat_id_set.add(img_id)
    
    test_images[cat_name] = cat_id_set

  return test_images

In [21]:
# import shutil

# shutil.rmtree('/content/results/chair')

In [None]:
import csv

# get random images per category
test_imgs = get_test_images()
print(test_imgs)

# create results
data_dir = "/content/results/"

if not os.path.exists(data_dir):
  %cd /content
  os.mkdir(data_dir)

# csv file
csv_file_name = data_dir + 'accuracy_data.csv'
csv_header = ['COCO id', 'accuracy']
with open(csv_file_name, 'w', encoding='UTF8') as f:
    writer = csv.writer(f)

    # write the header
    writer.writerow(csv_header)


# for progress check
count = 0
total_count = 160

# run colorization on all test images
for cat, imgs in test_imgs.items():
  # if cat in ['orange','broccoli','carrot','hot dog','pizza','donut','cake','wine glass','cup','fork','knife','spoon','bowl','banana','apple','sandwich','tennis racket','bottle','skis','snowboard','sports ball','kite','baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'giraffe','umbrella', 'handbag', 'backpack', 'tie','suitcase','frisbee','cow','elephant','bear','zebra','horse','sheep','boat','truck','airplane','bicycle','car','motorcycle','person','bus','train','cat','dog','bird','bench','parking meter','stop sign','fire hydrant','traffic light']:
  #   continue

  print(cat)
  #create category directory
  dir_loc = data_dir + cat 
  if not os.path.exists(dir_loc):
    try:
        os.mkdir(dir_loc)
    except FileExistsError:
        print('Directory {} already exists'.format(dir_loc))
    else:
        print('Directory {} created'.format(dir_loc))

  for i in imgs:
    # save input image
    input_name = str(i) + "_input.jpeg"
    file_loc = dir_loc + "/" + input_name
    input_image = coco.loadImgs(i)[0]
    input_image = io.imread(input_image['coco_url'])
    input_image = Image.fromarray(input_image)
    input_image.save(file_loc)

    # save result image
    save_loc = dir_loc + "/" + str(i) + "_result.jpeg"
    final_result, final_image = colorize(input_file=file_loc, save_path=save_loc)

    # calc accuracy and add to csv file
    img_acc = compare_rgb(file_loc, save_loc, epsilon=650)
      # First, open the old CSV file in append mode, hence mentioned as 'a'
      # Then, for the CSV file, create a file object
    with open(csv_file_name, 'a', newline='') as f_object:  
        # Pass the CSV  file object to the writer() function
        writer_object = csv.writer(f_object)
        # Pass the data in the list as an argument into the writerow() function
        writer_object.writerow([i, img_acc])  
        f_object.close()

    count += 1
    print(count, " out of ", total_count)

print("finished, check results folder")

{'person': {101068, 148719, 571857, 569976, 97337}, 'bicycle': {50145, 350122, 169996, 458223, 306136}, 'car': {191013, 228942, 345361, 33109, 17627}, 'motorcycle': {227399, 336232, 81394, 147740, 408830}, 'airplane': {477441, 425221, 490413, 52017, 379453}, 'bus': {76416, 338625, 380706, 128051, 82846}, 'train': {363072, 492937, 146825, 184400, 151857}, 'truck': {336587, 491213, 148719, 400815, 281693}, 'boat': {427649, 160772, 36678, 61418, 289741}, 'traffic light': {125572, 385190, 555050, 39484, 423229}, 'fire hydrant': {338560, 8899, 161128, 9769, 306893}, 'stop sign': {369442, 222094, 15440, 724, 100283}, 'parking meter': {67616, 179265, 129062, 53994, 162366}, 'bench': {361730, 325031, 259690, 48504, 166747}, 'bird': {119233, 383339, 227187, 473015, 315001}, 'cat': {101762, 551815, 311789, 264441, 112798}, 'dog': {385997, 479155, 286422, 246454, 395801}, 'horse': {384513, 134856, 367818, 377486, 100510}, 'sheep': {154947, 207728, 547383, 546556, 176606}, 'cow': {200667, 137576, 

  max_size = (max_size + (stride - 1)) // stride * stride


1  out of  160
2  out of  160
3  out of  160
4  out of  160
5  out of  160
bicycle
Directory /content/results/bicycle created
6  out of  160
7  out of  160
8  out of  160
9  out of  160
10  out of  160
car
Directory /content/results/car created
11  out of  160
12  out of  160
13  out of  160
14  out of  160
15  out of  160
motorcycle
Directory /content/results/motorcycle created
16  out of  160
17  out of  160
18  out of  160
19  out of  160
20  out of  160
airplane
Directory /content/results/airplane created
21  out of  160
22  out of  160
23  out of  160
24  out of  160
25  out of  160
bus
Directory /content/results/bus created
26  out of  160
27  out of  160
28  out of  160
29  out of  160
30  out of  160
train
Directory /content/results/train created
31  out of  160
32  out of  160
33  out of  160
34  out of  160
35  out of  160
truck
Directory /content/results/truck created
36  out of  160
37  out of  160
38  out of  160
39  out of  160
40  out of  160
boat
Directory /content/resu