In [None]:
# Imports
from scipy.misc import imsave
import os
import numpy as np
from PIL import Image
from IPython import embed
from model import get_frontend, add_softmax, add_context
from utils import interp_map, pascal_palette
from keras.preprocessing.image import load_img, img_to_array

In [None]:
#CHANGE THIS TO YOUR GIT REPO LOCATION
INPUT_DIR='/home/bfortuner/workplace/deephacks/'

IMG_EXTEN='.jpg'
STYLE_FILENAME='starrynight.jpg'
TEST_PHOTO_FILENAME='cat.jpg' #'good_dog_img_for_testing.jpg'

OUTPUT_DIR='images/segmentations/'
IMAGES_DIR='images/'
INPUT_FILE=IMAGES_DIR+TEST_PHOTO_FILENAME
OUTPUT_FILE=OUTPUT_DIR+TEST_PHOTO_FILENAME+'_seg.png'
WEIGHTS_PATH=INPUT_DIR+'conversion/dilation8_pascal_voc.npy'
ZOOM=8
MEAN=[102.93, 111.36, 116.52]
PASCAL_PALETTE = {
    0: (0, 0, 0),
    1: (128, 0, 0),
    2: (0, 128, 0),
    3: (128, 128, 0),
    4: (0, 0, 128),
    5: (128, 0, 128),
    6: (0, 128, 128),
    7: (128, 128, 128),
    8: (64, 0, 0),
    9: (192, 0, 0),
    10: (64, 128, 0),
    11: (192, 128, 0),
    12: (64, 0, 128),
    13: (192, 0, 128),
    14: (64, 128, 128),
    15: (192, 128, 128),
    16: (0, 64, 0),
    17: (128, 64, 0),
    18: (0, 192, 0),
    19: (128, 192, 0),
    20: (0, 64, 128),
}
'''
categories={
        'aeroplane'; %1
        'bicycle'; %2
        'bird'; %3
        'boat'; %4
        'bottle'; %5
        'bus'; %6
        'car'; %7
        'cat'; %8
        'chair'; %9
        'cow'; %10
        'diningtable';%11
        'dog';%12
        'horse';%13
        'motorbike';%14
        'person';%15
        'pottedplant'; %16
        'sheep'; %17
        'sofa'; %18
        'train'; %19
        'tvmonitor'; %20
};
'''
None

In [None]:
# Settings for the Pascal dataset
input_width, input_height = 900, 900
label_margin = 186
has_context_module = False

def get_trained_model():
    """ Returns a model with loaded weights. """

    model = get_frontend(input_width, input_height)

    if has_context_module:
        model = add_context(model)

    model = add_softmax(model)

    def load_tf_weights():
        """ Load pretrained weights converted from Caffe to TF. """

        # 'latin1' enables loading .npy files created with python2
        weights_data = np.load(WEIGHTS_PATH, encoding='latin1').item()

        for layer in model.layers:
            if layer.name in weights_data.keys():
                layer_weights = weights_data[layer.name]
                layer.set_weights((layer_weights['weights'],
                                   layer_weights['biases']))

    def load_keras_weights():
        """ Load a Keras checkpoint. """
        model.load_weights(WEIGHTS_PATH)

    if WEIGHTS_PATH.endswith('.npy'):
        load_tf_weights()
    elif WEIGHTS_PATH.endswith('.hdf5'):
        load_keras_weights()
    else:
        raise Exception("Unknown weights format.")

    return model


def forward_pass():
    ''' Runs a forward pass to segment the image. '''

    model = get_trained_model()

    # Load image and swap RGB -> BGR to match the trained weights
    image_rgb = np.array(Image.open(INPUT_FILE)).astype(np.float32)
    image = image_rgb[:, :, ::-1] - MEAN
    image_size = image.shape

    # Network input shape (batch_size=1)
    net_in = np.zeros((1, input_height, input_width, 3), dtype=np.float32)

    output_height = input_height - 2 * label_margin
    output_width = input_width - 2 * label_margin

    # This simplified prediction code is correct only if the output
    # size is large enough to cover the input without tiling
    assert image_size[0] < output_height
    assert image_size[1] < output_width

    # Center pad the original image by label_margin.
    # This initial pad adds the context required for the prediction
    # according to the preprocessing during training.
    image = np.pad(image,
                   ((label_margin, label_margin),
                    (label_margin, label_margin),
                    (0, 0)), 'reflect')

    # Add the remaining margin to fill the network input width. This
    # time the image is aligned to the upper left corner though.
    margins_h = (0, input_height - image.shape[0])
    margins_w = (0, input_width - image.shape[1])
    image = np.pad(image,
                   (margins_h,
                    margins_w,
                    (0, 0)), 'reflect')

    # Run inference
    net_in[0] = image
    prob = model.predict(net_in)[0]

    # Reshape to 2d here since the networks outputs a flat array per channel
    prob_edge = np.sqrt(prob.shape[0]).astype(np.int)
    prob = prob.reshape((prob_edge, prob_edge, 21))

    # Upsample
    if ZOOM > 1:
        prob = interp_map(prob, ZOOM, image_size[1], image_size[0])

    # Recover the most likely prediction (actual segment class)
    prediction = np.argmax(prob, axis=2)
    # Apply the color palette to the segmented image
    color_image = np.array(pascal_palette)[prediction.ravel()].reshape(
        prediction.shape + (3,))
    print('Saving results to: ', OUTPUT_FILE)
    with open(OUTPUT_FILE, 'wb') as out_file:
        Image.fromarray(color_image).save(out_file)
    return prediction, color_image

In [None]:
pred,color_image = forward_pass()

In [None]:
TARGET_CATEGORY=15 #human
TARGET_CATEGORY=12 #dog
TARGET_CATEGORY=8 #cat

In [None]:
#Test some things
print(np.max(pred))
print(np.min(pred))
print(np.unique(pred))
print(pred.shape)

In [None]:
#Reshape np array back to image format with 1 channel
out = pred.reshape(pred.shape + (1,))
out.shape

In [None]:
#Change all pixels != our target category to 0 (black)
pred[pred != TARGET_CATEGORY] = 0

#Change all pixels == our target category to 1 (white)
pred[pred == TARGET_CATEGORY] = 1

In [None]:
#Save image locally
imsave(OUTPUT_FILE, pred)

In [None]:
#View segmentation image
segimg = load_img(OUTPUT_FILE)
segimg