In [1]:
#!usr/bin/env python3
# -*- coding: utf-8 -*-

In [6]:
import tensorflow as tf
import numpy as np
import const as C

In [3]:
# filter = [kernel_height, kernel_width, input_channel, output_channel]
# 畳み込み層(エンコーダ部分)
def conv2d(x, weight, stride=2, batch_norm=True, is_training=True, leaky_relu=True):
    net = tf.nn.conv2d(x,
                       filter=weight,
                       strides=[1, stride, stride, 1],
                       padding='SAME')
    if batch_norm:
        net = tf.layers.batch_normalization(net, training=is_training)
    if leaky_relu:
        net = tf.nn.leaky_relu(net, 0.2)
    return net

In [4]:
# filter = [kernel_height, kernel_width, output_channel, input_channel]
# output_shape = [バッチ数, 得たいheight, 得たいwidth, 得たいchannel]
# 逆畳み込み層(デコーダ部分)
def de_conv2d(x, weight, output_shape, stride=2, batch_norm=True, is_training=True, relu=True):
    net = tf.nn.conv2d_transpose(x,
                                 filter=weight,
                                 output_shape=output_shape,
                                 strides=[1, stride, stride, 1],
                                 padding='SAME')
    if batch_norm:
        net = tf.layers.batch_normalization(net, training=is_training)
    if relu:
        net = tf.nn.relu(net)
    return net

In [None]:
# concat:連結処理
def concat_upconv(input_A, input_B, weight):
    de_conv = de_conv2d(input_A, weight)
    return tf.concat([de_conv, input_B], axis=-1)

In [None]:
def UNet():
    net = {}
    net['input'] = tf.Variable(np.zeros((C.BATCH_SIZE, C.IMAGE_HEIGHT, C.PATCH_LENGTH, 1)).astype('float32'))
    net['conv1'] = conv2d(net['input'], weight=[5, 5, 1, 16])
    net['conv2'] = conv2d(net['conv1'], weight=[5, 5, 16, 32])
    net['conv3'] = conv2d(net['conv2'], weight=[5, 5, 32, 64])
    net['conv4'] = conv2d(net['conv3'], weight=[5, 5, 64, 128])
    net['conv5'] = conv2d(net['conv4'], weight=[5, 5, 128, 256])
    net['conv6'] = conv2d(net['conv5'], weight=[5, 5, 256, 512])
    net['de_conv1'] = de_conv2d(net['conv6'], weight=[5, 5, 256, 512], output_shape=(16, 4, 256))
    net['de_conv2'] = de_conv2d(net['de_conv1'], weight=[5, 5, 128, 256], output_shape=(32, 8, 128))
    net['de_conv3'] = de_conv2d(net['de_conv2'], weight=[5, 5, 64, 128], output_shape=(64, 16, 64))
    net['de_conv4'] = de_conv2d(net['de_conv3'], weight=[5, 5, 32, 64], output_shape=(128, 32, 32))
    net['de_conv5'] = de_conv2d(net['de_conv4'], weight=[5, 5, 16, 32], output_shape=(256, 16, 16))
    net['de_conv6'] = de_conv2d(net['de_conv5'], weight=[5, 5, 1, 16], output_shape=(512, 128, 1))
    return net