[View in Colaboratory](https://colab.research.google.com/github/dkatsios/semantic_segmentation/blob/master/voc2012.ipynb)

In [0]:
colab = True

In [0]:
import numpy as np
import os
from shutil import unpack_archive
import cv2
from matplotlib import pyplot as plt
from IPython.display import Image
import PIL
from keras.optimizers import Adam
from time import time
if colab:
  from google.colab import files
import keras
from keras.models import load_model
from keras.callbacks import ModelCheckpoint
import pickle

### Download VOC 2012 dataset

In [0]:
if colab:
#   %rm -r /content/semantic_segmentation
#   %mkdir /content/semantic_segmentation
#   %cd /content/semantic_segmentation/
#   !wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
#   unpack_archive('VOCtrainval_11-May-2012.tar', './')
#   %rm VOCtrainval_11-May-2012.tar
  wdir = '/content/semantic_segmentation'
else:
  wdir = 'E:/Data_Files/Workspaces/PyCharm/semantic_segmentation/'

In [65]:
%cd {wdir}/VOCdevkit/VOC2012
!ls

/content/semantic_segmentation/VOCdevkit/VOC2012
Annotations  ImageSets	JPEGImages  SegmentationClass  SegmentationObject


In [0]:
imgs_folder = wdir + '/VOCdevkit/VOC2012/JPEGImages/'
classes_folder = wdir + '/VOCdevkit/VOC2012/SegmentationClass/'
train_list_path = wdir + '/VOCdevkit/VOC2012/ImageSets/Segmentation/train.txt'
val_list_path = wdir + '/VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt'

### Import helpers and Model files

In [67]:
if colab:
    %cd {wdir}
    %rm -r {wdir}/semantic_segmentation/
    !git clone https://github.com/dkatsios/semantic_segmentation.git
    %cd {wdir}/semantic_segmentation
else:
    %cd {wdir}
%ls

/content/semantic_segmentation
Cloning into 'semantic_segmentation'...
remote: Counting objects: 63, done.[K
remote: Compressing objects: 100% (62/62), done.[K
remote: Total 63 (delta 35), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (63/63), done.
/content/semantic_segmentation/semantic_segmentation
AtrusUnet_2.ipynb  README.md              voc2012.ipynb
AtrusUnet.ipynb    voc2012_helpers.ipynb


In [68]:
!pip install import_ipynb
import import_ipynb
from voc2012_helpers import *
# from AtrusUnet import  *



In [0]:
from keras.layers import Conv2D, Input, Concatenate, Deconv2D, Lambda, \
ZeroPadding2D, SeparableConv2D, BatchNormalization, Dropout, MaxPooling2D
from keras.regularizers import l1, l2
from keras.models import Model

import keras.backend as K
import tensorflow as tf

import numpy as np
np.random.seed(0)

In [0]:
class AtrousUnet:
  def __init__(self, img_shape, filters, out_channels,
               steps, out_resized_levels, kernel_sizes=None, dilation_rates=None, use_depthwise=False,
               use_max_pooling=True, use_regularizers=True, pre_resized=False, dropout_rate=0.4):
    self.img_shape = img_shape
    self.filters = filters
    self.out_channels = out_channels
    self.steps = steps
    self.out_resized_levels = out_resized_levels
    self.dropout_rate = dropout_rate
    self.use_max_pooling = use_max_pooling
    self.use_regularizers = use_regularizers
    self.pre_resized = pre_resized
    self.kernel_sizes = kernel_sizes if kernel_sizes is not None else [2, 3, 5, 7]
    self.dilation_rates = dilation_rates if dilation_rates is not None else [1, 2, 3]
    self.kern_reg, self.act_reg = (l2(), l1()) if self.use_regularizers else (None, None)
    self.conv = SeparableConv2D if use_depthwise else Conv2D
    self.deconv = Deconv2D  # SeparableConvolution2D if use_depthwise else Deconv2D 
    
    assert 2 ** self.steps <= np.min(img_shape[:2]) and \
    self.out_resized_levels <= self.steps
    
  def resize_img(self, input_tensor):
    return tf.image.resize_images(input_tensor[0],
                                  input_tensor[1].shape[1:-1])
  
  def check_shape(self, x):
    x_shape = K.int_shape(x)[1:-1]
    if x_shape == self.current_shape:
      return x
    
    dr = self.current_shape[0] - x_shape[0]
    if dr == 0:
      r_pad = 0, 0
    elif dr % 2 == 0:
      r_pad = dr // 2, dr // 2
    else:
      r_pad = dr // 2, dr // 2 + 1
    
    dc = self.current_shape[1] - x_shape[1]
    if dc == 0:
      c_pad = 0, 0
    elif dc % 2 == 0:
      c_pad = dc // 2, dc // 2
    else:
      c_pad = dc // 2, dc // 2 + 1
      
    assert np.abs(np.array([r_pad, c_pad])).all() <= 1
    if dr < 0 or dc < 0:
      r = -r_pad[0], x_shape[0] + r_pad[1]
      c = -c_pad[0], x_shape[1] + c_pad[1]
      x = Lambda(lambda x: x[:,r[0]:r[1],c[0]:c[1],:])(x)
    else:
      x = ZeroPadding2D((r_pad, c_pad))(x)
    return x
  
  def down_sampling_block(self, input_img, input_tensor, ind):
    """
    The downsampling block takes as input the original image (input_img)
    and the output of the previous downsampling block (input_tensor).
    Steps:
      - downsampling (kernel 2x2, strides 2) of 'input_tensor' to half its shape to 'down_sampled'
      - for each kernel size:
        - convolution with this kernel size, strides 1 and dilation 2
        - downsampling (convolution) with kernel 2x2 and strides 2
      - concatenation of the last one with 'down_sampled' to 'down_sampled'
      - resize of input_img to the same size as 'down_sampled' to 'resized_img'
      - convolution of 'down_sampled'  with kernel 1x1
      - concatenation of 'down_sampled' with 'resized_img' to total channels 2*filters
      - batch normalization.
    """
    ###############
    kern_reg, act_reg = self.kern_reg, self.act_reg
    dilated = []
    for k in self.kernel_sizes:
      for d in self.dilation_rates:
        x = self.conv(self.filters // 2, (k, k), strides=(1, 1), activation='relu',
#                       depthwise_regularizer=kern_reg, pointwise_regularizer=kern_reg, activity_regularizer=act_reg,
                      padding='same', dilation_rate=(d, d))(input_tensor)
        dilated.append(x)
    ###############        
    concatenated = Concatenate()([*dilated])
    if self.use_max_pooling:
      shortened = Conv2D(2 * self.filters, (1, 1), activation='relu',
                         kernel_regularizer=kern_reg, activity_regularizer=act_reg,
                         padding='same')(concatenated)
      down_sampled = MaxPooling2D()(shortened)
    else:
      down_sampled = self.conv(2 * self.filters, (3, 3), strides=(2, 2), activation='relu',
                               depthwise_regularizer=kern_reg, pointwise_regularizer=kern_reg,
                               activity_regularizer=act_reg,
                               padding='same')(concatenated)
    
#     down_sampled = BatchNormalization()(down_sampled)
    down_sampled = Dropout(self.dropout_rate)(down_sampled)
    ###############
    if self.pre_resized:
      self.current_shape = K.int_shape(down_sampled)[1:-1]
      resized_img = input_img
      resized_img = self.check_shape(resized_img)
    else:
      resized_img = Lambda(self.resize_img)([input_img, down_sampled])
    merged = Concatenate()([down_sampled, resized_img])
    ###############
    return merged
  
  def up_sampling_block(self, down, same, index):
    """
    The upsampling block takes as input the output of the previous upsampling block (down)
    and the output of the downsampling block of the same level (same).
    Steps:
      - upsampling (deconvolution) of 'down' to the shape of 'same'
      - concatenation of the previous one with the 'same' to the 'concatenated'
      - convolution of the 'concatenated' to same shape and standard filters
      - batch normalization of previous
      - optionally convolves to 1x1 filters and output channels for intermediate loss.
    """
    self.current_shape = K.int_shape(same)[1:-1]
    kern_reg, act_reg = self.kern_reg, self.act_reg
    upsampled = []
    out_name = 'prediction' if index == (self.steps+1) else 'resized_%d' % (self.steps + 1 - index)
    
    for k in self.kernel_sizes:
      if k == 1 and len(self.kernel_sizes) != 1:
        continue
      
      x = self.deconv(self.filters // 2, (k, k), activation='relu',
#                       kernel_regularizer=kern_reg, activity_regularizer=act_reg,
                      strides=(2, 2), padding='same')(down)
      x = self.check_shape(x)
      upsampled.append(x)
    
    concatenated = Concatenate()([*upsampled, same])
    up_sampled = Conv2D(2 * self.filters, (1, 1), activation='relu',
                        kernel_regularizer=kern_reg, activity_regularizer=act_reg,
                        padding='same')(concatenated)
    
#     up_sampled = BatchNormalization()(up_sampled)
    
    out_sampled = Conv2D(self.out_channels + 1, (1, 1), activation='softmax',
                         padding='same', name=out_name)(up_sampled)
    
    up_sampled = Dropout(self.dropout_rate)(up_sampled)
    
    return [up_sampled, out_sampled]
  
#   def build_toy_model(self):
#     input_img = Input(self.img_shape)
#     x = Conv2D(self.filters, (3, 3), activation='relu', padding='same')(input_img)
#     x = Conv2D(self.filters, (5, 5), activation='relu', padding='same')(x)
#     x = Conv2D(self.filters, (7, 7), activation='relu', padding='same')(x)
#     out = Conv2D(self.out_channels + 1, (1, 1), activation='softmax', padding='same', name='prediction')(x)
#     self.model = Model(input_img, out)
  
  def build_model(self):
    """
    The model has the downsample stage and the upsample stage.
    The downsample stage has the original image and n steps of the downsampling block.
    The upsample stage has n steps of the upsampling block and the original image.
    The loss is computed over the last result (original image shape) and m resized results (up_sampled).
    """
    input_img = [Input(self.img_shape, name='or_image')]

    # downsampling stage
    down_sampled = [input_img[-1]]
    for i in range(self.steps):
      if self.pre_resized:
        resized_shape = self.img_shape[0] // (2 ** (i+1)), self.img_shape[1] // (2 ** (i+1)), self.img_shape[2]
        input_img.append(Input(resized_shape, name='resized_image_%d' % (i+1)))
      
      down_sampled.append(self.down_sampling_block(input_img[-1], down_sampled[i], i))

    # upsampling stage
    up_sampled = [down_sampled[-1]]
    up_results = down_sampled[-1],
    for i in range(2, self.steps+2):
      up_results = self.up_sampling_block(up_results[0],
                                          down_sampled[-i], i)
      up_sampled.append(up_results[1])
    
    up_sampled = up_sampled[-(self.out_resized_levels + 1):]
    if len(up_sampled) == 1:
      up_sampled = up_sampled[0]
      
    # model
    model = Model(input_img, up_sampled)
    assert self.img_shape[:-1] == model.layers[-1].output_shape[1:-1]
    
    self.model = model
#     self.build_toy_model()

### Set parameters

In [0]:
img_shape = 512, 512, 3
filters = 128  # 256
segmentation_classes = 21
steps = 5
out_resized_levels = 0
kernel_sizes = [2, 3, 5, 7]  # [2, 3, 5, 7]
dilation_rates = [1, 2, 3]  # [1, 2, 3]
use_max_pooling = True
use_depthwise = True
pre_resized = False

batch_size = 1
epochs = 1
val_steps = 10

### Build generators

In [0]:
# train_lists, val_lists = get_lists_from_folders(train_list_path, val_list_path, imgs_folder, classes_folder)

# # train_lists = train_lists[0][:100], train_lists[1][:100]
# # val_lists = val_lists[0][:100], val_lists[1][:100]

# train_arrays = get_imgs_classes_arrays(*train_lists, img_shape)
# val_arrays = get_imgs_classes_arrays(*val_lists, img_shape)

In [0]:
def get_pre_resized(batch_imgs, steps):
  pre_resized_imgs = {'or_image': batch_imgs}
  or_size = batch_imgs.shape[1:-1]
  for i in range(1, steps+1):
    key = 'resized_image_%d' % i
    size = or_size[0] // (2 ** i), or_size[1] // (2 ** i)
    value = np.zeros((batch_imgs.shape[0], size[0], size[1], batch_imgs.shape[3]))
    
    for batch in range(batch_imgs.shape[0]):
      value[batch] = cv2.resize(batch_imgs[batch], size)
      
    pre_resized_imgs[key] = value
#   for key, value in pre_resized_imgs.items():
#     print(key, value.shape)
  return pre_resized_imgs

In [0]:
def imgs_generator(rgb_imgs, num_classes, batch_size, out_resized_levels, segmentation_classes, pre_resized, steps):
  while True:
    inds = np.random.randint(0, rgb_imgs.shape[0], batch_size)
    batch_imgs = rgb_imgs[inds]
    if pre_resized:
      batch_imgs = get_pre_resized(batch_imgs, steps)
    batch_classes = get_resized(num_classes[inds], out_resized_levels, segmentation_classes)[::-1]
    if isinstance(batch_classes, list):
      classes_dict = {'prediction': batch_classes.pop()}
      for i in range(1, out_resized_levels + 1):
        classes_dict['resized_%d' % i] = batch_classes.pop()
    else:
      classes_dict = {'prediction': batch_classes}
    yield batch_imgs, classes_dict

In [0]:
steps_per_epoch= len(train_lists[0]) // batch_size
train_gen = imgs_generator(*train_arrays, batch_size, out_resized_levels, segmentation_classes, pre_resized, steps)
val_gen = imgs_generator(*val_arrays, batch_size, out_resized_levels, segmentation_classes, pre_resized, steps)

### Build model

In [129]:
atrous_unet = AtrousUnet(img_shape, filters, segmentation_classes, steps, out_resized_levels,
                         kernel_sizes, use_max_pooling=use_max_pooling, use_depthwise=use_depthwise,
                         pre_resized=pre_resized, use_regularizers=False)

atrous_unet.build_model()
print('number of parameters:', atrous_unet.model.count_params())

number of parameters: 9802913


### Compile model

In [130]:
optimizer = Adam(0.0001)
loss, metrics, loss_weights = get_loss_metrics_weights(out_resized_levels, use_dice=True, loss_factor=10.)

# loss_weights = [loss_factor / out_resized_levels] * out_resized_levels + [loss_factor]


atrous_unet.model.compile(optimizer=optimizer, loss=loss,
                          metrics=metrics, loss_weights=loss_weights)
if colab:
    %mkdir {wdir}/logs/

mkdir: cannot create directory ‘/content/semantic_segmentation/logs/’: File exists


### Train model

In [0]:
class_weight = None  # get_class_weight(segmentation_classes, background_ratio=1/1)
weights_path = wdir + '/logs/weights.hdf5'
checkpointer = ModelCheckpoint(filepath=weights_path, verbose=1, save_best_only=True)
callbacks = [checkpointer]

In [0]:
# preds = atrous_unet.model.predict(np.zeros((5, *img_shape)))
# for pr in preds:
#   print(pr.shape)

In [0]:
history = atrous_unet.model.fit_generator(train_gen,
                                          steps_per_epoch=steps_per_epoch, epochs=20,
                                          verbose=1, validation_data=val_gen, validation_steps=val_steps,
                                          class_weight=class_weight, callbacks=callbacks)

with open(wdir+'/logs/history.pkl', 'wb') as handle:
    pickle.dump(history.history, handle, protocol=pickle.HIGHEST_PROTOCOL)

Epoch 1/20
 230/1464 [===>..........................] - ETA: 1:14:50 - loss: -6.6185 - dice_coef: 0.6619 - categorical_accuracy: 0.6634

### Download results (weights and history)

In [0]:
# files.download(wdir+'/logs/weights.hdf5')
# files.download(wdir+'/logs/history.pkl')

### Load weights

In [0]:
atrous_unet = AtrousUnet(img_shape, filters, segmentation_classes, steps, out_resized_levels,
                         kernel_sizes,use_max_pooling=use_max_pooling, use_depthwise=use_depthwise)

atrous_unet.build_model()
atrous_unet.model.load_weights(wdir + '/logs/weights.hdf5')

### Plot results

In [0]:
def get_images_from_predictions(preds):
  cmap_dict = get_cmap_dict(reversed=True)
  preds = np.argmax(preds, axis=-1)
  imgs = np.zeros((*preds.shape, 3))
  for i, pred in enumerate(preds):
    for j in range(pred.shape[0]):
      for k in range(pred.shape[1]):
        imgs[i, j, k, :] = cmap_dict[pred[j, k]]
  
  return imgs

In [0]:
imgs, labels = val_gen.__next__()
predictions = atrous_unet.model.predict_on_batch(imgs)
pred_labels = predictions[-1] if len(predictions[0].shape) > 3 else predictions
pred_labels = get_images_from_predictions(pred_labels)
real_labels = labels['prediction']  # labels[-1] if isinstance(labels, list) and len(labels[0].shape) > 3 else labels
real_labels = get_images_from_predictions(real_labels)

In [0]:
np.max(np.argmax(predictions[-1], axis=-1))

In [0]:
for img, real_label, pred_label in zip(imgs, real_labels, pred_labels):
  f, (ax1, ax2, ax3) = plt.subplots(1, 3)
  ax1.imshow(img)
  ax1.axis('off')
  ax2.imshow(real_label)
  ax2.axis('off')
  ax3.imshow(pred_label)
  ax3.axis('off')
  plt.show()