In [None]:
def get_simplest_model2() -> keras.Model:
    def get_attention_module(prev: keras.layers.Layer) -> keras.layers.Layer:
        gap_layer = keras.layers.GlobalAveragePooling2D()(prev)
        gap_layer_res = keras.layers.Reshape((1, 1, 256))(gap_layer)
        dense = keras.layers.Dense(256, activation='relu')(gap_layer_res)
        dense = keras.layers.Dense(256, activation='softmax')(dense)
        mul_layer = keras.layers.Multiply()([prev, dense])

        return mul_layer

    def get_conv_module(
            prev: keras.layers.Layer,
            filters: int,
            drop_rate: float,
            kernel_size: int) -> keras.layers.Layer:
        module = keras.layers.Conv2D(filters, kernel_size, strides=2, activation='relu', padding='same')(prev)
        module = keras.layers.BatchNormalization()(module)
        module = keras.layers.Conv2D(filters, kernel_size, activation='relu', padding='same')(module)
        module = keras.layers.BatchNormalization()(module)
        module = keras.layers.MaxPool2D(pool_size=2)(module)
        module = keras.layers.Dropout(drop_rate)(module)

        return module

    def get_classifier_module(prev: keras.layers.Layer) -> keras.layers.Layer:
        classifier = keras.layers.Flatten()(prev)
        classifier = keras.layers.Dense(512, activation='relu')(classifier)
        classifier = keras.layers.BatchNormalization()(classifier)
        classifier = keras.layers.Dropout(.5)(classifier)

        return classifier

    _input = keras.layers.Input(shape=(SMALLER_HEIGHT, SMALLER_WIDTH, 3))
    conv = get_conv_module(_input, 64, .25, 3)
    conv = get_conv_module(conv, 128, .4, 3)
    conv = get_conv_module(conv, 256, .5, 3)
    attention1 = get_attention_module(conv)
    attention2 = get_attention_module(conv)
    merged_attentions = keras.layers.concatenate([attention1, attention2])
    classifier = get_classifier_module(merged_attentions)
    output = keras.layers.Dense(num_classes, activation='softmax', name='root')(classifier)
    gap_attention1 = keras.layers.GlobalAveragePooling2D()(attention1)
    gap_attention2 = keras.layers.GlobalAveragePooling2D()(attention2)
    aux_output = keras.layers.Dot(axes=1, normalize=True, name='dot')([gap_attention1, gap_attention2])
    model = keras.Model(_input, outputs=[output, aux_output])

    model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics=['accuracy'])

    print(model.summary())

    return model