[View in Colaboratory](https://colab.research.google.com/github/dkatsios/semantic_segmentation/blob/master/AtrusUnet_2.ipynb)

In [0]:
# # https://pypi.python.org/pypi/pydot
# !pip install pydot
# !pip install graphviz
# !apt-get install graphviz
# import pydot

In [0]:
from keras.layers import Conv2D, Input, Concatenate, Deconv2D, Lambda, ZeroPadding2D, SeparableConv2D, BatchNormalization, MaxPooling2D, Dropout
from keras.regularizers import l1, l2, activity_l1, activity_l2, activity_l1l2
from keras.models import Model

import keras.backend as K
import tensorflow as tf

import numpy as np
np.random.seed(0)

# from IPython.display import SVG
# from keras.utils.vis_utils import model_to_dot

In [0]:
class AtrousUnet:
  def __init__(self, img_shape, filters, out_channels,
               steps, out_levels, kernel_sizes=None, dilation_rates=None, use_depthwise=False, dropout_rate=0.5):
    self.img_shape = img_shape
    self.filters = filters
    self.out_channels = out_channels
    self.steps = steps
    self.out_levels = out_levels
    self.dropout_rate = dropout_rate
    self.kernel_sizes = kernel_sizes if kernel_sizes is not None else [2, 3, 5, 7]
    self.dilation_rates = dilation_rates if dilation_rates is not None else [1, 2, 3]
    self.conv = SeparableConv2D if use_depthwise else Conv2D
    self.deconv = Deconv2D  # SeparableConvolution2D if use_depthwise else Deconv2D 
    
    assert 2 ** self.steps <= np.min(img_shape[:2]) and \
    self.out_levels <= self.steps
    
  def resize_img(self, input_tensor):
    return tf.image.resize_images(input_tensor[0],
                                  input_tensor[1].shape[1:-1])
  
  def check_shape(self, x):
    x_shape = K.int_shape(x)[1:-1]
    if x_shape == self.current_shape:
      return x
    
    dr = self.current_shape[0] - x_shape[0]
    if dr == 0:
      r_pad = 0, 0
    elif dr % 2 == 0:
      r_pad = dr // 2, dr // 2
    else:
      r_pad = dr // 2, dr // 2 + 1
    
    dc = self.current_shape[1] - x_shape[1]
    if dc == 0:
      c_pad = 0, 0
    elif dc % 2 == 0:
      c_pad = dc // 2, dc // 2
    else:
      c_pad = dc // 2, dc // 2 + 1
      
    assert np.abs(np.array([r_pad, c_pad])).all() <= 1
    if dr < 0 or dc < 0:
      r = -r_pad[0], x_shape[0] + r_pad[1]
      c = -c_pad[0], x_shape[1] + c_pad[1]
      x = Lambda(lambda x: x[:,r[0]:r[1],c[0]:c[1],:])(x)
    else:
      x = ZeroPadding2D((r_pad, c_pad))(x)
    return x
  
  def down_sampling_block(self, input_img, input_tensor, ind):
    """
    The downsampling block takes as input the original image (input_img)
    and the output of the previous downsampling block (input_tensor).
    Steps:
      - downsampling (kernel 2x2, strides 2) of 'input_tensor' to half its shape to 'down_sampled'
      - for each kernel size:
        - convolution with this kernel size, strides 1 and dilation 2
        - downsampling (convolution) with kernel 2x2 and strides 2
      - concatenation of the last one with 'down_sampled' to 'down_sampled'
      - resize of input_img to the same size as 'down_sampled' to 'resized_img'
      - convolution of 'down_sampled'  with kernel 1x1
      - concatenation of 'down_sampled' with 'resized_img' to total channels 2*filters
      - batch normalization.
    """
    ###############
    dilated = []
    for k in self.kernel_sizes:
      for dr in self.dilation_rates:
        x = self.conv(self.filters // 2, (k, k), strides=(1, 1), activation='relu',
                      depthwise_regularizer=l2(), pointwise_regularizer=l2(), activity_regularizer=activity_l1l2(),
                      padding='same', dilation_rate=(dr, dr))(input_tensor)
        dilated.append(x)
    ###############        
    concatenated = Concatenate()([*dilated])
    down_sampled = self.conv(2 * self.filters - input_img.get_shape().as_list()[-1],
                             (3, 3), strides=(2, 2), activation='relu',
                             depthwise_regularizer=l2(), pointwise_regularizer=l2(),
                             activity_regularizer=activity_l1l2(),
                             padding='same')(concatenated)
    down_sampled = BatchNormalization()(down_sampled)
    down_sampled = Dropout(self.dropout_rate)(down_sampled)
    ###############
    resized_img = Lambda(self.resize_img)([input_img, down_sampled])
    merged = Concatenate()([down_sampled, resized_img])
    ###############

    return merged
  
  def up_sampling_block(self, down, same, ind=None):
    """
    The upsampling block takes as input the output of the previous upsampling block (down)
    and the output of the downsampling block of the same level (same).
    Steps:
      - upsampling (deconvolution) of 'down' to the shape of 'same'
      - concatenation of the previous one with the 'same' to the 'concatenated'
      - convolution of the 'concatenated' to same shape and standard filters
      - batch normalization of previous
      - optionally convolves to 1x1 filters and output channels for intermediate loss.
    """
    self.current_shape = K.int_shape(same)[1:-1]
    upsampled = []
    
    for k in self.kernel_sizes:
      x = self.deconv(self.filters // 2, (k, k), activation='relu',
                      depthwise_regularizer=l2(), pointwise_regularizer=l2(),
                      activity_regularizer=activity_l1l2(),
                      strides=(2, 2), padding='same')(down)
      x = self.check_shape(x)
      upsampled.append(x)
    
    concatenated = Concatenate()([*upsampled, same])
    up_sampled = self.conv(2 * self.filters, (1, 1), activation='relu',
                           depthwise_regularizer=l2(), pointwise_regularizer=l2(),
                           activity_regularizer=activity_l1l2(),
                           padding='same')(concatenated)
    
    up_sampled = BatchNormalization()(up_sampled)
    
    out_sampled = self.conv(self.out_channels+1, (1, 1), activation='softmax',
                            depthwise_regularizer=l2(), pointwise_regularizer=l2(),
                            activity_regularizer=activity_l1l2(),
                            padding='same')(up_sampled)
    
    up_sampled = Dropout(self.dropout_rate)(up_sampled)
    
    return [up_sampled, out_sampled]
  
  def build_model(self):
    """
    The model has the downsample phase and the upsample phase.
    The downsample phase has the original image and n steps of the outputs of the downsampling block.
    The upsample phase has n steps of the outputs of the upsampling block and the original image.
    The loss is computed over the last result (original image shape) and m resized results (up_sampled).
    """
    input_img = Input(self.img_shape)
    down_sampled = [input_img]

    for i in range(self.steps):
      down_sampled.append(self.down_sampling_block(input_img,
                                                   down_sampled[i], i))

    up_sampled = [down_sampled[-1]]
    
    up_results = down_sampled[-1],
    for i in range(2, self.steps+1):
      up_results = self.up_sampling_block(up_results[0],
                                          down_sampled[-i], self.steps-i)
      
      up_sampled.append(up_results[1])
      
    up_results = self.up_sampling_block(up_results[0], input_img)
    
    up_sampled = up_sampled[-self.out_levels:]
    if self.out_levels == 0:
      up_sampled = []
    
    up_results[0] = BatchNormalization()(up_results[0])
    out_img = self.conv(self.out_channels+1, (1, 1), activation='softmax',
                        name='out_img')(up_results[0])
    up_sampled.append(out_img)
    model = Model(input_img, up_sampled)
    
    
    assert self.img_shape[:-1] == model.layers[-1].output_shape[1:-1]
    
    self.model = model

In [0]:
if __name__ == '__main__':
  img_shape = 256, 256, 3
  filters = 64
  out_channels = 20
  steps = 5
  out_resized_levels = 2
  kernel_sizes = [2, 3, 5]
  atrous_unet = AtrousUnet(img_shape, filters, out_channels,
                           steps, out_resized_levels, kernel_sizes)
  atrous_unet.build_model()
  losses = ['mse'] * (out_resized_levels + 1)  # binary_crossentropy
  atrous_unet.model.compile(optimizer='Adam', loss=losses, metrics=['accuracy'])

  print('number of parameters:', atrous_unet.model.count_params())
  # print(atrous_unet.model.summary())
  # SVG(model_to_dot(atrous_unet.model, show_shapes=True,
  #                  show_layer_names=False).create(prog='dot', format='svg'))