In [None]:
"""

Octave Residual UNetの実装

"""

In [None]:
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import *

__all__ = ['OctaveConv2D', 'octave_conv_2d']


class OctaveConv2D(Layer):
    """Octave convolutions.
    # Arguments
        octave: The division of the spatial dimensions by a power of 2.
        ratio_out: The ratio of filters for lower spatial resolution.
    # References
        - [Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution]
          (https://arxiv.org/pdf/1904.05049.pdf)
    """

    def __init__(self,
                 filters,
                 kernel_size=(3,3),
                 octave=2,
                 ratio_out=0.125,
                 strides=(1, 1),
                 data_format=None,
                 dilation_rate=(1, 1),
                 activation=None,
                 use_bias=False,
                 use_transpose=False,
                 kernel_initializer='he_normal',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 **kwargs):
        super(OctaveConv2D, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.octave = octave
        self.ratio_out = ratio_out
        self.strides = strides
        self.data_format = data_format
        self.dilation_rate = dilation_rate
        self.use_bias = use_bias
        self.use_transpose = use_transpose
        self.kernel_initializer = kernel_initializer
        self.bias_initializer = bias_initializer
        self.kernel_regularizer = kernel_regularizer
        self.bias_regularizer = bias_regularizer
        self.activity_regularizer = activity_regularizer
        self.kernel_constraint = kernel_constraint
        self.bias_constraint = bias_constraint

        self.filters_low = int(filters * self.ratio_out)
        self.filters_high = filters - self.filters_low

        self.conv_high_to_high, self.conv_low_to_high = None, None
        if self.use_transpose:
          if self.filters_high > 0:
              self.conv_high_to_high = self._init_transconv(self.filters_high, name='{}-Trans-Conv2D-HH'.format(self.name))
              self.conv_low_to_high = self._init_transconv(self.filters_high, name='{}-Conv2D-LH'.format(self.name))
          self.conv_low_to_low, self.conv_high_to_low = None, None
          if self.filters_low > 0:
              self.conv_low_to_low = self._init_transconv(self.filters_low, name='{}-Trans-Conv2D-HL'.format(self.name))
              self.conv_high_to_low = self._init_transconv(self.filters_low, name='{}-Trans-Conv2D-LL'.format(self.name))
          self.pooling = AveragePooling2D(
              pool_size=self.octave,
              padding='valid',
              data_format=data_format,
              name='{}-AveragePooling2D'.format(self.name),
          )
          self.up_sampling = UpSampling2D(
              size=self.octave,
              data_format=data_format,
              name='{}-UpSampling2D'.format(self.name)
          )
        else:
          if self.filters_high > 0:
              self.conv_high_to_high = self._init_conv(self.filters_high, name='{}-Conv2D-HH'.format(self.name))
              self.conv_low_to_high = self._init_conv(self.filters_high, name='{}-Conv2D-LH'.format(self.name))
          self.conv_low_to_low, self.conv_high_to_low = None, None
          if self.filters_low > 0:
              self.conv_low_to_low = self._init_conv(self.filters_low, name='{}-Conv2D-HL'.format(self.name))
              self.conv_high_to_low = self._init_conv(self.filters_low, name='{}-Conv2D-LL'.format(self.name))
          self.pooling = AveragePooling2D(
              pool_size=self.octave,
              padding='valid',
              data_format=data_format,
              name='{}-AveragePooling2D'.format(self.name),
          )
          self.up_sampling = UpSampling2D(
              size=self.octave,
              data_format=data_format,
              name='{}-UpSampling2D'.format(self.name)
          )
    def _init_transconv(self, filters, name):
        return Conv2DTranspose(
            filters=filters,
            kernel_size=self.kernel_size,
            strides=self.strides,
            padding='same',
            data_format=self.data_format,
            dilation_rate=self.dilation_rate,
            use_bias=self.use_bias,
            kernel_initializer=self.kernel_initializer,
            bias_initializer=self.bias_initializer,
            kernel_regularizer=self.kernel_regularizer,
            bias_regularizer=self.bias_regularizer,
            activity_regularizer=self.activity_regularizer,
            kernel_constraint=self.kernel_constraint,
            bias_constraint=self.bias_constraint,
            name=name,
        )

    def _init_conv(self, filters, name):
        return Conv2D(
            filters=filters,
            kernel_size=self.kernel_size,
            strides=self.strides,
            padding='same',
            data_format=self.data_format,
            dilation_rate=self.dilation_rate,
            use_bias=self.use_bias,
            kernel_initializer=self.kernel_initializer,
            bias_initializer=self.bias_initializer,
            kernel_regularizer=self.kernel_regularizer,
            bias_regularizer=self.bias_regularizer,
            activity_regularizer=self.activity_regularizer,
            kernel_constraint=self.kernel_constraint,
            bias_constraint=self.bias_constraint,
            name=name,
        )

    def build(self, input_shape):
        if isinstance(input_shape, list):
            input_shape_high, input_shape_low = input_shape
        else:
            input_shape_high, input_shape_low = input_shape, None
        if self.data_format == 'channels_first':
            channel_axis, rows_axis, cols_axis = 1, 2, 3
        else:
            rows_axis, cols_axis, channel_axis = 1, 2, 3
        if input_shape_high[channel_axis] is None:
            raise ValueError('The channel dimension of the higher spatial inputs '
                             'should be defined. Found `None`.')
        if input_shape_low is not None and input_shape_low[channel_axis] is None:
            raise ValueError('The channel dimension of the lower spatial inputs '
                             'should be defined. Found `None`.')
        if input_shape_high[rows_axis] is not None and input_shape_high[rows_axis] % self.octave != 0 or \
           input_shape_high[cols_axis] is not None and input_shape_high[cols_axis] % self.octave != 0:
            raise ValueError('The rows and columns of the higher spatial inputs should be divisible by the octave. '
                             'Found {} and {}.'.format(input_shape_high, self.octave))
        if input_shape_low is None:
            self.conv_low_to_high, self.conv_low_to_low = None, None

        if self.conv_high_to_high is not None:
            with K.name_scope(self.conv_high_to_high.name):
                self.conv_high_to_high.build(input_shape_high)
        if self.conv_low_to_high is not None:
            with K.name_scope(self.conv_low_to_high.name):
                self.conv_low_to_high.build(input_shape_low)
        if self.conv_high_to_low is not None:
            with K.name_scope(self.conv_high_to_low.name):
                self.conv_high_to_low.build(input_shape_high)
        if self.conv_low_to_low is not None:
            with K.name_scope(self.conv_low_to_low.name):
                self.conv_low_to_low.build(input_shape_low)
        super(OctaveConv2D, self).build(input_shape)

    @property
    def trainable_weights(self):
        weights = []
        if self.conv_high_to_high is not None:
            weights += self.conv_high_to_high.trainable_weights
        if self.conv_low_to_high is not None:
            weights += self.conv_low_to_high.trainable_weights
        if self.conv_high_to_low is not None:
            weights += self.conv_high_to_low.trainable_weights
        if self.conv_low_to_low is not None:
            weights += self.conv_low_to_low.trainable_weights
        return weights

    @property
    def non_trainable_weights(self):
        weights = []
        if self.conv_high_to_high is not None:
            weights += self.conv_high_to_high.non_trainable_weights
        if self.conv_low_to_high is not None:
            weights += self.conv_low_to_high.non_trainable_weights
        if self.conv_high_to_low is not None:
            weights += self.conv_high_to_low.non_trainable_weights
        if self.conv_low_to_low is not None:
            weights += self.conv_low_to_low.non_trainable_weights
        return weights

    def compute_output_shape(self, input_shape):
        if isinstance(input_shape, list):
            input_shape_high, input_shape_low = input_shape
        else:
            input_shape_high, input_shape_low = input_shape, None

        output_shape_high = None
        if self.filters_high > 0:
            output_shape_high = self.conv_high_to_high.compute_output_shape(input_shape_high)
        output_shape_low = None
        if self.filters_low > 0:
            output_shape_low = self.conv_high_to_low.compute_output_shape(
                self.pooling.compute_output_shape(input_shape_high),
            )

        if self.filters_low == 0:
            return output_shape_high
        if self.filters_high == 0:
            return output_shape_low
        return [output_shape_high, output_shape_low]

    def call(self, inputs, **kwargs):
        if isinstance(inputs, list):
            inputs_high, inputs_low = inputs
        else:
            inputs_high, inputs_low = inputs, None

        outputs_high_to_high, outputs_low_to_high = 0.0, 0.0
        if self.use_transpose:
          if self.conv_high_to_high is not None:
              outputs_high_to_high = self.conv_high_to_high(inputs_high)
          if self.conv_low_to_high is not None:
              outputs_low_to_high = self.up_sampling(self.conv_low_to_high(inputs_low))
          outputs_high = outputs_high_to_high + outputs_low_to_high

          outputs_low_to_low, outputs_high_to_low = 0.0, 0.0
          if self.conv_low_to_low is not None:
              outputs_low_to_low = self.conv_low_to_low(inputs_low)
          if self.conv_high_to_low is not None:
              outputs_high_to_low = self.pooling(self.conv_high_to_low(inputs_high))
          outputs_low = outputs_low_to_low + outputs_high_to_low

          if self.filters_low == 0:
              return outputs_high
          if self.filters_high == 0:
              return outputs_low
        else:
          if self.conv_high_to_high is not None:
              outputs_high_to_high = self.conv_high_to_high(inputs_high)
          if self.conv_low_to_high is not None:
              outputs_low_to_high = self.up_sampling(self.conv_low_to_high(inputs_low))
          outputs_high = outputs_high_to_high + outputs_low_to_high

          outputs_low_to_low, outputs_high_to_low = 0.0, 0.0
          if self.conv_low_to_low is not None:
              outputs_low_to_low = self.conv_low_to_low(inputs_low)
          if self.conv_high_to_low is not None:
              outputs_high_to_low = self.conv_high_to_low(self.pooling(inputs_high))
          outputs_low = outputs_low_to_low + outputs_high_to_low

          if self.filters_low == 0:
              return outputs_high
          if self.filters_high == 0:
              return outputs_low
        return [outputs_high, outputs_low]

    def get_config(self):
        config = {
            'filters': self.filters,
            'kernel_size': self.kernel_size,
            'octave': self.octave,
            'ratio_out': self.ratio_out,
            'strides': self.strides,
            'data_format': self.data_format,
            'dilation_rate': self.dilation_rate,
            'use_bias': self.use_bias,
            'kernel_initializer': self.kernel_initializer,
            'bias_initializer': self.bias_initializer,
            'kernel_regularizer': self.kernel_regularizer,
            'bias_regularizer': self.bias_regularizer,
            'activity_regularizer': self.activity_regularizer,
            'kernel_constraint': self.kernel_constraint,
            'bias_constraint': self.bias_constraint
        }
        base_config = super(OctaveConv2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import math
import numpy as np
import tensorflow as tf
from sklearn.utils import class_weight

#from keras_radam.training import RAdamOptimizer
from tensorflow.keras import layers
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler

def OctResUnet(pretrained_weights = None,input_size = (256,256,3)):
    inputs = Input(input_size) 
    # downsampling for lower
    low = layers.AveragePooling2D(2)(inputs)
 #To Do : Encoderの実装 
 #To Do : Q1実装 
#main network
    high1, low1 = OctaveConv2D(64)([inputs,low])   
    high1 = layers.BatchNormalization()(high1_1)
    high1 = layers.Activation("relu")(high1)
    low1 = layers.BatchNormalization()(low1_1)
    low1 = layers.Activation("relu")(low1)   
    high1, low1 = OctaveConv2D(64)([high1, low1])
#shortcut network
    shortcut_high1, shortcut_low1 = OctaveConv2D(64)([inputs,low])
    shortcut_high1 = layers.BatchNormalization()(shortcut_high1)
    shortcut_low1 = layers.BatchNormalization()(shortcut_low1)
    
    output_high1 = layers.Add()([high1, shortcut_high1])
    output_low1 = layers.Add()([low1, shortcut_low1])

    pool1high = layers.MaxPooling2D(2)(output_high1)
    pool1low = layers.MaxPooling2D(2)(output_low1)
    
#To Do : Q2の実装
#main network
    high2_2 = layers.BatchNormalization()(pool1high)
    high2_2 = layers.Activation("relu")(high2_2)
    
    low2_2 = layers.BatchNormalization()(pool1low)
    low2_2 = layers.Activation("relu")(low2_2)
    
    high2_2, low2_2 = OctaveConv2D(128)([high2_2,low2_2])
    
    high2_2 = layers.BatchNormalization()(high2_2)
    high2_2 = layers.Activation("relu")(high2_2)
    
    low2_2 = layers.BatchNormalization()(low2_2)
    low2_2 = layers.Activation("relu")(low2_2)
    
    high2_2, low2_2 = OctaveConv2D(128)([high2_2,low2_2])    

#shortcut network
    high2, low2 = OctaveConv2D(128)([pool1high,pool1low])
    shortcut_high2 = layers.BatchNormalization()(high2)
    shortcut_low2 = layers.BatchNormalization()(low2)
    
    output_high2 = layers.Add()([high2_2, shortcut_high2])
    output_low2 = layers.Add()([low2_2, shortcut_low2])

    pool2high = layers.MaxPooling2D(2)(output_high2)
    pool2low = layers.MaxPooling2D(2)(output_low2)    

#To Do : Q3の実装
#main network
    high3_3 = layers.BatchNormalization()(pool2high)
    high3_3 = layers.Activation("relu")(high3_3)
    
    low3_3 = layers.BatchNormalization()(pool2low)
    low3_3 = layers.Activation("relu")(low3_3)
    
    high3_3, low3_3 = OctaveConv2D(256)([high3_3,low3_3])
    
    high3_3 = layers.BatchNormalization()(high3_3)
    high3_3 = layers.Activation("relu")(high3_3)
    
    low3_3 = layers.BatchNormalization()(low3_3)
    low3_3 = layers.Activation("relu")(low3_3)
    
    high3_3, low3_3 = OctaveConv2D(256)([high3_3,low3_3])    

#shortcut network
    high3, low3 = OctaveConv2D(256)([pool2high,pool2low])
    shortcut_high3 = layers.BatchNormalization()(high3)
    shortcut_low3 = layers.BatchNormalization()(low3)
    
    output_high3 = layers.Add()([high3_3, shortcut_high3])
    output_low3 = layers.Add()([low3_3, shortcut_low3])

    pool3high = layers.MaxPooling2D(2)(output_high3)
    pool3low = layers.MaxPooling2D(2)(output_low3)    

#To Do : Q4の実装
#main network
    high4_4 = layers.BatchNormalization()(pool3high)
    high4_4 = layers.Activation("relu")(high4_4)
    
    low4_4 = layers.BatchNormalization()(pool3low)
    low4_4 = layers.Activation("relu")(low4_4)
    
    high4_4, low4_4 = OctaveConv2D(512)([high4_4,low4_4])
    
    high4_4 = layers.BatchNormalization()(high4_4)
    high4_4 = layers.Activation("relu")(high4_4)
    
    low4_4 = layers.BatchNormalization()(low4_4)
    low4_4 = layers.Activation("relu")(low4_4)
    
    high4_4, low4_4 = OctaveConv2D(512)([high4_4,low4_4])    

#shortcut network
    high4, low4 = OctaveConv2D(512)([pool3high,pool3low])
    shortcut_high4 = layers.BatchNormalization()(high4)
    shortcut_low4 = layers.BatchNormalization()(low4)
    
    output_high4 = layers.Add()([high4_4, shortcut_high4])
    output_low4 = layers.Add()([low4_4, shortcut_low4])

    pool4high = layers.MaxPooling2D(2)(output_high4)
    pool4low = layers.MaxPooling2D(2)(output_low4)    

#To Do : Bridgeの実装
    high5 = layers.BatchNormalization()(pool4high)
    high5 = layers.Activation("relu")(high5)
    low5 = layers.BatchNormalization()(pool4low)
    low5 = layers.Activation("relu")(low5)
    high5, low5 = OctaveConv2D(1024)([high5, low5])
    high5 = Dropout(0.4)(high5)
    low5 = Dropout(0.4)(low5)
    high5 = layers.BatchNormalization()(high5)
    high5 = layers.Activation("relu")(high5)
    low5 = layers.BatchNormalization()(low5)
    low5 = layers.Activation("relu")(low5)
    high5, low5 = OctaveConv2D(1024)([high5, low5])
    high5 = Dropout(0.4)(high5)
    low5 = Dropout(0.4)(low5)

#To Do : Decoderの実装
#To Do : Q6の実装
    uphigh6, uplow6 = OctaveConv2D(512, use_transpose=True, strides=(2,2))([high5,low5])
    uphigh6 = layers.BatchNormalization()(uphigh6)
    uphigh6 = layers.Activation("relu")(uphigh6)
    uplow6 = layers.BatchNormalization()(uplow6)
    uplow6 = layers.Activation("relu")(uplow6)
    merge6high = concatenate([output_high4,uphigh6], axis = 3)
    merge6low = concatenate([output_low4,uplow6], axis = 3)
    
    high6 = layers.BatchNormalization()(merge6high)
    high6 = layers.Activation("relu")(high6)
    low6 = layers.BatchNormalization()(merge6low)
    low6 = layers.Activation("relu")(low6)
    high6, low6 = OctaveConv2D(512)([high6,low6])
    high6 = layers.BatchNormalization()(high6)
    high6 = layers.Activation("relu")(high6)
    low6 = layers.BatchNormalization()(low6)
    low6 = layers.Activation("relu")(low6)    
    high6, low6 = OctaveConv2D(512)([high6, low6])
    
#shortcut network
    high6_6, low6_6 = OctaveConv2D(512)([merge6high,merge6low])
    shortcut_high6 = layers.BatchNormalization()(high6_6)
    shortcut_low6 = layers.BatchNormalization()(low6_6)
    
    output_high6 = layers.Add()([high6, shortcut_high6])
    output_low6 = layers.Add()([low6, shortcut_low6])  

#To Do : Q7の実装
    uphigh7, uplow7 = OctaveConv2D(256, use_transpose=True, strides=(2,2))([high6,low6])
    uphigh7 = layers.BatchNormalization()(uphigh7)
    uphigh7 = layers.Activation("relu")(uphigh7)
    uplow7 = layers.BatchNormalization()(uplow7)
    uplow7 = layers.Activation("relu")(uplow7)
    merge7high = concatenate([output_high3,uphigh7], axis = 3)
    merge7low = concatenate([output_low3,uplow7], axis = 3)
    
    high7 = layers.BatchNormalization()(merge7high)
    high7 = layers.Activation("relu")(high7)
    low7 = layers.BatchNormalization()(merge7low)
    low7 = layers.Activation("relu")(low7)
    high7, low7 = OctaveConv2D(256)([high7,low7])
    high7 = layers.BatchNormalization()(high7)
    high7 = layers.Activation("relu")(high7)
    low7 = layers.BatchNormalization()(low7)
    low7 = layers.Activation("relu")(low7)    
    high7, low7 = OctaveConv2D(256)([high7, low7])
    
#shortcut network
    high7_7, low7_7 = OctaveConv2D(256)([merge7high,merge7low])
    shortcut_high7 = layers.BatchNormalization()(high7_7)
    shortcut_low7 = layers.BatchNormalization()(low7_7)
    
    output_high7 = layers.Add()([high7, shortcut_high7])
    output_low7 = layers.Add()([low7, shortcut_low7])  

#To Do : Q8の実装
    uphigh8, uplow8 = OctaveConv2D(128, use_transpose=True, strides=(2,2))([high7,low7])
    uphigh8 = layers.BatchNormalization()(uphigh8)
    uphigh8 = layers.Activation("relu")(uphigh8)
    uplow8 = layers.BatchNormalization()(uplow8)
    uplow8 = layers.Activation("relu")(uplow8)
    merge8high = concatenate([output_high2,uphigh8], axis = 3)
    merge8low = concatenate([output_low2,uplow8], axis = 3)
    
    high8 = layers.BatchNormalization()(merge8high)
    high8 = layers.Activation("relu")(high8)
    low8 = layers.BatchNormalization()(merge8low)
    low8 = layers.Activation("relu")(low8)
    high8, low8 = OctaveConv2D(128)([high8,low8])
    high8 = layers.BatchNormalization()(high8)
    high8 = layers.Activation("relu")(high8)
    low8 = layers.BatchNormalization()(low8)
    low8 = layers.Activation("relu")(low8)    
    high8, low8 = OctaveConv2D(128)([high8, low8])
    
#shortcut network
    high8_8, low8_8 = OctaveConv2D(128)([merge8high,merge8low])
    shortcut_high8 = layers.BatchNormalization()(high8_8)
    shortcut_low8 = layers.BatchNormalization()(low8_8)
    
    output_high8 = layers.Add()([high8, shortcut_high8])
    output_low8 = layers.Add()([low8, shortcut_low8])
    
#To Do : Q9の実装
    uphigh9, uplow9 = OctaveConv2D(64, use_transpose=True, strides=(2,2))([high8,low8])
    uphigh9 = layers.BatchNormalization()(uphigh9)
    uphigh9 = layers.Activation("relu")(uphigh9)
    uplow9 = layers.BatchNormalization()(uplow9)
    uplow9 = layers.Activation("relu")(uplow9)
    merge9high = concatenate([output_high1,uphigh9], axis = 3)
    merge9low = concatenate([output_low1,uplow9], axis = 3)
    
    high9 = layers.BatchNormalization()(merge9high)
    high9 = layers.Activation("relu")(high9)
    low9 = layers.BatchNormalization()(merge9low)
    low9 = layers.Activation("relu")(low9)
    high9, low9 = OctaveConv2D(128)([high9,low9])
    high9 = layers.BatchNormalization()(high9)
    high9 = layers.Activation("relu")(high9)
    low9 = layers.BatchNormalization()(low9)
    low9 = layers.Activation("relu")(low9)    
    high9, low9 = OctaveConv2D(64)([high9, low9])
    
#shortcut network
    high9_9, low9_9 = OctaveConv2D(64)([merge9high,merge9low])
    shortcut_high9 = layers.BatchNormalization()(high9_9)
    shortcut_low9 = layers.BatchNormalization()(low9_9)
    
    output_high9 = layers.Add()([high9, shortcut_high9])
    output_low9 = layers.Add()([low9, shortcut_low9])
    
    conv9 = OctaveConv2D(32, ratio_out=0.0)([output_high9, output_low9])
    conv9 = layers.Activation("sigmoid")(conv9)
    conv10 = layers.Conv2D(1, 1, activation = 'sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=conv10)
    
    model.summary()
    
    model.compile(optimizer = Adam(lr=1e-4), loss = 'binary_crossentropy', metrics = ['accuracy'])
    
    return model