## Setup

In [None]:
# Google Colab setup
# Ignore if you're not running on colabx

GDRIVE_PWD = 'ContextSeg'

try:
    from google.colab import drive
    import os
    IN_COLAB = True
except:
    IN_COLAB = False
    
if IN_COLAB:
    drive.mount('/content/gdrive', force_remount=True)
    root_dir = "/content/gdrive/My Drive/"
    base_dir = os.path.join(root_dir, GDRIVE_PWD)
    
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)
    os.chdir(base_dir)

In [None]:
import os

P_DATA = 'data'
P_SAVEDMODEL = 'models/checkpoints'
P_LOGS = 'logs'
P_OUTPUT = 'output'

dirs = [P_DATA, P_SAVEDMODEL, P_LOGS, P_OUTPUT]

for d in dirs:
    if not os.path.exists(d):
        os.makedirs(d)

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import keras
from keras import backend as K
import numpy as np
import os
import h5py
import matplotlib.pylab as plt
from models.ContextSeg import ContextSeg
from utils.colormap import label_defs
from data_loader.data_generator import DataGenerator

In [None]:
# os.environ['KMP_DUPLICATE_LIB_OK']='True'
# K.tensorflow_backend._get_available_gpus()

In [None]:
DATASET = os.path.join(P_DATA, 'cihp_dataset.h5')
MODEL_SAVE_PATH = os.path.join(P_SAVEDMODEL, # path to save the model
                               'dweights.{epoch:02d}-{val_acc:.2f}.hdf5')

INPUT_SHAPE = (320, 320, 3)
BATCH_SIZE = 8
NUM_CLASSES = 20
EPOCHS = 80

In [None]:
# create hdf5 dataset from tar file
# from utils.create_hdf5 import create_hdf5

# create_hdf5('instance-level_human_parsing.tar.gz', DATASET)

## Train

In [None]:
hf = h5py.File(DATASET, 'r', libver='latest', swmr=True)

seed=1
datagen_args = dict(
    rotation_range=20,
    width_shift_range=0.05,
    height_shift_range=0.05,
    brightness_range=[0.7, 1.4],
    shear_range=0.05,
    channel_shift_range=30,
    horizontal_flip=True,
    rescale=1/255,
    fill_mode='reflect'
)


train_generator = DataGenerator(
    hf['x_train'], y=hf['y_train'], datagen_args=datagen_args,
    input_dim=INPUT_SHAPE, batch_size=BATCH_SIZE,
    colormap=label_defs, seed=seed)

val_generator = DataGenerator(
    hf['x_val'], y=hf['y_val'], datagen_args=datagen_args,
    input_dim=INPUT_SHAPE, batch_size=BATCH_SIZE,
    colormap=label_defs, seed=seed)

test_generator = DataGenerator(
    hf['x_test'], datagen_args=datagen_args,
    input_dim=INPUT_SHAPE, batch_size=BATCH_SIZE,
    colormap=label_defs, seed=seed)

In [None]:
model = ContextSeg(INPUT_SHAPE, NUM_CLASSES)

opt = keras.optimizers.RMSprop(lr=5e-5, rho=0.9, epsilon=1e-08, decay=0.0)
model.compile(loss='categorical_crossentropy',
              optimizer=opt,
              metrics=['accuracy'])

In [None]:
modelchk = keras.callbacks.ModelCheckpoint(MODEL_SAVE_PATH, 
                                      monitor='val_loss', 
                                      verbose=1,
                                      save_best_only=True, 
                                      save_weights_only=False,
                                      mode='auto',
                                      period=3)

tensorboard = keras.callbacks.TensorBoard(log_dir=P_LOGS,
                                          histogram_freq=0,
                                          write_graph=True,
                                          write_images=True)

csv_logger = keras.callbacks.CSVLogger(os.path.join(P_LOGS, 'keras_log.csv'),
                                       append=True)

In [None]:
model.fit_generator(train_generator,
                    epochs=EPOCHS,
                    verbose=1,
                    validation_data=val_generator,
                    callbacks=[modelchk, tensorboard, csv_logger],
                    workers=8)

## Evaluation

In [None]:
import seaborn as sns
import numpy as np

plt.style.use('ggplot')


def plot2D(points, title, labelX, labelY, legends=None, save=None):
    plt.plot(points[0],'r')
    plt.plot(points[1],'g')
    plt.xlabel(labelX, fontsize=11)
    plt.ylabel(labelY, fontsize=11)
    plt.title(title, fontname="Comic Sans MS Bold", fontsize=14)
    plt.legend(legends)
    # plt.rcParams['figure.figsize'] = (7, 5)
    if save:
        plt.savefig(save)
    plt.show()
    
    
