Correct Code

In [3]:
import tensorflow as tf
from tensorflow.keras.layers import Layer
import tensorflow.keras as K
import tensorflow.keras.backend as Kback

def SAM_avg(x, cam):
    print("---")
    batch, _, _, channel = x.shape
    print("SAM_avg!!! Input Shape:", x.shape)
    x = K.layers.SeparableConv2D(channel, kernel_size=1, padding="same", kernel_initializer=tf.keras.initializers.HeNormal())(x)
    print("SepConv2D_1 shape:", x.shape)
    x = K.layers.SeparableConv2D(channel, kernel_size=3, padding="same", kernel_initializer=tf.keras.initializers.HeNormal())(x)
    print("SepConv2D_2 shape:", x.shape)
    x = K.layers.BatchNormalization()(x)
    print("BN shape:", x.shape)
    x = x*cam
    print("X*CAM:", x.shape)

    ## Average Pooling
    x1 = tf.reduce_mean(x, axis=-1)
    print("MEAN:", x1.shape)
    x1 = tf.expand_dims(x1, axis=-1)
    print("Expand:", x1.shape)

    ## Conv layer
    feats = K.layers.Conv2D(1, kernel_size=7, padding="same", activation="sigmoid", kernel_initializer=tf.keras.initializers.HeNormal())(x1)
    print("Conv2D shape:", feats.shape)
    feats = K.layers.Multiply()([x, feats])
    print("Out shape:", feats.shape)
    return feats

def SAM_max(x, cam):
    batch, _, _, channel = x.shape
    print("---")
    print("SAM_max!!! Input Shape:", x.shape)
    x = K.layers.SeparableConv2D(channel, kernel_size=1, padding="same", kernel_initializer=tf.keras.initializers.HeNormal())(x)
    print("SepConv2D_1 shape:", x.shape)
    x = K.layers.SeparableConv2D(channel, kernel_size=3, padding="same", kernel_initializer=tf.keras.initializers.HeNormal())(x)
    print("SepConv2D_2 shape:", x.shape)
    x = K.layers.BatchNormalization()(x)
    print("BN shape:", x.shape)
    x = x*cam
    print("X*CAM:", x.shape)

    ## Max Pooling
    x2 = tf.reduce_max(x, axis=-1)
    print("MAX:", x2.shape)
    x2 = tf.expand_dims(x2, axis=-1)
    print("Expand:", x2.shape)

    ## Conv layer
    feats = K.layers.Conv2D(1, kernel_size=7, padding="same", activation="sigmoid")(x2)
    print("Conv2D shape:", feats.shape)
    feats = K.layers.Multiply()([x, feats])
    print("Out shape:", feats.shape)
    return feats

