[View in Colaboratory](https://colab.research.google.com/github/dkatsios/semantic_segmentation/blob/master/AtrusUnet_2.ipynb)

In [0]:
# # https://pypi.python.org/pypi/pydot
# !pip install pydot
# !pip install graphviz
# !apt-get install graphviz
# import pydot

In [0]:
from keras.layers import Conv2D, Input, Concatenate, Deconv2D, Lambda, Flatten,\
ZeroPadding2D, SeparableConv2D, BatchNormalization, Dropout, MaxPooling2D, Dense
from keras.regularizers import l1, l2
from keras.models import Model

import keras.backend as K
import tensorflow as tf

import numpy as np
np.random.seed(0)

In [0]:
class AtrousUnet:
  def __init__(self, img_shape, filters, out_channels,
               steps, out_resized_levels, kernel_sizes=None, dilation_rates=None, use_depthwise=False,
               use_max_pooling=True, use_regularizers=True, pre_resized=False, classify=False, dropout_rate=0.4):
    self.img_shape = img_shape
    self.filters = filters
    self.out_channels = out_channels
    self.steps = steps
    self.out_resized_levels = out_resized_levels
    self.dropout_rate = dropout_rate
    self.use_max_pooling = use_max_pooling
    self.use_regularizers = use_regularizers
    self.pre_resized = pre_resized
    self.classify = classify
    self.kernel_sizes = kernel_sizes if kernel_sizes is not None else [2, 3, 5, 7]
    self.dilation_rates = dilation_rates if dilation_rates is not None else [1, 2, 3]
    self.kern_reg, self.act_reg = (l2(), l1()) if self.use_regularizers else (None, None)
    self.conv = SeparableConv2D if use_depthwise else Conv2D
    self.deconv = Deconv2D  # SeparableConvolution2D if use_depthwise else Deconv2D 
    
    assert 2 ** self.steps <= np.min(img_shape[:2]) and \
    self.out_resized_levels <= self.steps
    
  def resize_img(self, input_tensor):
    return tf.image.resize_images(input_tensor[0],
                                  input_tensor[1].shape[1:-1])
  
  def check_shape(self, x):
    x_shape = K.int_shape(x)[1:-1]
    if x_shape == self.current_shape:
      return x
    
    dr = self.current_shape[0] - x_shape[0]
    if dr == 0:
      r_pad = 0, 0
    elif dr % 2 == 0:
      r_pad = dr // 2, dr // 2
    else:
      r_pad = dr // 2, dr // 2 + 1
    
    dc = self.current_shape[1] - x_shape[1]
    if dc == 0:
      c_pad = 0, 0
    elif dc % 2 == 0:
      c_pad = dc // 2, dc // 2
    else:
      c_pad = dc // 2, dc // 2 + 1
      
    assert np.abs(np.array([r_pad, c_pad])).all() <= 1
    if dr < 0 or dc < 0:
      r = -r_pad[0], x_shape[0] + r_pad[1]
      c = -c_pad[0], x_shape[1] + c_pad[1]
      x = Lambda(lambda x: x[:,r[0]:r[1],c[0]:c[1],:], name='sh_lambda_%d' % x_shape[0])(x)
    else:
      x = ZeroPadding2D((r_pad, c_pad))(x)
    return x
  
  def down_sampling_block(self, input_img, input_tensor, ind):
    """
    The downsampling block takes as input the original image (input_img)
    and the output of the previous downsampling block (input_tensor).
    Steps:
      - downsampling (kernel 2x2, strides 2) of 'input_tensor' to half its shape to 'down_sampled'
      - for each kernel size:
        - convolution with this kernel size, strides 1 and dilation 2
        - downsampling (convolution) with kernel 2x2 and strides 2
      - concatenation of the last one with 'down_sampled' to 'down_sampled'
      - resize of input_img to the same size as 'down_sampled' to 'resized_img'
      - convolution of 'down_sampled'  with kernel 1x1
      - concatenation of 'down_sampled' with 'resized_img' to total channels 2*filters
      - batch normalization.
    """
    ###############
    kern_reg, act_reg = self.kern_reg, self.act_reg
    dilated = []
    for k in self.kernel_sizes:
      for d in self.dilation_rates:
        name = 'ds_%d_conv_kern_%d_dil_%d' % (ind, k, d)
        x = self.conv(self.filters // 2, (k, k), strides=(1, 1), activation='relu',
#                       depthwise_regularizer=kern_reg, pointwise_regularizer=kern_reg, activity_regularizer=act_reg,
                      padding='same', dilation_rate=(d, d),
                      name=name)(input_tensor)
        dilated.append(x)
    ###############        
    concatenated = Concatenate(name='ds_%d_conc' % ind)([*dilated])
    if self.use_max_pooling:
      shortened = Conv2D(2 * self.filters, (1, 1), activation='relu',
                         kernel_regularizer=kern_reg, activity_regularizer=act_reg,
                         padding='same', name='ds_%d_conv_short' % ind)(concatenated)
      down_sampled = MaxPooling2D(name='ds_%d_maxpool' % ind)(shortened)
    else:
      down_sampled = self.conv(2 * self.filters, (3, 3), strides=(2, 2), activation='relu',
                               depthwise_regularizer=kern_reg, pointwise_regularizer=kern_reg,
                               activity_regularizer=act_reg,
                               padding='same',
                               name='ds_%d_conv_downsamp' % ind)(concatenated)
    
#     down_sampled = BatchNormalization()(down_sampled)
    down_sampled = Dropout(self.dropout_rate, name='ds_%d_dropout' % ind)(down_sampled)
    ###############
    if self.pre_resized:
      self.current_shape = K.int_shape(down_sampled)[1:-1]
      resized_img = input_img
      resized_img = self.check_shape(resized_img)
    else:
      resized_img = Lambda(self.resize_img, name='ds_%d_lambda' % ind)([input_img, down_sampled])
    merged = Concatenate(name='ds_%d_conc_merged' % ind)([down_sampled, resized_img])
    ###############
    return merged
  
  def up_sampling_block(self, down, same, ind):
    """
    The upsampling block takes as input the output of the previous upsampling block (down)
    and the output of the downsampling block of the same level (same).
    Steps:
      - upsampling (deconvolution) of 'down' to the shape of 'same'
      - concatenation of the previous one with the 'same' to the 'concatenated'
      - convolution of the 'concatenated' to same shape and standard filters
      - batch normalization of previous
      - optionally convolves to 1x1 filters and output channels for intermediate loss.
    """
    self.current_shape = K.int_shape(same)[1:-1]
    kern_reg, act_reg = self.kern_reg, self.act_reg
    upsampled = []
    out_name = 'prediction' if ind == (self.steps+1) else 'resized_%d' % (self.steps + 1 - ind)
    
    for k in self.kernel_sizes:
      if k == 1 and len(self.kernel_sizes) != 1:
        continue
      
      x = self.deconv(self.filters // 2, (k, k), activation='relu',
#                       kernel_regularizer=kern_reg, activity_regularizer=act_reg,
                      strides=(2, 2), padding='same', name='us_%d_deconv_kern_%d' % (ind, k))(down)
      x = self.check_shape(x)
      upsampled.append(x)
    
    concatenated = Concatenate(name='us_%d_conc' % ind)([*upsampled, same])
    up_sampled = Conv2D(2 * self.filters, (1, 1), activation='relu',
                        kernel_regularizer=kern_reg, activity_regularizer=act_reg,
                        padding='same', name='us_%d_conv' % ind)(concatenated)
    
#     up_sampled = BatchNormalization()(up_sampled)
    
    out_sampled = Conv2D(self.out_channels + 1, (1, 1), activation='softmax',
                         padding='same', name=out_name)(up_sampled)
    
    up_sampled = Dropout(self.dropout_rate, name='us_%d_dropout' % ind)(up_sampled)
    
    return [up_sampled, out_sampled]
  
#   def build_toy_model(self):
#     input_img = Input(self.img_shape)
#     x = Conv2D(self.filters, (3, 3), activation='relu', padding='same')(input_img)
#     x = Conv2D(self.filters, (5, 5), activation='relu', padding='same')(x)
#     x = Conv2D(self.filters, (7, 7), activation='relu', padding='same')(x)
#     out = Conv2D(self.out_channels + 1, (1, 1), activation='softmax', padding='same', name='prediction')(x)
#     self.model = Model(input_img, out)

  def get_encoding(self, input, encoding_size):
    x = self.conv(self.filters // 2, (3, 3), strides=(2, 2),
                  padding='same', name='encoding_%d_conv' % encoding_size,activation='relu')(input)
    x = Flatten(x, name='encoding_%d_flatten' % encoding_size)
    name = 'labels' if encoding_size == (self.out_channels + 1) else 'encoding_%d' % encoding_size
    encoding = Dense(encoding_size, activation='sigmoid', name=name)(x)
    return encoding
  
  def build_model(self):
    """
    The model has the downsample stage and the upsample stage.
    The downsample stage has the original image and n steps of the downsampling block.
    The upsample stage has n steps of the upsampling block and the original image.
    The loss is computed over the last result (original image shape) and m resized results (up_sampled).
    """
    input_img = [Input(self.img_shape, name='or_image')]

    # downsampling stage
    down_sampled = [input_img[-1]]
    for i in range(self.steps):
      if self.pre_resized:
        resized_shape = self.img_shape[0] // (2 ** (i+1)), self.img_shape[1] // (2 ** (i+1)), self.img_shape[2]
        input_img.append(Input(resized_shape, name='resized_image_%d' % (i+1)))
      
      down_sampled.append(self.down_sampling_block(input_img[-1], down_sampled[i], i))
    
    # upsampling stage
    up_results = down_sampled[-1],
    for i in range(2, self.steps+2):
      up_results = self.up_sampling_block(up_results[0],
                                          down_sampled[-i], i)
      up_sampled.append(up_results[1])
    
    up_sampled = up_sampled[-(self.out_resized_levels + 1):]
    
    if self.classify:
      labels = self.get_encoding(down_sampled[-1], self.out_channels + 1)
      up_sampled.append(labels)
      
    if len(up_sampled) == 1:
      up_sampled = up_sampled[0]
      
    # model
    model = Model(input_img, up_sampled)
    assert self.img_shape[:-1] == model.layers[-1].output_shape[1:-1]
    
    self.model = model
#     self.build_toy_model()

In [0]:
if __name__ == '__main__':
  img_shape = 256, 256, 3
  filters = 64
  out_channels = 20
  steps = 5
  out_resized_levels = 2
  kernel_sizes = [2, 3, 5]
  atrous_unet = AtrousUnet(img_shape, filters, out_channels,
                           steps, out_resized_levels, kernel_sizes)
  atrous_unet.build_model()
  losses = ['mse'] * (out_resized_levels + 1)  # binary_crossentropy
  atrous_unet.model.compile(optimizer='Adam', loss=losses, metrics=['accuracy'])

  print('number of parameters:', atrous_unet.model.count_params())
  # print(atrous_unet.model.summary())
  # SVG(model_to_dot(atrous_unet.model, show_shapes=True,
  #                  show_layer_names=False).create(prog='dot', format='svg'))