[View in Colaboratory](https://colab.research.google.com/github/dkatsios/semantic_segmentation/blob/master/voc2012.ipynb)

In [0]:
colab = True

In [0]:
import numpy as np
import os
from shutil import unpack_archive
import cv2
from matplotlib import pyplot as plt
from IPython.display import Image
import PIL
from keras.optimizers import Adam
from time import time
if colab:
  from google.colab import files
import keras
from keras.models import load_model
from keras.callbacks import ModelCheckpoint
import pickle

### Download VOC 2012 dataset

In [0]:
if colab:
#   %rm -r /content/semantic_segmentation
#   %mkdir /content/semantic_segmentation
#   %cd /content/semantic_segmentation/
#   !wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
#   unpack_archive('VOCtrainval_11-May-2012.tar', './')
#   %rm VOCtrainval_11-May-2012.tar
  wdir = '/content/semantic_segmentation'
else:
  wdir = 'E:/Data_Files/Workspaces/PyCharm/semantic_segmentation/'

In [0]:
%cd {wdir}/VOCdevkit/VOC2012
!ls

In [0]:
imgs_folder = wdir + '/VOCdevkit/VOC2012/JPEGImages/'
classes_folder = wdir + '/VOCdevkit/VOC2012/SegmentationClass/'
train_list_path = wdir + '/VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt'
val_list_path = wdir + '/VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt'

### Import helpers and Model files

In [0]:
if colab:
    %cd {wdir}
    %rm -r {wdir}/semantic_segmentation/
    !git clone https://github.com/dkatsios/semantic_segmentation.git
    %cd {wdir}/semantic_segmentation
else:
    %cd {wdir}
%ls

In [0]:
!pip install import_ipynb
import import_ipynb
from voc2012_helpers import *
from AtrusUnet_2 import  *

### Set parameters

In [0]:
img_shape = 512, 512, 3
filters = 128  # 256
segmentation_classes = 21
steps = 5
out_resized_levels = 0
kernel_sizes = [2, 3, 5, 7]  # [2, 3, 5, 7]
dilation_rates = [1, 2, 3]  # [1, 2, 3]
use_max_pooling = True
use_depthwise = True
pre_resized = False

batch_size = 1
epochs = 1
val_steps = 10

### Build generators

In [0]:
train_lists, val_lists = get_lists_from_folders(train_list_path, val_list_path, imgs_folder, classes_folder)

# train_lists = train_lists[0][:100], train_lists[1][:100]
# val_lists = val_lists[0][:100], val_lists[1][:100]

train_arrays = get_imgs_classes_arrays(*train_lists, img_shape)
val_arrays = get_imgs_classes_arrays(*val_lists, img_shape)

In [0]:
steps_per_epoch= len(train_lists[0]) // batch_size
train_gen = imgs_generator(*train_arrays, batch_size, out_resized_levels, segmentation_classes, pre_resized, steps)
val_gen = imgs_generator(*val_arrays, batch_size, out_resized_levels, segmentation_classes, pre_resized, steps)

### Build model

In [0]:
atrous_unet = AtrousUnet(img_shape, filters, segmentation_classes, steps, out_resized_levels,
                         kernel_sizes, use_max_pooling=use_max_pooling, use_depthwise=use_depthwise,
                         pre_resized=pre_resized, use_regularizers=False)

atrous_unet.build_model()
print('number of parameters:', atrous_unet.model.count_params())

### Compile model

In [0]:
optimizer = Adam(0.0001)
loss, metrics, loss_weights = get_loss_metrics_weights(out_resized_levels, use_dice=True, loss_factor=10.)

atrous_unet.model.compile(optimizer=optimizer, loss=loss,
                          metrics=metrics, loss_weights=loss_weights)
if colab:
    %mkdir {wdir}/logs/

### Train model

In [0]:
class_weight = None  # get_class_weight(segmentation_classes, background_ratio=1/1)
weights_path = wdir + '/logs/weights.hdf5'
checkpointer = ModelCheckpoint(filepath=weights_path, verbose=1, save_best_only=True)
callbacks = [checkpointer]

In [0]:
atrous_unet.model.summary()

In [0]:
history = atrous_unet.model.fit_generator(train_gen,
                                          steps_per_epoch=steps_per_epoch, epochs=20,
                                          verbose=1, validation_data=val_gen, validation_steps=val_steps,
                                          class_weight=class_weight, callbacks=callbacks)

with open(wdir+'/logs/history.pkl', 'wb') as handle:
    pickle.dump(history.history, handle, protocol=pickle.HIGHEST_PROTOCOL)

### Download results (weights and history)

In [0]:
# files.download(wdir+'/logs/weights.hdf5')
# files.download(wdir+'/logs/history.pkl')

### Load weights

In [0]:
atrous_unet = AtrousUnet(img_shape, filters, segmentation_classes, steps, out_resized_levels,
                         kernel_sizes,use_max_pooling=use_max_pooling, use_depthwise=use_depthwise)

atrous_unet.build_model()
atrous_unet.model.load_weights(wdir + '/logs/weights.hdf5')

### Plot results

In [0]:
def get_images_from_predictions(preds):
  cmap_dict = get_cmap_dict(reversed=True)
  preds = np.argmax(preds, axis=-1)
  imgs = np.zeros((*preds.shape, 3))
  for i, pred in enumerate(preds):
    for j in range(pred.shape[0]):
      for k in range(pred.shape[1]):
        imgs[i, j, k, :] = cmap_dict[pred[j, k]]
  
  return imgs

In [0]:
imgs, labels = val_gen.__next__()
predictions = atrous_unet.model.predict_on_batch(imgs)
pred_labels = predictions[-1] if len(predictions[0].shape) > 3 else predictions
pred_labels = get_images_from_predictions(pred_labels)
real_labels = labels['prediction']  # labels[-1] if isinstance(labels, list) and len(labels[0].shape) > 3 else labels
real_labels = get_images_from_predictions(real_labels)

In [0]:
np.max(np.argmax(predictions[-1], axis=-1))

In [0]:
for img, real_label, pred_label in zip(imgs, real_labels, pred_labels):
  f, (ax1, ax2, ax3) = plt.subplots(1, 3)
  ax1.imshow(img)
  ax1.axis('off')
  ax2.imshow(real_label)
  ax2.axis('off')
  ax3.imshow(pred_label)
  ax3.axis('off')
  plt.show()