In [0]:
import os
import numpy as np
import matplotlib.pyplot as plt

from keras.models import load_model
from keras.regularizers import l2
from keras.optimizers import Adam, RMSprop, SGD
import keras.backend as K

from preprocessing import Preprocessing
from model import FullyConvolutionalNetwork
from metrics import Metrics

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [0]:
# from google.colab import drive
# drive.mount('/content/gdrive', force_remount=True)

In [0]:
# !unzip gdrive/My\ Drive/PASCALVOCdataset12000images

In [0]:
NUM_CLASSES = 22
height, width = 224, 224
imagePath = "PASCALVOCdataset12000images/Images/"
annotationPath = "PASCALVOCdataset12000images/SegmentationClassAug"

In [0]:
len(os.listdir(imagePath))

In [0]:
prePro = Preprocessing(height, width, NUM_CLASSES)
trainImages, valImages, testImages = prePro.get_test_train_filenames(imagePath, 0.2, 0.2)

In [0]:
len(trainImages), len(valImages), len(testImages)

In [0]:
_, axs = plt.subplots(3, 2, figsize=(15, 15))
for n, d in enumerate(prePro.data_gen(trainImages, imagePath, annotationPath, 1)):
    _, h, w, c = d[0].shape
    axs[n][0].imshow(d[0].reshape(h, w, c))
    axs[n][1].imshow(np.argmax(d[1], axis=3).reshape(h, w))
    if(n == 2):
        break

In [0]:
from tensorflow.python.ops import math_ops, array_ops
from tensorflow.python.framework import ops

def new_sparse_categorical_accuracy(y_true, y_pred):
    y_pred_rank = ops.convert_to_tensor(y_pred).get_shape().ndims
    y_true_rank = ops.convert_to_tensor(y_true).get_shape().ndims

    # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
    if (y_true_rank is not None) and (y_pred_rank is not None) and (len(K.int_shape(y_true)) == len(K.int_shape(y_pred))):
        y_true = array_ops.squeeze(y_true, [-1])
    y_pred = math_ops.argmax(y_pred, axis=-1)
    # If the predicted output and actual output types don't match, force cast them
    # to match.
    if K.dtype(y_pred) != K.dtype(y_true):
        y_pred = math_ops.cast(y_pred, K.dtype(y_true))
    return math_ops.cast(math_ops.equal(y_true, y_pred), K.floatx())

#credits : https://github.com/keras-team/keras/issues/11348#issuecomment-468568429

# Fully Convolution Network

In [0]:
from model import FullyConvolutionalNetwork
model = FullyConvolutionalNetwork((height, width, 3), NUM_CLASSES)

## Traditional FCN

In [0]:
fcn = model.get_model()
fcn.summary()

In [0]:
fcn.compile(optimizer=Adam(lr=1e-4), loss='categorical_crossentropy', metrics=['accuracy'])

In [0]:
history_fcn = fcn.fit_generator(prePro.data_gen(trainImages, imagePath, annotationPath, 32), steps_per_epoch=128, epochs = 1, 
                            validation_data=prePro.data_gen(valImages, imagePath, annotationPath, 32), validation_steps=48)

In [0]:
fcn.save("gdrive/My Drive/Traditional_FCN(Adam).h5")

## Modified FCN

In [0]:
modified_fcn = model.get_modified_model()
modified_fcn.summary()

In [0]:
modified_fcn.compile(optimizer=Adam(lr=1e-4), loss='categorical_crossentropy', metrics=['accuracy'])

In [0]:
history_modified_fcn = modified_fcn.fit_generator(prePro.data_gen(trainImages, imagePath, annotationPath, 32), steps_per_epoch=128, epochs = 1, 
                            validation_data=prePro.data_gen(valImages, imagePath, annotationPath, 32), validation_steps=48)

In [0]:
modified_fcn.save("gdrive/My Drive/Modified_FCN(Adam).h5")

In [0]:
# import IPython
# import keras
# keras.utils.plot_model(fcn, to_file='fcn_modified_model.png', show_shapes=True)
# IPython.display.Image('fcn_modified_model.png')

# Results and Plot

## Traditional Fully Convolutional Network

In [0]:
met_fcn = Metrics(imagePath, annotationPath, fcn, prePro)

### Validation Set

In [0]:
met_fcn.plot_predictions(valImages)

In [0]:
met_fcn.plot_graphs('acc', 'epoch', 'accuracy', 'Accuracy')

In [0]:
met_fcn.plot_graphs('loss', 'epoch', 'loss', 'Loss')

### Test Set

In [0]:
met_fcn.plot_predictions(testImages)

## Modified Fully Convolutional Network

In [0]:
met_modified_fcn = Metrics(imagePath, annotationPath, modified_fcn, prePro)

### Validation Set

In [0]:
met_modified_fcn.plot_predictions(valImages)

In [0]:
met_modified_fcn.plot_graphs('acc', 'epoch', 'accuracy', 'Accuracy')

In [0]:
met_modified_fcn.plot_graphs('loss', 'epoch', 'loss', 'Loss')

### Test Set

In [0]:
met_modified_fcn.plot_predictions(testImages)