In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ch_attention(nn.Module):

    def __init__(self, n_ch, r=16):
        super(ch_attention, self).__init__()
        layers = []
        layers += [nn.AdaptiveAvgPool2d((1,1)),
                   View(-1),
                   nn.Linear(in_features=n_ch, out_features=n_ch//r, bias=False),
                   nn.BatchNorm1d(n_ch//r),
                   nn.ReLU(True),
                   nn.Linear(in_features=n_ch//r, out_features=n_ch)]
        self.layers = nn.Sequential(*layers)

    # unsqueeze()함수는 인수로 받은 위치에 새로운 차원을 삽입   
    # expand_as(x)로 나머지 차원에다가 인풋 이미지 크기대로 넣어주는 것   
        
    def forward(self, x):
        return self.layers(x).unsqueeze(2).unsqueeze(3).expand_as(x) 
    
class spatial_attention(nn.Module):

    def __init__(self, n_ch, r=16, dilation=4):
        super(spatial_attention, self).__init__()
        layers = []
        layers += [nn.Conv2d(in_channels=n_ch, out_channels=n_ch//r, kernel_size=1, bias=False),
                   nn.BatchNorm2d(n_ch//r),
                   nn.ReLU(True),
                   nn.Conv2d(in_channels=n_ch//r, out_channels=n_ch//r, kernel_size=3, padding=dilation, dilation=dilation, bias=False),
                   nn.BatchNorm2d(n_ch//r), # 파이토치에서는 dilation 있는만큼 padding을 해줘야 다음에 크기가 안 줄어든다
                   nn.ReLU(True),
                   nn.Conv2d(in_channels=n_ch//r, out_channels=n_ch//r, kernel_size=3, padding=dilation, dilation=dilation, bias=False),
                   nn.BatchNorm2d(n_ch//r),
                   nn.ReLU(True),
                   nn.Conv2d(in_channels=n_ch//r, out_channels=1, kernel_size=1)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x).expand_as(x)

class BAM(nn.Module):

    def __init__(self, n_ch, r=16, dilation=4):
        super(BAM, self).__init__()
        self.channel_att = ch_attention(n_ch, r)
        self.spatial_att = spatial_attention(n_ch, r, dilation)

    def forward(self, x):
        ch_out = self.channel_att(x)
        sp_out = self.spatial_att(x)
        out = 1 + F.sigmoid(ch_out*sp_out)
        return out * x

class CNN_Attention(nn.Module):

    def __init__(self):
        super(CNN_Attention, self).__init__()

        layers = []

        layers += [nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
                   nn.ReLU(inplace=True),
                   BAM(32),
                   nn.MaxPool2d(kernel_size=2, stride=2),
                   nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
                   nn.ReLU(inplace=True),
                   BAM(64),
                   nn.MaxPool2d(kernel_size=2, stride=2),
                   View(-1),
                   nn.Linear(in_features=3136, out_features=128),
                   nn.ReLU(inplace=True),
                   nn.Linear(in_features=128, out_features=10)]

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

class View(nn.Module):

    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(x.shape[0], *self.shape)

if __name__ == '__main__':
    from torchsummary import summary
    model = CNN_Attention()
    summary(model, (1, 28, 28))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 28, 28]             320
              ReLU-2           [-1, 32, 28, 28]               0
 AdaptiveAvgPool2d-3             [-1, 32, 1, 1]               0
              View-4                   [-1, 32]               0
            Linear-5                    [-1, 2]              64
       BatchNorm1d-6                    [-1, 2]               4
              ReLU-7                    [-1, 2]               0
            Linear-8                   [-1, 32]              96
      ch_attention-9           [-1, 32, 28, 28]               0
           Conv2d-10            [-1, 2, 28, 28]              64
      BatchNorm2d-11            [-1, 2, 28, 28]               4
             ReLU-12            [-1, 2, 28, 28]               0
           Conv2d-13            [-1, 2, 28, 28]              36
      BatchNorm2d-14            [-1, 2,

In [10]:
# 케라스

from keras.layers import Input, Conv2D, GlobalAveragePooling2D, Dense, Activation, BatchNormalization, Flatten, Reshape, multiply, add, MaxPooling2D, Dropout
from keras.models import Model

def spatial_attention(in_layer, in_ch, r, d):
    x = Conv2D(in_ch//r, kernel_size=(1,1), padding='same', use_bias=False)(in_layer)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(in_ch//r, kernel_size=(3,3), dilation_rate=d, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(in_ch//r, kernel_size=(3,3), dilation_rate=d, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(1, kernel_size=(1,1), padding='same')(x)
    
    return x

def channel_attention(in_layer, in_ch, r):
    x = GlobalAveragePooling2D()(in_layer)
    # x = Flatten()(x) ?? 원래 케라스에서는 GlobalAvgPooling 하면 자동으로 1차원으로 줄어든다. 그래서 그 담에 Flatten 안 넣어준 것
    x = Dense(in_ch//r, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Dense(in_ch)(x)
    x = Reshape((1, 1, in_ch))(x)
    
    return x

def Bottleneck_Attention_Module(in_layer, in_ch, r, d):
    Mc = channel_attention(in_layer, in_ch, r)
    Ms = spatial_attention(in_layer, in_ch, r, d)
    M = multiply([Ms, Mc])
    M = Activation('sigmoid')(M)
    
    x = multiply([in_layer, M])
    x = add([in_layer, x])
    
    return x
    
r = 16
d = 4

inputs = Input(shape=(28,28, 1))

x = Conv2D(filters=32, kernel_size=(3, 3), strides=(1,1), padding='same', activation='relu')(inputs)
x = Bottleneck_Attention_Module(x, 32, r, d)
x = MaxPooling2D(pool_size=(2,2))(x)
x = Dropout(0.5)(x)
x = Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu')(x)
x = Bottleneck_Attention_Module(x, 64, r, d)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Flatten()(x)
x = Dense(128, activation='relu')(x)
outputs = Dense(10, activation='softmax')(x)

model = Model(inputs=inputs, outputs=outputs)

model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])

model.summary()

Model: "model_7"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_9 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_71 (Conv2D)              (None, 28, 28, 32)   320         input_9[0][0]                    
__________________________________________________________________________________________________
conv2d_72 (Conv2D)              (None, 28, 28, 2)    64          conv2d_71[0][0]                  
__________________________________________________________________________________________________
batch_normalization_56 (BatchNo (None, 28, 28, 2)    8           conv2d_72[0][0]                  
____________________________________________________________________________________________

dense_40 (Dense)                (None, 64)           320         activation_74[0][0]              
__________________________________________________________________________________________________
conv2d_80 (Conv2D)              (None, 14, 14, 1)    5           activation_77[0][0]              
__________________________________________________________________________________________________
reshape_14 (Reshape)            (None, 1, 1, 64)     0           dense_40[0][0]                   
__________________________________________________________________________________________________
multiply_27 (Multiply)          (None, 14, 14, 64)   0           conv2d_80[0][0]                  
                                                                 reshape_14[0][0]                 
__________________________________________________________________________________________________
activation_78 (Activation)      (None, 14, 14, 64)   0           multiply_27[0][0]                
__________

In [None]:
# 훈련

from keras.datasets import mnist

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)