[View in Colaboratory](https://colab.research.google.com/github/dkatsios/semantic_segmentation/blob/master/AtrusUnet.ipynb)

In [0]:
# # https://pypi.python.org/pypi/pydot
# !pip install pydot
# !pip install graphviz
# !apt-get install graphviz
# import pydot

In [0]:
from keras.layers import Conv2D, Input, Concatenate, Deconv2D, Lambda, ZeroPadding2D
from keras.models import Model

import keras.backend as K
import tensorflow as tf

import numpy as np
np.random.seed(0)

# from IPython.display import SVG
# from keras.utils.vis_utils import model_to_dot

In [0]:
class AtrousUnet:
  def __init__(self, img_shape, filters, out_channels,
               steps, out_levels, kernel_sizes=None):
    self.img_shape = img_shape
    self.filters = filters
    self.out_channels = out_channels
    self.steps = steps
    self.out_levels = out_levels
    self.kernel_sizes = kernel_sizes if kernel_sizes is not None else [2, 3, 5, 7]
    
    assert 2 ** self.steps <= np.min(img_shape[:2]) and \
    self.out_levels <= self.steps
    
  def resize_img(self, input_tensor):
    return tf.image.resize_images(input_tensor[0],
                                  input_tensor[1].shape[1:-1])
  
  def check_shape(self, x):
    x_shape = K.int_shape(x)[1:-1]
    if x_shape == self.current_shape:
      return x
    
    dr = self.current_shape[0] - x_shape[0]
    if dr == 0:
      r_pad = 0, 0
    elif dr % 2 == 0:
      r_pad = dr // 2, dr // 2
    else:
      r_pad = dr // 2, dr // 2 + 1
    
    dc = self.current_shape[1] - x_shape[1]
    if dc == 0:
      c_pad = 0, 0
    elif dc % 2 == 0:
      c_pad = dc // 2, dc // 2
    else:
      c_pad = dc // 2, dc // 2 + 1
      
    assert np.abs(np.array([r_pad, c_pad])).all() <= 1
    if dr < 0 or dc < 0:
      r = -r_pad[0], x_shape[0] + r_pad[1]
      c = -c_pad[0], x_shape[1] + c_pad[1]
      x = Lambda(lambda x: x[:,r[0]:r[1],c[0]:c[1],:])(x)
    else:
      x = ZeroPadding2D((r_pad, c_pad))(x)
    return x
  
  def down_sampling_block(self, input_img, input_tensor, ind):
    name = 'down_sampled_{}'.format(ind) if ind is not None else None

    down_sampled = [Conv2D(self.filters, (2, 2), strides=(2, 2), padding='same',
                         dilation_rate=(1, 1), activation='relu',
                          name='x_{}'.format(ind))(input_tensor)]

    for i, k in enumerate(self.kernel_sizes):
      x = Conv2D(self.filters, (k, k), strides=(1, 1), activation='relu',
                 padding='same', dilation_rate=(2, 2))(input_tensor)

      x = Conv2D(self.filters, (2, 2), strides=(2, 2), activation='relu',
                 padding='same')(x)
      
      down_sampled.append(x)

    down_sampled = Concatenate()([*down_sampled])
    resized_img = Lambda(self.resize_img)([input_img, down_sampled])
    down_sampled = Conv2D(2 * self.filters - resized_img.get_shape().as_list()[-1],
                          (1, 1), padding='same', name=name)(down_sampled)
    down_sampled = Concatenate()([down_sampled, resized_img])

    return down_sampled
  
  def up_sampling_block(self, down_1, down_0, ind=None):
    up_name = 'up_sampled_{}'.format(ind) if ind is not None else None
    out_name = 'out_sampled_{}'.format(ind) if ind is not None else None

    up_sampled = Deconv2D(self.filters, (1, 1), activation='relu',
                          strides=(2, 2), padding='same')(down_1)
    
    self.current_shape = K.int_shape(down_0)[1:-1]
    up_sampled = self.check_shape(up_sampled)
    
    up_sampled = Concatenate()([up_sampled, down_0])
    up_sampled = Conv2D(2 * self.filters, (1, 1), activation='relu',
                        padding='same', name=up_name)(up_sampled)
    out_sampled = Conv2D(self.out_channels+1, (1, 1), activation='softmax',
                         padding='same', name=out_name)(up_sampled)
    
    return up_sampled, out_sampled
  
  def build_model(self):
    input_img = Input(self.img_shape)
    down_sampled = [input_img]

    for i in range(self.steps):
      down_sampled.append(self.down_sampling_block(input_img,
                                                   down_sampled[i], i))

    up_sampled = [down_sampled[-1]]
    #############
#     up_results = self.up_sampling_block(down_sampled[-1],
#                                         down_sampled[-2], self.steps-1)
#     up_sampled.append(up_results[1])
    
#     for i in range(3, self.steps+1):
#       up_results = self.up_sampling_block(up_results[0],
#                                      down_sampled[-i], self.steps-i)
#       up_sampled.append(up_results[1])
    #############
    up_results = down_sampled[-1],
    for i in range(2, self.steps+1):
      up_results = self.up_sampling_block(up_results[0],
                                          down_sampled[-i], self.steps-i)
      
      up_sampled.append(up_results[1])
    ############
      
    up_results = self.up_sampling_block(up_results[0], input_img)
    
    up_sampled = up_sampled[-self.out_levels:]
    if self.out_levels == 0:
      up_sampled = []
    
    out_img = Conv2D(self.out_channels+1, (1, 1), activation='softmax',
                        name='out_img')(up_results[0])
    up_sampled.append(out_img)
    model = Model(input_img, up_sampled)
    
    
    assert self.img_shape[:-1] == model.layers[-1].output_shape[1:-1]
    
    self.model = model

In [0]:
if __name__ == '__main__':
  img_shape = 256, 256, 3
  filters = 64
  out_channels = 20
  steps = 5
  out_resized_levels = 2
  kernel_sizes = [2, 3, 5]
  atrous_unet = AtrousUnet(img_shape, filters, out_channels,
                           steps, out_resized_levels, kernel_sizes)
  atrous_unet.build_model()
  losses = ['mse'] * (out_resized_levels + 1)  # binary_crossentropy
  atrous_unet.model.compile(optimizer='Adam', loss=losses, metrics=['accuracy'])

  print('number of parameters:', atrous_unet.model.count_params())
  # print(atrous_unet.model.summary())
  # SVG(model_to_dot(atrous_unet.model, show_shapes=True,
  #                  show_layer_names=False).create(prog='dot', format='svg'))