def CAM(x, ratio=8):
    batch, _, _, channel = x.shape
    print("CAM!!! Input Shape:", x.shape)
    ## Shared layers
    l1 = K.layers.Dense(channel//ratio, activation="relu", use_bias=False)
    l2 = K.layers.Dense(channel, use_bias=False)
    ## Global Average Pooling
    x1 = K.layers.GlobalAveragePooling2D()(x)
    print("GAP shape:", x1.shape)
    x1 = l1(x1)
    print("GAP + Dense shape:", x1.shape)
    x1 = l2(x1)
    print("GAP + Dense + Dense shape:", x1.shape)
    ## Global Max Pooling
    x2 = K.layers.GlobalMaxPooling2D()(x)
    print("GMP shape:", x2.shape)
    x2 = l1(x2)
    print("GMP + Dense shape:", x2.shape)
    x2 = l2(x2)
    print("GMP + Dense + Dense shape:", x2.shape)
    ## Add both the features and pass through sigmoid
    feats = x1 + x2
    feats = K.layers.Activation("sigmoid")(feats)
    feats = K.layers.Multiply()([x, feats])
    print("Out shape:", feats.shape)
    return feats

class ChannelDropout(K.layers.Layer):
    def __init__(self, drop_ratio=0.2):
        super(ChannelDropout, self).__init__()
        self.drop_ratio = drop_ratio

    def build(self, input_shape):
        _, _, _, self.channels = input_shape
        # Initialize a trainable mask with ones
        self.mask = RichardsSigmoid(units=1)(self.add_weight("mask", shape=(1, 1, 1, self.channels), initializer="ones", trainable=True))

    def call(self, x):
        # Duplicate the mask to match the batch size
        mask = tf.tile(self.mask, [tf.shape(x)[0], 1, 1, 1])
        # Multiply the input by the mask
        x = x * mask
        #num_channels_to_keep = int(self.channels // 1.25) # WHY THEY DO NOT USE THE RATIO????!!!!
        num_channels_to_keep = int(x.shape[3] * (1 - self.drop_ratio)) # THIS IS OUR FORMULA
        sorted_x, indices = tf.nn.top_k(x, k=num_channels_to_keep, sorted=True)
        sorted_x = sorted_x[:,:,:,0:num_channels_to_keep]
        return sorted_x

class RichardsSigmoid(K.layers.Layer):
    def __init__(self, units=1, **kwargs):
        super(RichardsSigmoid, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        # Initialize learnable parameters: A, Q, mu
        self.A = self.add_weight(name='A', shape=(self.units,), initializer='uniform', trainable=True)
        self.Q = self.add_weight(name='Q', shape=(self.units,), initializer='uniform', trainable=True)
        self.mu = self.add_weight(name='mu', shape=(self.units,), initializer='uniform', trainable=True)

        super(RichardsSigmoid, self).build(input_shape)

    def call(self, x):
        # Richards sigmoid function
        return 1 / (1 + tf.exp(-self.A * tf.exp(-self.Q * (x - self.mu))))

    def compute_output_shape(self, input_shape):
        return input_shape[:-1] + (self.units,)

def CSSAM(x, cam):
    print("---")
    print("CSSAM!!! Input Shape:", x.shape)
    x_avg = SAM_avg(x, cam)
    print("CSSAM!!! SAM_AVG shape:", x_avg.shape)
    x_max = SAM_max(x, cam)
    print("CSSAM!!! SAM_MAX shape:", x_max.shape)
    x = K.layers.Concatenate()([x_avg, x_max, cam])
    print("CSSAM!!! Concat shape:", x.shape)
    x = ChannelDropout(drop_ratio=0.5)(x)
    print("CSSAM!!! Channel Dropout shape:", x.shape)
    return x

def main():
      # Esempio di input con dimensioni (1, 8, 8, 1024)
      input_tensor = tf.random.normal([1, 8, 8, 1024])

      # CAM
      output_cam = CAM(input_tensor)

      # Test CSSAM
      output_cssam = CSSAM(input_tensor, output_cam)

if __name__ == "__main__":
        main()

CAM!!! Input Shape: (1, 8, 8, 1024)
GAP shape: (1, 1024)
GAP + Dense shape: (1, 128)
GAP + Dense + Dense shape: (1, 1024)
GMP shape: (1, 1024)
GMP + Dense shape: (1, 128)
GMP + Dense + Dense shape: (1, 1024)
Out shape: (1, 8, 8, 1024)
---
CSSAM!!! Input Shape: (1, 8, 8, 1024)
---
SAM_avg!!! Input Shape: (1, 8, 8, 1024)
SepConv2D_1 shape: (1, 8, 8, 1024)
SepConv2D_2 shape: (1, 8, 8, 1024)
BN shape: (1, 8, 8, 1024)
X*CAM: (1, 8, 8, 1024)
MEAN: (1, 8, 8)
Expand: (1, 8, 8, 1)
Conv2D shape: (1, 8, 8, 1)
Out shape: (1, 8, 8, 1024)
CSSAM!!! SAM_AVG shape: (1, 8, 8, 1024)
---
SAM_max!!! Input Shape: (1, 8, 8, 1024)
SepConv2D_1 shape: (1, 8, 8, 1024)
SepConv2D_2 shape: (1, 8, 8, 1024)
BN shape: (1, 8, 8, 1024)
X*CAM: (1, 8, 8, 1024)
MAX: (1, 8, 8)
Expand: (1, 8, 8, 1)
Conv2D shape: (1, 8, 8, 1)
Out shape: (1, 8, 8, 1024)
CSSAM!!! SAM_MAX shape: (1, 8, 8, 1024)
CSSAM!!! Concat shape: (1, 8, 8, 3072)
CSSAM!!! Channel Dropout shape: (1, 8, 8, 1536)





```
input tensor: torch.Size([1, 1024, 8, 8])
---
CAM
GAP: torch.Size([1, 1024])
GAP + Dense: torch.Size([1, 128])
GAP + Dense + Dense: torch.Size([1, 1024])
GMP: torch.Size([1, 1024, 1, 1])
GMP + Dense: torch.Size([1, 128])
GMP + Dense + Dense: torch.Size([1, 1024])
Feats shape: torch.Size([1, 1024])
Output shape: torch.Size([1, 1024, 8, 8])
Output shape after CAM: torch.Size([1, 1024, 8, 8])
---
SAM_AVG
Input shape: torch.Size([1, 1024, 8, 8])
SepConv2d_1 shape: torch.Size([1, 1024, 8, 8])
SepConv2d_2 shape: torch.Size([1, 1024, 8, 8])
BatchNorm2d shape: torch.Size([1, 1024, 8, 8])
X*CAM shape: torch.Size([1, 1024, 8, 8])
Mean pooling shape: torch.Size([1, 1, 8, 8])
Conv2d shape: torch.Size([1, 1, 8, 8])
Sigmoid: torch.Size([1, 1, 8, 8])
Output shape:: torch.Size([1, 1024, 8, 8])
---
SAM_MAX
SepConv2d_1 shape: torch.Size([1, 1024, 8, 8])
SepConv2d_2 shape: torch.Size([1, 1024, 8, 8])
BatchNorm2d shape: torch.Size([1, 1024, 8, 8])
X*CAM shape: torch.Size([1, 1024, 8, 8])
MAX pooling shape: torch.Size([1, 1, 8, 8])
Conv2d shape: torch.Size([1, 1, 8, 8])
Sigmoid: torch.Size([1, 1, 8, 8])
Output shape: torch.Size([1, 1024, 8, 8])
Output shape after CSSAM: torch.Size([1, 3072, 8, 8])

CSSAM:

CSSAM: SAM_avg shape: torch.Size([1, 1024, 8, 8])
CSSAM: SAM_max shape: torch.Size([1, 1024, 8, 8])
CSSAM: Concatenation shape: torch.Size([1, 3072, 8, 8])
CSSAM: Channel Dropout shape: torch.Size([1, 1536, 8, 8])

```