<a href="https://colab.research.google.com/github/kikiru328/Bone_Detection/blob/main/TJ_Net(2020).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Layers , Models , Sequential
from tensorflow.keras . Layers import the Input , Conv2D , MaxPooling2D , Dense , Flatten , Reshape , Dropout , BatchNormalization , Activation , GlobalAveragePooling2D
from fromtensorflow . keras . layers import GlobalMaxPool2D , Concatenate

# Inherit Layer and build resnet50 101 152 convolutional layer module 
def  conv_block ( inputs , filter_num , reduction_ratio , stride = 1 , name = None ) : # 합성곱함수 블럭
    
    x = inputs
    x = Conv2D ( filter_num [ 0 ] ,  ( 1 , 1 ) , strides = stride , padding = 'same' , name = name + '_conv1' ) ( x ) 
    x = BatchNormalization ( axis = 3 , name = name + '_bn1 ' ) ( x ) 
    x = Activation ('relu' , name = name + '_relu1' ) ( x )

    x = Conv2D ( filter_num [ 1 ] ,  ( 3 , 3 ) , strides = 1 , padding = 'same' , name = name + '_conv2' ) ( x ) 
    x = BatchNormalization ( axis = 3 , name = name + '_bn2 ' ) ( x ) 
    x = Activation ('relu' , name = name + '_relu2' ) ( x )

    x = Conv2D ( filter_num [ 2 ] ,  ( 1 , 1 ) , strides = 1 , padding = 'same' , name = name + '_conv3' ) ( x ) 
    x = BatchNormalization ( axis = 3 , name = name + '_bn3 ' ) ( x )

    # Channel Attention 
    avgpool = GlobalAveragePooling2D ( name = name + '_channel_avgpool' ) ( x ) 
    maxpool = GlobalMaxPool2D ( name = name + '_channel_maxpool' ) ( x ) 
    # Shared MLP 
    Dense_layer1 = Dense ( filter_num [ 2 ]  reduction_ratio , activation = 'relu' , name = name+ '_channel_fc1' ) 
    Dense_layer2 = Dense ( filter_num [ 2 ] , activation = 'relu' , name = name + '_channel_fc2' ) 
    avg_out = Dense_layer2 ( Dense_layer1 ( avgpool ) ) 
    max_out = Dense_layer2 ( Dense_layer1 ( maxpool ) )

    channel = layers . add ( [ avg_out , max_out ] ) 
    channel = Activation ( 'sigmoid' , name = name + '_channel_sigmoid' ) ( channel ) 
    channel = Reshape ( ( 1 , 1 , filter_num [ 2 ] ) , name = name + '_channel_reshape' ) ( channel) 
    CHANNEL_OUT = TF . Multiply ( X , Channel )
    
    # Spatial Attention 
    avgpool = tf . Reduce_mean ( channel_out , axis = 3 , keepdims = True , name = name + '_spatial_avgpool' ) 
    maxpool = tf . Reduce_max ( channel_out , axis = 3 , keepdims = True , name = name + '_spatial_maxpool' ) 
    spatial = Concatenate( axis = 3 ) ( [ avgpool , maxpool ] )

    spatial = Conv2D ( 1 ,  ( 7 , 7 ) , strides = 1 , padding = 'same' , name = name + '_spatial_conv2d' ) ( spatial 
    spatial_conv2d ) spatial_out = Activation ( 'sigmoid' , name = name + '_spatial_sigmoid' ) ( spatial )

    CBAM_out = TF . Multiply ( CHANNEL_OUT , spatial_out )

    # residual connection 
    r = Conv2D ( filter_num [ 2 ] ,  ( 1 , 1 ) , strides = stride , padding = 'same' , name = name + '_residual' ) ( inputs ) 
    x = layers . add ( [ CBAM_out , r ] ) 
    x = Activation ( 'relu' ,name = name + '_relu3' ) ( x )

    return x

def  build_block  ( x , filter_num , blocks , reduction_ratio = 16 , stride = 1 , name = None ) : # Conv2D부터 CBAM을 하나의 블록으로 엮어주는 함수

    x = conv_block ( x , filter_num , reduction_ratio , stride , name = name )

    for i in  range ( 1 , blocks ) : 
        x = conv_block ( x , filter_num , reduction_ratio , stride = 1 , name = name + '_block' + str ( i ) )

    return x


# Create resnet50 101 152 
def  SE_ResNet ( Netname , nb_classes ) : #Tj-Net 실행함수

    ResNet_Config =  { 'ResNet50' : [ 3 , 4 , 6 , 3 ] , 
                    'ResNet101' : [ 3 , 4 , 23 , 3 ] , 
                    'ResNet152' : [ 3 , 8 , 36 , 3 ] } 
    layers_dims = ResNet_Config [ Netname ]

    filter_block1 = [ 64 ,  64 ,  256 ] 
    filter_block2 = [ 128 , 128 , 512 ] 
    filter_block3 = [ 256 , 256 , 1024 ] 
    filter_block4 = [ 512 , 512 , 2048 ]

    # Reduction ratio in four blocks 
    SE_reduction = [ 16 , 16 , 16 , 16 ]

    img_input = Input ( shape = ( 224 , 224 , 3 ) ) 
    # stem block 
    x = Conv2D ( 64 ,  ( 7 , 7 ) , strides = ( 2 , 2 ) , padding = 'same' , name = 'stem_conv' ) ( img_input ) 
    x = BatchNormalization ( axis =3 , name = 'stem_bn' ) ( x ) 
    x = Activation ( 'relu' , name = 'stem_relu' ) ( x ) 
    x = MaxPooling2D ( ( 3 , 3 ) , strides = ( 2 , 2 ) , padding = 'same ' , name = 'stem_pool' ) ( x ) 
    # convolution block
    x = build_block ( x , filter_block1 , layers_dims [ 0 ] , SE_reduction [ 0 ] , name = 'conv1' ) 
    x = build_block ( x , filter_block2 , layers_dims [ 1 ] , SE_reduction [ 1 ] , stride = 2 , name = 'conv2 ' ) 
    x =build_block ( x , filter_block3 , layers_dims [ 2 ] , SE_reduction [ 2 ] , stride = 2 , name = 'conv3' ) 
    x = build_block ( x , filter_block4 , layers_dims [ 3 ] ] , stride = 2 , name = 'conv4' ) # top layer, SE_reduction [ 3 
    x = GlobalAveragePooling2D ( name = 'top_layer_pool' ) ( x ) 
    x = Dense ( nb_classes , activation = 'softmax'
     , name = 'fc' ) ( x )

    model = models . Model ( img_input , x , name = Netname )

    return model
    

if __name__ == ' __main__ ' : 
    model = SE_ResNet ( 'ResNet50' ,  1000 ) 
    model . summary ( )