In [None]:
import tensorflow as tf
import tensorflow.keras.layers as layers
import keras
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
from math import ceil

# Check original model from Tensorflow

In [None]:
from keras.applications import EfficientNetB0

model_eff = EfficientNetB0()
model_eff.summary()

# Custom blocks for EfficientNet

In [None]:
@keras.saving.register_keras_serializable(package="efficient_net", name="ConvBlock")
class ConvBlock(tf.keras.layers.Layer):
    "Simple convolutional block with Conv2D + Batch Normalization + Activation (SiLu)"
    def __init__(self, out_channels, kernel_size, stride, padding):
        super(ConvBlock, self).__init__()

        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding


        self.conv = layers.Conv2D(out_channels, kernel_size, strides=stride, padding=padding, use_bias=False)
        self.bn = layers.BatchNormalization()
        self.activation = layers.Activation(tf.nn.silu)

    def call(self, inputs):
        x = self.conv(inputs)
        x = self.bn(x)
        x = self.activation(x)
        return x
    
    def build(self, input_shape):
        self.conv.build(input_shape)
        self.bn.build(self.conv.compute_output_shape(input_shape))
        self.built = True

    def compute_output_shape(self, input_shape):
        out_shape = self.conv.compute_output_shape(input_shape)
        return out_shape
    
    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "out_channels": self.out_channels,
                "kernel_size": self.kernel_size,
                "stride": self.stride,
                "stride": self.stride,
                "padding": self.padding,
            }
        )

        return config
    
    @classmethod
    def from_config(cls, config):
        config.pop('trainable')
        config.pop('dtype')
        return cls(**config)

@keras.saving.register_keras_serializable(package="efficient_net", name="SqueezeExcite")
class SqueezeExcitation(tf.keras.layers.Layer):
    "Squeeze-and-Excitation block for per-channel weighting of the feature maps"
    def __init__(self, in_channels, reduced_dim):
        super(SqueezeExcitation, self).__init__()
        self.in_channels = in_channels
        self.reduced_dim = reduced_dim
        self.global_avg_pool = layers.GlobalAveragePooling2D()
        self.conv1 = layers.Conv2D(reduced_dim, kernel_size=1, activation=tf.nn.silu, use_bias=True)
        self.conv2 = layers.Conv2D(in_channels, kernel_size=1, activation='sigmoid', use_bias= True)
        self.reshape = layers.Reshape((1, 1, in_channels))

    def call(self, inputs):
        x = self.global_avg_pool(inputs)
        x = self.reshape(x) # Reshape back to (batch_size,1,1, in_channels)
        x = self.conv1(x)
        x = self.conv2(x)
        return layers.Multiply()([inputs, x])
    
    def build(self, input_shape):
        self.conv1.build((None, 1, 1, self.in_channels))
        self.conv2.build((None, 1, 1, self.reduced_dim))
        self.built = True

    def compute_output_shape(self, input_shape):
        return input_shape
    
    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "in_channels": self.in_channels,
                "reduced_dim": self.reduced_dim,
            }
        )

        return config
    
    @classmethod
    def from_config(cls, config):
        config.pop('trainable')
        config.pop('dtype')
        return cls(**config)
    
