In [1]:
import cv2
import glob
import tensorflow as tf
import numpy as np
import os
from scipy import misc
import argparse
import sys
# Will use matplotlib for showing the image
#from matplotlib import pyplot as plt

In [2]:
g_mean = np.array(([126.88,120.24,112.19])).reshape([1,1,3])
output_folder = "./test_output"

In [3]:
def rgba2rgb(img):
    return img[:,:,:3]*np.expand_dims(img[:,:,3],2)

In [4]:
def enhance_images():
    
    print("Enhancing images...")
    
    if not os.path.exists('./enhanced'):
        os.mkdir('./enhanced')
    
    input_image_filenames = glob.glob("../sketchy_database/256x256/photo/tx_000100000000/zebra/*.jpg")
    for filename in input_image_filenames:
        
        image = cv2.imread(filename, cv2.IMREAD_COLOR)
    
        image_yuv = cv2.cvtColor(image, cv2.COLOR_BGR2YUV)
        # equalize the histogram of the Y channel
        image_yuv[:,:,0] = cv2.equalizeHist(image_yuv[:,:,0])
        output_image = cv2.cvtColor(image_yuv, cv2.COLOR_YUV2BGR)
        
        input_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        output_rgb = cv2.cvtColor(image_yuv, cv2.COLOR_YUV2RGB)
        #plt.imshow(np.hstack((input_rgb,output_rgb)))
        
        
        filename_without_path = filename[56:]
        
        cv2.imwrite('./enhanced/' + filename_without_path, output_image)

In [5]:
def generate_salience_maps():
    
    print("Generating salience maps...")
    
    rgb_folder = "./enhanced"
    gpu_fraction = 1.0
    
    if not os.path.exists('./salience_maps'):
        os.mkdir('./salience_maps')
    
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = gpu_fraction)
    with tf.Session(config=tf.ConfigProto(gpu_options = gpu_options)) as sess:
        saver = tf.train.import_meta_graph('./meta_graph/my-model.meta')
        saver.restore(sess,tf.train.latest_checkpoint('./salience_model'))
        image_batch = tf.get_collection('image_batch')[0]
        pred_mattes = tf.get_collection('mask')[0]

        
        rgb_pths = os.listdir(rgb_folder)
        for rgb_pth in rgb_pths:
            rgb = misc.imread(os.path.join(rgb_folder,rgb_pth))
            if rgb.shape[2]==4:
                rgb = rgba2rgb(rgb)
            origin_shape = rgb.shape
            rgb = np.expand_dims(misc.imresize(rgb.astype(np.uint8),[320,320,3],interp="nearest").astype(np.float32)-g_mean,0)

            feed_dict = {image_batch:rgb}
            pred_alpha = sess.run(pred_mattes,feed_dict = feed_dict)
            final_alpha = misc.imresize(np.squeeze(pred_alpha),origin_shape)
            misc.imsave(os.path.join('./salience_maps',rgb_pth),final_alpha)


In [6]:
def convert_salience_maps_to_binary_masks():
    
    print("Converting salience maps to binary masks...")
    
    if not os.path.exists('./binary_masks'):
        os.mkdir('./binary_masks')
    
    input_image_filenames = glob.glob("./salience_maps/*.jpg")
    for filename in input_image_filenames:
        
        image = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
    
        (thresh, im_bw) = cv2.threshold(image, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
        
        before = cv2.cvtColor(image,cv2.COLOR_GRAY2RGB)
        after = cv2.cvtColor(im_bw,cv2.COLOR_GRAY2RGB)
        #plt.imshow(np.hstack((before,after)))
        
        filename_without_path = filename[16:]
        print(filename_without_path)
        
        cv2.imwrite('./binary_masks/' + filename_without_path, im_bw)

In [7]:
def apply_binary_masks_to_original_images():
    
    print("Apply binary masks to original images...")
    
    if not os.path.exists('./extracted'):
        os.mkdir('./extracted')
    
    mask_filenames = glob.glob("./binary_masks/*.jpg")
    for filename in mask_filenames:
        
        mask = cv2.imread(filename, cv2.IMREAD_COLOR)
        filename_without_path = filename[15:]
        original = cv2.imread('../sketchy_database/256x256/photo/tx_000100000000/zebra/' + filename_without_path, cv2.IMREAD_COLOR)
        
        extracted = cv2.bitwise_and(original, mask)
        
        #inverted_mask = cv2.bitwise_not(mask)
        
        original_rgb = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
        original_and_mask = np.hstack((original_rgb, mask))
        extracted_rgb = cv2.cvtColor(extracted, cv2.COLOR_BGR2RGB)
        original_and_mask_and_extracted = np.hstack((original_and_mask, extracted_rgb))
        #plt.imshow(original_and_mask_and_extracted)
        
        cv2.imwrite('./extracted/' + filename_without_path, extracted)

In [8]:
enhance_images()
generate_salience_maps()
convert_salience_maps_to_binary_masks()
apply_binary_masks_to_original_images()

Enhancing images...
Generating salience maps...


AttributeError: module 'tensorflow' has no attribute 'GPUOptions'