# Implementation of Xception CNN architecture

In [1]:
from kerastuner.applications import xception

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras import Model
from tensorflow.keras import layers
import tensorflow.keras.backend as K
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.python.util import nest

In [2]:
def build_xception_model(
    input_shape,
    output_shape,
    normalize=True,
    conv2d_num_filters=64,
    kernel_size=5,
    initial_strides=2,
    activation="selu",
    sep_num_filters=256,
    num_residual_blocks=4,
    pooling="avg",
    dropout_rate=0,
):
    """Build an image regression model with Xception blocks.
    Original architecture by François Chollet https://arxiv.org/pdf/1610.02357.pdf
    Code based on Autokeras and Keras Tuner.
    """

    assert len(input_shape) == 3, "The input images should have a channel dimension"
    assert activation in ["relu", "selu"]
    assert pooling in ["avg", "flatten", "max"]

    inputs = layers.Input(shape=input_shape)
    x = inputs

    if normalize:
        # Compute the mean and the variance of the dataset and store it as model weights.
        # Don't forget to use adapt_model(model, X) before fitting the model.
        x = preprocessing.Normalization()(x)

    # Initial conv2d
    x = xception.conv(
        x, conv2d_num_filters, kernel_size=kernel_size, activation=activation, strides=initial_strides
    )

    # Separable convolutions
    for _ in range(num_residual_blocks):
        x = xception.residual(x, sep_num_filters, activation=activation, max_pooling=False)

    # Exit flow
    x = xception.residual(x, 2 * sep_num_filters, activation=activation, max_pooling=True)

    pooling_layers = {
        "flatten": layers.Flatten,
        "avg": layers.GlobalAveragePooling2D,
        "max": layers.GlobalMaxPooling2D,
    }
    x = pooling_layers[pooling]()(x)

    # Regression head
    if dropout_rate > 0:
        x = layers.Dropout(dropout_rate)(x)

    outputs = layers.Dense(output_shape[-1])(x)
    
    model = Model(inputs, outputs)
    return model

In [None]:
def adapt_model(model, dataset):
    """Adapt the preprocessing layers, e.g. Normalization(), to the dataset.
    """
    if isinstance(dataset, tf.data.Dataset):
        x = dataset.map(lambda x, y: x)
    else:
        x = nest.flatten(dataset)
    
    def get_output_layer(tensor):
        tensor = nest.flatten(tensor)[0]
        for layer in model.layers:
            if isinstance(layer, tf.keras.layers.InputLayer):
                continue
            input_node = nest.flatten(layer.input)[0]
            if input_node is tensor:
                return layer
        return None

    for index, input_node in enumerate(nest.flatten(model.input)):
        def get_data(*args):
            return args[index]

        if isinstance(x, tf.data.Dataset):
            temp_x = x.map(get_data)
        else:
            temp_x = x[index]
        layer = get_output_layer(input_node)
        while isinstance(layer, preprocessing.PreprocessingLayer):
            layer.adapt(temp_x)
            layer = get_output_layer(layer.output)
    return model

In [3]:
# # Usage examples:
# model = build_xception_model(input_shape=(64, 64, 1), output_shape=(5,), num_residual_blocks=7)
# model.summary()

# # Adapt the normalization layer to the data
# adapt_model(model, data)
# model.fit(...)