@keras.saving.register_keras_serializable(package="efficient_net", name="MBConvBlock")
class MBBlock(tf.keras.layers.Layer):
    """MBConv block with skip connections"""
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, expand_rate, reduction=4, dropout_rate=0.2):
        super(MBBlock, self).__init__()

        # Paramateres
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.expand_rate = expand_rate
        self.reduction = reduction
        self.dropout_rate = dropout_rate


        self.hidden_dim = self.in_channels * self.expand_rate
        self.expand = self.in_channels != self.hidden_dim
        reduced_dim = int(self.in_channels / self.reduction)


        if self.expand:
            self.expand_conv = ConvBlock(self.hidden_dim, kernel_size=1, stride=1, padding=padding)

        self.conv = tf.keras.Sequential([
            layers.DepthwiseConv2D(kernel_size, stride, padding=padding, use_bias= False),
            layers.BatchNormalization(),
            layers.Activation(tf.nn.silu),
            SqueezeExcitation(self.hidden_dim, reduced_dim),
            layers.Conv2D(out_channels, kernel_size=1, strides=1, padding=padding, use_bias= False),
            layers.BatchNormalization()
        ])
        
        self.dropout = layers.Dropout(dropout_rate)
        self.skip_add = layers.Add()

    def call(self, inputs, training=False):
        if self.expand:
            x = self.expand_conv(inputs)
        else:
            x = inputs

        x = self.conv(x)
        
        if training:
            x = self.dropout(x)

        if self.stride == 1 and inputs.shape[-1] == x.shape[-1]:
            x = self.skip_add([x, inputs])

        return x
    
    def build(self, input_shape):
        if self.expand:
            self.expand_conv.build(input_shape)
            self.conv.build(self.expand_conv.compute_output_shape(input_shape))
        else:
            self.conv.build((input_shape))

        self.built = True
    
    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "in_channels": self.in_channels,
                "out_channels": self.out_channels,
                "kernel_size": self.kernel_size,
                "stride": self.stride,
                "padding": self.padding,
                "expand_rate": self.expand_rate,
                "reduction": self.reduction,
                "dropout_rate": self.dropout_rate
            }
        )

        return config
    
    @classmethod
    def from_config(cls, config):
        config.pop('trainable')
        config.pop('dtype')
        return cls(**config)

## Build EfficientNetB0 from scratch and compare to model from original paper (~5.3m parameters)

In [None]:

@keras.saving.register_keras_serializable(package="efficient_net", name="EfficientNet")
class EfficientNet(keras.Model):
    def __init__(
            self,
            input_shape,
            num_classes,
            name = 'EfficientNet',
            include_head = True,
            dropout = 0.2,
            width_factor = 1.0,
            depth_factor = 1.0
    ):
        """EfficientNet builder class"""
        super(EfficientNet, self).__init__()
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.name = name
        self.include_head = include_head
        self.dropout = dropout
        self.width_factor = width_factor
        self.depth_factor = depth_factor
        
        # Model blocks
        self.feature_extractor = self.build_feature_head(self.input_shape,
                                                         self.width_factor,
                                                         self.depth_factor)
        
        if include_head:
            self.mlp = keras.Sequential(
                [   layers.GlobalAveragePooling2D(name='feature_pooling'),
                    layers.Dropout(self.dropout, name ='classifier_dropout'),
                    layers.Dense(self.num_classes, activation='softmax', name='classifier')
                ], name = 'classifier_head'
            )


    def call(self, inputs, training= False):
        x = self.feature_extractor(inputs, training = training)

        if self.include_head:
            x = self.mlp(x, training = training)

        return x

    def build_functional(self):
        input_layer = layers.Input(self.input_shape, name = 'Input')
        output_layer = self(input_layer)
        return keras.Model(inputs=input_layer, outputs=output_layer, name = self.name)
    
    def build_feature_head(self, input_shape, width_factor, depth_factor):

        # Baseline EfficientNetB0 parameters
        basic_mb_params = [
            # expansion_rate, channels(c), repeats(t), stride(s), kernel_size(k)
            [1, 16, 1, 1, 3],
            [6, 24, 2, 2, 3],
            [6, 40, 2, 2, 5],
            [6, 80, 3, 2, 3],
            [6, 112, 3, 1, 5],
            [6, 192, 4, 2, 5],
            [6, 320, 1, 1, 3],
        ]

        # Calculate the filte size of the last convolutional block
        last_conv_dims = ceil(1280 * width_factor)

        inputs = layers.Input(shape=input_shape, name = 'Input')
        channels = int(32 * width_factor)
        x = ConvBlock(channels, 3, stride=2, padding='same')(inputs)
        in_channels = channels

        for r, c_o, repeat, s, k in basic_mb_params:
            out_channels = 4 * ceil(int(c_o * width_factor) / 4)
            num_layers = ceil(repeat * depth_factor)

            for layer in range(num_layers):
                # On layers with more than 1 repetitions, only the first get Depthwise with stride >1
                stride = s if layer == 0 else 1
                x = MBBlock(in_channels, out_channels, kernel_size=k, stride=stride, padding='same', expand_rate=r)(x)
                in_channels = out_channels
                
        x = ConvBlock(last_conv_dims, kernel_size=1, stride=1, padding='valid')(x)
        feature_extractor = keras.Model(inputs, x, name='feature_extractor')
        return feature_extractor
    
    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "input_shape": self.input_shape,
                "num_classes": self.num_classes,
                "name": self.name,
                "include_head": self.include_head,
                "dropout": self.dropout,
                "width_factor": self.width_factor,
                "depth_factor": self.depth_factor
            }
        )

        return config
    
    @classmethod
    def from_config(cls, config):
        config.pop('trainable')
        config.pop('dtype')
        return cls(**config)
    
