In [1]:
import tensorflow as tf
import numpy as np
import argparse
import os
import json
import glob
import random
import collections
import math
import time

  from ._conv import register_converters as _register_converters


In [2]:
def down_block(input, padding='same', first_block=False):
    if first_block:
        input = tf.keras.layers.Conv2D(32, 3, padding=padding)(input)
    x = tf.keras.layers.Activation('relu')(input)
    conv1 = x = tf.keras.layers.Conv2D(64, 3, padding=padding)(x)
    x = tf.keras.layers.Activation('relu')(x)
    conv2 = x = tf.layers.Conv2D(32, 3, padding=padding)(x)
    plus = tf.keras.layers.Add()([x, conv2]) if not first_block else conv2
    max_pooling = tf.layers.MaxPooling2D(2,2)(plus)
    return plus, max_pooling

In [10]:
def up_block(x1, x2, padding='same'):
    x = tf.keras.layers.Concatenate(axis=-1)([x1, x2])
    x = tf.keras.layers.Activation('relu')(x)
    conv1 = x = tf.keras.layers.Conv2D(64, 3, padding='same')(x)
    x = tf.keras.layers.Activation('relu')(x)
    conv2 = x = tf.keras.layers.Conv2D(32, 3, padding='same')(x)
    plus = tf.keras.layers.Add()([x1, conv2])
    upsample_layers = tf.keras.layers.Conv2DTranspose(32, 3, 2, padding='same')(plus)
    return upsample_layers

In [11]:
tf.reset_default_graph()
input = tf.keras.layers.Input(shape=[640, 640, 3])

In [12]:
plus_layers = []
_ = input
for i in range(7):
    y, _ = down_block(_, first_block=True)
    plus_layers.append(y)

In [13]:
plus_layers

[<tf.Tensor 'conv2d_2/BiasAdd:0' shape=(?, 640, 640, 32) dtype=float32>,
 <tf.Tensor 'conv2d_5/BiasAdd:0' shape=(?, 320, 320, 32) dtype=float32>,
 <tf.Tensor 'conv2d_8/BiasAdd:0' shape=(?, 160, 160, 32) dtype=float32>,
 <tf.Tensor 'conv2d_11/BiasAdd:0' shape=(?, 80, 80, 32) dtype=float32>,
 <tf.Tensor 'conv2d_14/BiasAdd:0' shape=(?, 40, 40, 32) dtype=float32>,
 <tf.Tensor 'conv2d_17/BiasAdd:0' shape=(?, 20, 20, 32) dtype=float32>,
 <tf.Tensor 'conv2d_20/BiasAdd:0' shape=(?, 10, 10, 32) dtype=float32>]

In [14]:
_ = tf.keras.layers.Conv2DTranspose(32, 3, 2, padding='same')(_)

In [15]:
_

<tf.Tensor 'conv2d_transpose/BiasAdd:0' shape=(?, 10, 10, 32) dtype=float32>

In [16]:
for plus_layer in reversed(plus_layers):
    print(_, plus_layer)
    _ = up_block(_, plus_layer)
    

Tensor("conv2d_transpose/BiasAdd:0", shape=(?, 10, 10, 32), dtype=float32) Tensor("conv2d_20/BiasAdd:0", shape=(?, 10, 10, 32), dtype=float32)
Tensor("conv2d_transpose_1/BiasAdd:0", shape=(?, 20, 20, 32), dtype=float32) Tensor("conv2d_17/BiasAdd:0", shape=(?, 20, 20, 32), dtype=float32)
Tensor("conv2d_transpose_2/BiasAdd:0", shape=(?, 40, 40, 32), dtype=float32) Tensor("conv2d_14/BiasAdd:0", shape=(?, 40, 40, 32), dtype=float32)
Tensor("conv2d_transpose_3/BiasAdd:0", shape=(?, 80, 80, 32), dtype=float32) Tensor("conv2d_11/BiasAdd:0", shape=(?, 80, 80, 32), dtype=float32)
Tensor("conv2d_transpose_4/BiasAdd:0", shape=(?, 160, 160, 32), dtype=float32) Tensor("conv2d_8/BiasAdd:0", shape=(?, 160, 160, 32), dtype=float32)
Tensor("conv2d_transpose_5/BiasAdd:0", shape=(?, 320, 320, 32), dtype=float32) Tensor("conv2d_5/BiasAdd:0", shape=(?, 320, 320, 32), dtype=float32)
Tensor("conv2d_transpose_6/BiasAdd:0", shape=(?, 640, 640, 32), dtype=float32) Tensor("conv2d_2/BiasAdd:0", shape=(?, 640, 640

In [18]:
model = tf.keras.Model([input], [_])
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 640, 640, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 640, 640, 32) 896         input_1[0][0]                    
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 640, 640, 32) 0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 640, 640, 64) 18496       activation_1[0][0]               
__________________________________________________________________________________________________
activation