def evalModel(history, save=None):
    h = history
    if save:
      sv1 = os.path.join(P_PLOT, save + '_loss.png')
      sv2 = os.path.join(P_PLOT, save + '_acc.png')
    else:
      sv1, sv2 = None, None
      
    
    plt.figure(0)
    plot2D((h['loss'], h['val_loss']), 'Training and validation loss',
           'Epochs', 'Loss', ['train', 'val'], save=sv1)
    
    plt.figure(1)
    plot2D((h['acc'], h['val_acc']), 'Training and validation accuracy',
           'Epochs', 'Accuracy', ['train', 'val'], save=sv2)
    

In [None]:

from glob import glob
import data_processing as dp

output_dir='pred/'
test_dir='data/test/Images'
image_shape = (384, 512, 3)
target_shape =(48, 64)
batch_size=8
num_classes=20

test_images=sorted(list(glob( os.path.join(test_dir, '*.jpg') )))

label_colors = {i: np.array(l.color) for i, l in enumerate(dp.label_defs)}

In [None]:
from keras.models import load_model

model = load_model('weights.39-0.56.hdf5')

In [None]:
test_images

In [None]:
from skimage.transform import rescale, resize
import skimage.io as io


N=1000

images = np.empty((N, *image_shape)) 
images_low_res = np.empty((N, 
                           image_shape[0] // 4, image_shape[1] // 4,
                           3))

masks = np.empty((N, *target_shape, 3)) 


for i in range(N):
    print(i)
    images[i] = resize(io.imread(test_images[i]), image_shape,
                       mode='reflect')#, anti_aliasing=True)

    images_low_res[i] = rescale(images[i], 1 / 4, mode='reflect')
                               # multichannel=True, anti_aliasing=True) 

In [None]:
import time

In [None]:
s_time = time.time()

predicted = model.predict([images, images_low_res])

e_time = time.time()
print('Running time on {} images: {:3.3f}'.format(len(images),
                                                  e_time - s_time))

In [None]:
print(f'FPS: {len(images) / (e_time - s_time): 3.3f}')

In [None]:
for i in range(N):
    labels = np.argmax(predicted[i], axis=-1)
    labels = labels.reshape(target_shape)
    labels_colored = np.zeros((*target_shape, 3)) 
                
    for label, color in label_colors.items():
        labels_colored[labels == label] = color
    
    final_out = resize(labels_colored, image_shape)
    
    basename = os.path.splitext(os.path.basename(test_images[i]))[0]
    masks[i] = labels_colored
    io.imsave(os.path.join(output_dir, basename + '.png'), final_out / 255)

## Analysis

In [None]:
nrows = len(label_defs)


def plot_color_gradients(cmap_list, nrows):
    fig, axes = plt.subplots(nrows=nrows, figsize=(10, 7))
    fig.subplots_adjust(top=0.95, bottom=0.01, left=0.2, right=0.99)
    axes[0].set_title('CIHP Colormap', fontsize=14, color='black')

    for ax, label in zip(axes, cmap_list):
        ax.imshow([[label.color]], aspect='auto', vmin=0, vmax=255)
        pos = list(ax.get_position().bounds)
        x_text = pos[0] - 0.01
        y_text = pos[1] + pos[3]/2.
        fig.text(x_text, y_text, label.name, 
                 va='center', ha='right', 
                 fontsize=12, color='black')

    # Turn off *all* ticks & spines, not just the ones with colormaps.
    for ax in axes:
        ax.set_axis_off()


plot_color_gradients(label_defs, nrows)
plt.show()


In [None]:
classes = [0, 2, 7, 9, 13, 15, 19]
N=len(test_images)

fig, axes = plt.subplots(nrows=N, ncols=len(classes)+1, figsize=(22, 20))
fig.subplots_adjust(top=0.95, bottom=0, left=0.1, right=0.99, wspace = 0.1, hspace = 0)


for i in range(N):
    axes[i][0].imshow(images[i])
    if i == 0:
        axes[i][0].text(0, -85,
                    'Original',
                    fontsize=21,
                    color='black', va='top')

    axes[i][0].set_axis_off()
    for j, ax in enumerate(axes[i][1:]):
        ax.imshow(predicted[i, :,:, classes[j]], cmap='viridis')
        if i == 0:
            ax.text(0, -10,
                    dp.label_defs[classes[j]].name if classes[j] != 0 else 'Background',
                    fontsize=21,
                    color='black', va='top')
        ax.set_axis_off()
        
plt.show()