model = EfficientNet(input_shape=(224,224,3),
                     num_classes= 1000,
                     name = 'EfficientNetB0',
                     include_head= True
                    )
model(np.expand_dims(np.random.rand(224,224,3),axis = 0))
model.summary()

# Train model on CIFAR10

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

def preprocess_image(image, label, image_size):
    image = tf.image.resize(image, (image_size, image_size))
    image = image / 255.0
    return image, label

def load_data(image_size, batch_size):
    (train_ds, val_ds), ds_info = tfds.load(
        'cifar10',
        split=['train', 'test'],
        as_supervised=True,
        with_info=True
    )

    num_classes = ds_info.features['label'].num_classes

    train_ds = train_ds.map(lambda image, label: preprocess_image(image, label, image_size))
    val_ds = val_ds.map(lambda image, label: preprocess_image(image, label, image_size))

    train_ds = train_ds.shuffle(buffer_size=1000).batch(batch_size, drop_remainder = True).prefetch(AUTOTUNE)
    val_ds = val_ds.batch(batch_size, drop_remainder = True).prefetch(AUTOTUNE)

    return train_ds, val_ds, ds_info, num_classes

image_size = 224
batch_size = 128
ds_train, ds_test, ds_info, num_classes = load_data(image_size, batch_size)


model = EfficientNet(input_shape =(image_size,image_size,3),
                     num_classes = num_classes,
                     name = 'EfficientNetB0',
                     include_head = True
                    ).build_functional()
model.summary()

#-------------------------------------------------------------------------------------------------
# Prepare training arguments
# ------------------------------------------------------------------------------------------------
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
initial_learning_rate = 1e-3
decay_steps = len(ds_train) * (50 - 10)
alpha = 1e-5 / initial_learning_rate
warmup_steps = len(ds_train) * 10
lr_schedule = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate, decay_steps, warmup_target=1e-5,
    warmup_steps=warmup_steps
)
optimizer = tf.keras.optimizers.AdamW(learning_rate=lr_schedule, weight_decay=5e-5)

model.compile(
    optimizer=optimizer,
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Train the model
epochs = 50
history = model.fit(ds_train, epochs=epochs, validation_data=ds_test, callbacks = [early_stopping])

# Check results and training metrics

In [None]:
def plot_training_history(history):
    # Plot training & validation accuracy values
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')

    # Plot training & validation loss values
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper right')

    plt.tight_layout()
    plt.show()

# Evaluate the model
loss, accuracy = model.evaluate(ds_test)
print(f"Test accuracy: {accuracy:.2f}")

plot_training_history(history)

In [None]:
def random_predictions(model, dataset, ds_info, num_images=8):
    class_names = {idx: name for idx, name in enumerate(ds_info.features['label'].names)}

    # Get a random sample of num_images from the dataset
    random_indices = np.random.choice(len(dataset), size=num_images, replace=False)
    ds_subset = dataset.unbatch().skip(random_indices[0]).take(num_images).batch(num_images)

    # Make predictions
    images, labels = next(iter(ds_subset))
    predicted_logits = model.predict(images)
    predicted_probabilities = tf.nn.softmax(predicted_logits, axis=-1)
    predicted_classes = np.argmax(predicted_probabilities, axis=-1)

    # Display results
    plt.figure(figsize=(15, 10))
    for i in range(num_images):
        plt.subplot(2, 4, i+1)
        plt.imshow((images[i].numpy() * 255).astype(np.uint8))
        if predicted_classes[i] == labels[i].numpy():
            color = 'green'
        else:
            color = 'red'

        plt.title(f'True: {class_names[labels[i].numpy()]}\nPredicted: {class_names[predicted_classes[i]]}', color=color)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

random_predictions(model, ds_train, ds_info, num_images=8)