1. Convolutional layers with Batch Normalization and Dropout.
2. Attention Mechanism: We'll integrate a basic channel-wise attention mechanism.
3. Advanced layers: Using layers like Separable Convolutions, and GlobalAveragePooling for more efficient feature extraction.

In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models

# Custom Channel-wise Attention Mechanism
class ChannelAttention(layers.Layer):
    def __init__(self, channels, reduction_ratio=8):
        super(ChannelAttention, self).__init__()
        self.channels = channels
        self.reduction_ratio = reduction_ratio
        self.global_avg_pool = layers.GlobalAveragePooling2D()
        self.global_max_pool = layers.GlobalMaxPooling2D()
        
        self.fc1 = layers.Dense(channels // reduction_ratio, activation='relu')
        self.fc2 = layers.Dense(channels)
        
    def call(self, inputs):
        avg_pool = self.global_avg_pool(inputs)
        max_pool = self.global_max_pool(inputs)
        
        avg_pool = layers.Reshape((1, 1, self.channels))(avg_pool)
        max_pool = layers.Reshape((1, 1, self.channels))(max_pool)
        
        avg_fc = self.fc2(self.fc1(avg_pool))
        max_fc = self.fc2(self.fc1(max_pool))
        
        attention = layers.Activation('sigmoid')(avg_fc + max_fc)
        
        return inputs * attention

# Custom CNN with Attention
def custom_cnn_with_attention(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)

    # Block 1: Convolutional block with Attention
    x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(32, (3, 3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = ChannelAttention(32)(x)                                                 # Attention applied here
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = layers.Dropout(0.3)(x)

    # Block 2: Advanced Convolutional Block (Separable Conv)
    x = layers.SeparableConv2D(64, (3, 3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.SeparableConv2D(64, (3, 3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = ChannelAttention(64)(x)                                                 # Attention applied here
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = layers.Dropout(0.4)(x)

    # Block 3: Deeper Conv Block with Global Pooling
    x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = ChannelAttention(128)(x)                                                # Attention applied here
    
    # Global Average Pooling and Fully Connected Layers
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs, outputs)
    return model

# Model Compilation and Summary
input_shape = (128, 128, 3)  # Example input shape
num_classes = 10  # Example number of output classes

model = custom_cnn_with_attention(input_shape, num_classes)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()






Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 128, 128, 3)]     0         
                                                                 
 conv2d (Conv2D)             (None, 128, 128, 32)      896       
                                                                 
 batch_normalization (Batch  (None, 128, 128, 32)      128       
 Normalization)                                                  
                                                                 
 conv2d_1 (Conv2D)           (None, 128, 128, 32)      9248      
                                                                 
 batch_normalization_1 (Bat  (None, 128, 128, 32)      128       
 chNormalization)                                                
                                                                 
 channel_attention (Channel  (None, 128, 128, 32)      29