In [1]:
import math
import sys
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from operation import Conv1dNormActivation, _make_divisible, SqueezeExcitation

from marnasnet import BlockConfig, ConvBlock, SeparableConvBlock, MBConvBlock, marnasnet_a, marnasnet_b, marnasnet_c, marnasnet_d, marnasnet_e

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
input = torch.randn(1, 3, 112)
# Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation
layer = Conv1dNormActivation(3, 32, 3, 1) 

In [3]:
output = layer(input)
output.shape

torch.Size([1, 32, 112])

In [4]:
cnf = BlockConfig(
    conv_op=ConvBlock,
    repeats=5, 
    kernel=5, stride=1, 
    input_channels=32, out_channels=16, 
    skip_op='identity',
    se_ratio=0.25)
    
blocks_setting = [
    cnf
]

* dummy input

In [5]:
input = torch.randn(1, 32, 256)

In [6]:
convblock = ConvBlock(cnf)
convblock(input).shape

torch.Size([1, 16, 256])

In [7]:
sepconv = SeparableConvBlock(cnf)
sepconv(input).shape

torch.Size([1, 16, 256])

In [8]:
mbconv = MBConvBlock(cnf)
mbconv(input).shape

torch.Size([1, 16, 256])

In [9]:
block_conf = partial(BlockConfig, se_ratio=0.25)

In [10]:
input = torch.randn(1, 3, 256)
model = marnasnet_a(init_channels=3)
output = model(input)
output.shape


torch.Size([1, 6])

In [11]:
input = torch.randn(1, 3, 256)
model = marnasnet_b(init_channels=3)
output = model(input)
output.shape

torch.Size([1, 6])

In [12]:
input = torch.randn(1, 3, 256)
model = marnasnet_c(init_channels=3)
output = model(input)
output.shape

torch.Size([1, 6])

In [13]:
input = torch.randn(1, 3, 256)
model = marnasnet_d(init_channels=3)
output = model(input)
output.shape

torch.Size([1, 6])

In [14]:
input = torch.randn(1, 3, 256)
model = marnasnet_e(init_channels=3)
output = model(input)
output.shape

torch.Size([1, 6])

In [25]:
# TF Version Blcok Analysis

In [4]:
import tensorflow as tf
import abc

from enum import Enum
from abc import *

class BaseAttention(metaclass=abc.ABCMeta):
    def __init__(self, filters, block_name):
        self.filters = filters
        self.block_name = block_name

    def __call__(self, x):
        return x


# Squeeze-and-Excitation module
class SqueezeAndExcite(BaseAttention):
    """squeeze-and-excitation module
    """
    def __init__(self, filters, se_ratio=0.25, block_name=""):
        """
        Parameters
        ----------
        filters: int
            output filter size
        se_ratio: float
            se ratio, se_ratio must be greater than 0 and less than or equal to 1.
        block_name: str
            block name
        """
        super().__init__(filters, block_name)
        self.se_ratio = se_ratio

    def __call__(self, x):
        assert 0 < self.se_ratio <= 1, "se_ratio must be greater than 0 and less than or equal to 1."

        filters_se = max(1, int(self.filters * self.se_ratio))
        se = tf.keras.layers.GlobalAveragePooling1D(name="{}_se_squeeze".format(self.block_name))(x)
        se = tf.keras.layers.Reshape((1, self.filters), name="{}_se_reshape".format(self.block_name))(se)
        se = tf.keras.layers.Conv1D(
            filters_se,
            1,
            padding="same",
            activation="relu",
            kernel_initializer="he_normal",
            name="{}_se_reduce".format(self.block_name)
        )(se)
        se = tf.keras.layers.Conv1D(
            self.filters,
            1,
            padding="same",
            activation="sigmoid",
            kernel_initializer="he_normal",
            name="{}_se_expand".format(self.block_name)
        )(se)
        x = tf.keras.layers.multiply([x, se], name="{}_se_excite".format(self.block_name))
        return x

