<p class="caption" style="font-size:1.5em;">EfficientNet Core Architecture</p>  
<p align="">
  <a ><img src="../images/model/EfficientNet_core.png" alt="EfficientNetV2 Structure" title="EfficientNet Core Architecture"></a>
</p>

<p class="caption" style="font-size:1.5em;">MBConv Block Architecture</p>  
<p align="">
  <a><img src="../images/model/MBConvBlock.png" alt="EfficientNetV2 Structure" title="MBConv Block Architecture"></a>
</p>

In [None]:
""" EfficientNet https://arxiv.org/abs/2104.00298"""

import keras
import tensorflow as tf
from pydantic import BaseModel, Field

from heartkit.models.blocks import batch_norm, conv2d, mbconv_block, relu6
from heartkit.models.defines import KerasLayer, MBConvParams
from heartkit.models.utils import make_divisible

from heartkit.models import EfficientNetParams, EfficientNetV2, MBConvParams, ModelFactory


def EfficientNetV2_model(
    inputs: tf.Tensor,
    num_classes: int,
) -> keras.Model:
    """Reference model

    Args:
        inputs (tf.Tensor): Model inputs
        num_classes (int): Number of classes

    Returns:
        keras.Model: Model
    """

    blocks = [
        MBConvParams(
            filters=32,
            depth=2,
            ex_ratio=1,
            kernel_size=(1, 3),
            strides=(1, 2),
            se_ratio=2,
        ),
        MBConvParams(
            filters=48,
            depth=1,
            ex_ratio=1,
            kernel_size=(1, 3),
            strides=(1, 2),
            se_ratio=4,
        ),
        MBConvParams(
            filters=64,
            depth=2,
            ex_ratio=1,
            kernel_size=(1, 3),
            strides=(1, 2),
            se_ratio=4,
        ),
        MBConvParams(
            filters=80,
            depth=1,
            ex_ratio=1,
            kernel_size=(1, 3),
            strides=(1, 2),
            se_ratio=4,
        ),
    ]
    return EfficientNetV2(
        inputs,
        params=EfficientNetParams(
            input_filters=24,
            input_kernel_size=(1, 3),
            input_strides=(1, 2),
            blocks=blocks,
            output_filters=0,
            include_top=True,
            dropout=0.0,
            drop_connect_rate=0.0,
        ),
        num_classes=num_classes,
    )

class EfficientNetParams(BaseModel):
    """EfficientNet parameters"""

    blocks: list[MBConvParams] = Field(default_factory=list, description="EfficientNet blocks")
    input_filters: int = Field(default=0, description="Input filters")
    input_kernel_size: int | tuple[int, int] = Field(default=3, description="Input kernel size")
    input_strides: int | tuple[int, int] = Field(default=2, description="Input stride")
    output_filters: int = Field(default=0, description="Output filters")
    include_top: bool = Field(default=True, description="Include top")
    dropout: float = Field(default=0.2, description="Dropout rate")
    drop_connect_rate: float = Field(default=0.2, description="Drop connect rate")
    model_name: str = Field(default="EfficientNetV2", description="Model name")


def efficientnet_core(blocks: list[MBConvParams], drop_connect_rate: float = 0) -> KerasLayer:
    """EfficientNet core

    Args:
        blocks (list[MBConvParam]): MBConv params
        drop_connect_rate (float, optional): Drop connect rate. Defaults to 0.

    Returns:
        KerasLayer: Core
    """

    def layer(x: tf.Tensor) -> tf.Tensor:
        global_block_id = 0
        total_blocks = sum((b.depth for b in blocks))
        for i, block in enumerate(blocks):
            filters = make_divisible(block.filters, 8)
            for d in range(block.depth):
                name = f"stage{i+1}.mbconv{d+1}"
                block_drop_rate = drop_connect_rate * global_block_id / total_blocks
                x = mbconv_block(
                    filters,
                    block.ex_ratio,
                    block.kernel_size,
                    block.strides if d == 0 else 1,
                    block.se_ratio,
                    droprate=block_drop_rate,
                    name=name,
                )(x)
                global_block_id += 1
            # END FOR
        # END FOR
        return x

    # END DEF
    return layer


def EfficientNetV2(
    x: tf.Tensor,
    params: EfficientNetParams,
    num_classes: int | None = None,
) -> keras.Model:
    """Create EfficientNet V2 TF functional model

    Args:
        x (tf.Tensor): Input tensor
        params (EfficientNetParams): Model parameters.
        num_classes (int, optional): # classes.

    Returns:
        keras.Model: Model
    """

    # Force input to be 4D (add dummy dimension)
    requires_reshape = len(x.shape) == 3
    if requires_reshape:
        y = keras.layers.Reshape((1,) + x.shape[1:])(x)
    else:
        y = x
    # END IF

    # Stem
    if params.input_filters > 0:
        name = "stem"
        filters = make_divisible(params.input_filters, 8)
        y = conv2d(
            filters,
            kernel_size=params.input_kernel_size,
            strides=params.input_strides,
            name=name,
        )(y)
        y = batch_norm(name=name)(y)
        y = relu6(name=name)(y)
    # END IF

    y = efficientnet_core(blocks=params.blocks, drop_connect_rate=params.drop_connect_rate)(y)

    if params.output_filters:
        name = "neck"
        filters = make_divisible(params.output_filters, 8)
        y = conv2d(filters, kernel_size=(1, 1), strides=(1, 1), padding="same", name=name)(y)
        y = batch_norm(name=name)(y)
        y = relu6(name=name)(y)

    if params.include_top:
        name = "top"
        y = keras.layers.GlobalAveragePooling2D(name=f"{name}.pool")(y)
        if 0 < params.dropout < 1:
            y = keras.layers.Dropout(params.dropout)(y)
        y = keras.layers.Dense(num_classes, name=name)(y)

    model = keras.Model(x, y, name=params.model_name)
    return model