In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models

from utils.blocks.stemblock import StemBlock
from utils.blocks.cefe import CeFEBlock
from utils.blocks.emvit import EMViTBlock
from utils.blocks.faces import FACeSBlock

In [2]:
def create_mhadformer_model(num_classes=5, image_size=224):
    inputs = layers.Input(shape=(image_size, image_size, 3), name="input_layer")
    
    x = StemBlock(name="stem_block")(inputs)
    
    x = EMViTBlock(num_blocks=1, 
                   projection_dim=16, 
                   strides=1, 
                   activation="swish", 
                   dropout_rate=0.1, 
                   name="emvit_block1")(x)
    
    x = CeFEBlock(filters=32, strides=1, activation="swish", name="cefe_block1")(x)
    x = CeFEBlock(filters=64, strides=2, activation="swish", name="cefe_block2")(x)
    
    x = EMViTBlock(num_blocks=1, 
                   projection_dim=128, 
                   strides=1, 
                   activation="swish", 
                   dropout_rate=0.1, 
                   name="emvit_block2")(x)
    
    x = FACeSBlock(name="faces_block")(x)
    x = layers.Conv2D(filters=256, kernel_size=(1, 1), strides=(1, 1),
                      padding="same", name="final_conv")(x)
    x = layers.Activation("swish", name="final_activation")(x)
    x = layers.BatchNormalization(name="final_bn")(x)
    
    x = layers.GlobalAveragePooling2D(name="global_avg_pool")(x)
    x = layers.Dropout(0.5, name="final_dropout")(x)
    outputs = layers.Dense(num_classes, activation="softmax", name="output_layer")(x)
    
    model = models.Model(inputs=inputs, outputs=outputs, name="MHADFormer")
    return model

In [3]:
model = create_mhadformer_model()
model.summary()

Model: "MHADFormer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_layer (InputLayer)    [(None, 224, 224, 3)]     0         
                                                                 
 stem_block (StemBlock)      (None, 56, 56, 3)         162       
                                                                 
 emvit_block1 (EMViTBlock)   (None, 28, 28, 16)        106928    
                                                                 
 cefe_block1 (CeFEBlock)     (None, 28, 28, 32)        7376      
                                                                 
 cefe_block2 (CeFEBlock)     (None, 14, 14, 64)        29088     
                                                                 
 emvit_block2 (EMViTBlock)   (None, 7, 7, 128)         291524    
                                                                 
 faces_block (FACeSBlock)    (None, 7, 7, 256)         6

In [4]:
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

def get_model_flops(model, input_shape=(1, 224, 224, 3)):
    """
    Computes the FLOPs of a Keras model.
    
    Args:
        model: A tf.keras.Model instance.
        input_shape: A tuple representing the input shape including batch dimension.
    
    Returns:
        total_flops: Total number of FLOPs (as an integer).
    """
    # Create a concrete function from the model.
    concrete_func = tf.function(model).get_concrete_function(
        tf.TensorSpec(input_shape, tf.float32)
    )
    
    # Convert variables to constants
    frozen_func = convert_variables_to_constants_v2(concrete_func)
    graph_def = frozen_func.graph.as_graph_def()
    
    # Import the graph and run the profiler.
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")
        run_meta = tf.compat.v1.RunMetadata()
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        flops_profile = tf.compat.v1.profiler.profile(
            graph=graph, run_meta=run_meta, cmd='scope', options=opts)
        
        if flops_profile is not None:
            return flops_profile.total_float_ops
        else:
            return 0

flops = get_model_flops(model, input_shape=(1, 224, 224, 3))
flops_giga = flops / 10**9

Instructions for updating:
Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`


In [5]:
params = model.count_params() / 10**6
print(f"Model: {model._name}")
print("FLOPS: {:.2f} GFLOPS".format(flops_giga))
print("Params: {:.2f} M".format(params))

Model: MHADFormer
FLOPS: 0.26 GFLOPS
Params: 1.13 M