In [5]:
# ConvOps enum
class ConvBlock(Enum):
    Conv = "Conv"
    SeparableConv = "SeparableConv"
    MBConv = "MBConv"
    ExtremeInception = "ExtremeInception"


# SkipOps enum
class SkipOperation(Enum):
    none = "none"
    pool = "pool"
    identity = "identity"


# Base conv block
class BaseBlock(metaclass=ABCMeta):
    def __init__(self, repeats, kernel_size, skip_op, strides,
                 se_ratio, block_id=1):
        self.repeats = repeats
        self.kernel_size = kernel_size
        self.skip_op = skip_op
        self.strides = strides
        self.se_ratio = se_ratio
        self.block_id = block_id

    @abstractmethod
    def __call__(self, x):
        raise NotImplementedError()

In [6]:
class RegularConvBlock(BaseBlock):
    def __init__(self, repeats: int, kernel_size: int, filters: int, skip_op: SkipOperation, strides: int,
                 se_ratio: float, block_id=1):
        """
        Parameters
        ----------
        repeats: int
            the number of convolutional layers
        kernel_size: int
            the dimension of the convolution window
        filters: int
            the number of filters
        skip_op: Blossom.options.SkipOperation
            skip operation
        strides: int
            the stride ot the convolution
        se_ratio: float
            between 0 and 1, fraction to squeeze the input filters
        block_id: int
            larger than 1, the block id
        """
        super().__init__(repeats, kernel_size, skip_op, strides, se_ratio, block_id)
        self.filters = filters

    def __call__(self, x):
        inputs = x

        for i in range(self.repeats):
            x = tf.keras.layers.Conv1D(
                self.filters,
                self.kernel_size,
                self.strides,
                padding='same',
                activation='relu',
                kernel_initializer='he_normal',
                name="block{}{}_conv".format(self.block_id, chr(i + 97))
            )(x)

        if 0 < self.se_ratio <= 1:
            x = SqueezeAndExcite(self.filters, self.se_ratio, block_name="block{}".format(self.block_id))(x)

        if self.skip_op == SkipOperation.pool:
            x = tf.keras.layers.MaxPooling1D(name="block{}_pool".format(self.block_id), padding='same')(x)
        elif self.skip_op == SkipOperation.identity:
            if self.strides == 1:
                shortcut = inputs
                if int(inputs.shape[-1]) != int(x.shape[-1]):
                    shortcut = tf.keras.layers.Conv1D(int(x.shape[-1]),
                                                      1,
                                                      strides=self.strides,
                                                      kernel_initializer="he_normal",
                                                      padding='valid',
                                                      name="block{}_shortcut".format(self.block_id))(x)

                x = tf.keras.layers.add([x, shortcut], name="block{}_add".format(self.block_id))

        return x

In [7]:
conv_block = RegularConvBlock(3, 3, 32, 'identity', 1, 0.25, 1)

In [19]:
# 입력 데이터의 크기 정의 (예: [배치 크기, 시퀀스 길이, 채널 수])
input_shape = (None, 112, 3)  # None은 배치 크기를 나중에 지정할 수 있음을 의미합니다.

# 입력 텐서 생성
input_tensor = tf.keras.layers.Input(shape=input_shape[1:])  # 배치 크기를 제외한 나머지 차원을 지정합니다.

# RegularConvBlock에 입력 텐서 전달
output_tensor = conv_block(input_tensor)

# 모델 정의 (입력과 출력 텐서를 연결)
model = tf.keras.Model(inputs=input_tensor, outputs=output_tensor)

# 모델 요약 정보 출력
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 112, 3)]     0           []                               
                                                                                                  
 block1a_conv (Conv1D)          (None, 112, 32)      320         ['input_2[0][0]']                
                                                                                                  
 block1b_conv (Conv1D)          (None, 112, 32)      3104        ['block1a_conv[0][0]']           
                                                                                                  
 block1c_conv (Conv1D)          (None, 112, 32)      3104        ['block1b_conv[0][0]']           
                                                                                            

In [24]:
plot_model(model, to_file='model.png', show_shapes=True)

You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.
