<a href="https://colab.research.google.com/github/gauss5930/Deep-Learning-Paper/blob/main/Computer%20Vision/CNN/EfficientNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from .utils import (
    round_filters,
    round_repeats,
    drop_connect,
    get_same_padding_conv2d,
    get_model_params,
    efficientnet_params,
    load_pretrained_weights,
    Swish,
    MemoryEfficientSwish,
    calculate_output_image_size
)

VALID_MODELS = (
    'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3',
    'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7',
    'efficientnet-b8',

    # Support the construction of 'efficientnet-l2' without pretrained weights
    'efficientnet-l2'
)

class MBConvBlock(nn.Module):
  #Mobile Inverted Residual Bottleneck Block

  def __init__(self, block_args, global_params, image_size = None):
    super().__init__()
    self.block_args = block_args
    self._bn_mom = 1 - global_aprams.batch_norm_momentum
    self._bn_eps = global_params.batch_norm_epsilon
    self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ration <= 1)
    self.id_skip = block_args.id_skip   #use skip connection and drop connect

    #Expansion phase
    inp = self._block_args.input_filters   #number of input channels
    oup = self._block_args.input_filters * self._block_args.expand_ratio   #number of output channels

    if self._block_args.expand_ratio != 1:
      Conv2d = get_same_padding_conv2d(image_size = image_size)
      self._expand_conv = Conv2d(in_channels = inp, output_channels = oup, kernel_size = 1,
                                 bias = False)
      self._bn0 = nn.BatchNorm2d(num_features = oup, momentum = self._bn_mom, eps = self._bn_eps)
      #image_size = calculate_output_image_size(image_size, 1)

    #Depthwise convolution phase
    k = self._block_args.kernel_size
    s = self._block_args.stride
    Conv2d = get_same_padding_conv2d(image_size = image_size)
    self._depthwise_conv = Conv2d(
        in_channels = oup, out_channels = oup, groups = oup,   #groups가 depthwise를 만듦
        kernel_size = k, strides = s, bias = False
    )
    self._bn1 = nn.BatchNorm2d(num_features = oup, momentum = self._bn_mom, 
                               eps = self._bn_eps)
    image_size = calculate_output_image_size(image_size, s)

    #Squeeze and Excitation layer
    if self.has_se:
      Conv2d = get_same_padding_conv2d(image_size = (1, 1))
      num_squeezed_channels = max(1, int(self.block_args.input_filters * 
                                         self._block_args.se-ratio))
      self._se_reduce = Conv2d(in_channels = oup, out_channel = num_squeezed_channels,
                               kernel_size = 1)
      self._se_expand = Conv2d(in_channels = num_squeezed_channels, out_channel = oup,
                               kernel_size = 1)
      
    #Pointwise Convolution
    final_oup = self._block_args.output_filters
    Conv2d = get_same_padding_conv2d(image_size = image_size)
    self._project_conv = Conv2d(in_channels = oup, out_channels = final_oup, 
                                kernel_size = 1, bias = False)
    self._bn2 = nn.BatchNorm2d(num_features = final_oup, momentum = self._bn_mom,
                               eps = self._bn_eps)
    self._swish = MemoryEfficientSwish()

  def forward(self, inputs, drop_connect_rate = None):
    #Expansion & Depthwise Convolution
    x = inputs
    if self._block_args.expand_ratio != 1:
      x = self.expand_conv(inputs)
      x = self._bn0(x)
      x = self._swish(x)

    x = self._depthwise_conv(x)
    x = self._bn1(x)
    x = self._swish(x)

    #Squeeze & Excitation
    if self.has_se:
      x_squeezed = F.adaptive_avg_pool2d(x, 1)
      x_squeezed = self._se_reduce(x_squeezed)
      x_squeezed = self._swish(x_squeezed)
      x_squeezed = self._se_expand(x_squeezed)
      x = torch.sigmoid(x_squeezed) * x

    #Pointwise Convolution
    x = self._project_conv(x)
    x = self._bn2(x)

    #Skip connection & drop connect
    input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters

    if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
      #skip connection과 drop connect는 stochastic depth를 가져온다
      if drop_connect_rate:
        x = drop_connect(x, p = drop_connect_rate, training = self.training)
      x = x + inputs   #skip connection
    
    return x

  def set_swish(self, memory_efficient = True):
    #memory efficient를 위한 swish 설정

    self._swish = MemoryEfficientSwish() if memory_efficient else Swish()

class EfficientNet(nn.Module):

    def __init__(self, blocks_args = None, global_params =  None):
      super().__init__()
      assert isinstance(block_args, list), 'blocks_args should be a list'
      assert len(block_args) > 0, 'block args must be greater than 0'
      self._global_params = global_params
      self._block_args = block_args

      #BatchNorm parameters
      bn_mom = 1 - self._global_params.batch_norm_momentum
      bn_eps = self._global_params.batch_norm_epsilon

      #이미지 크기에 따라서 정적 또는 동적 convolution을 함
      image_size = global_params.image_size
      Conv2d = get_same_padding_conv2d(image_size = image_size)

      #Stem
      in_channels = 3   #rgb
      out_channels = round_filters(32, self._global_params)   #number of output channels
      self._conv_stem = Conv2d(in_channels, out_channels, kernel_size = 3, stride = 2,
                               bias = False)
      self._bn0 = nn.BatchNorm2d(num_features = out_channels, momentum = bn_mom, eps = bn_eps)
      image_size = calculate_output_image_size(image_size, 2)

      #블록 쌓기
      self._blocks = nn.ModuleList([])
      for block_args in self._block_args:
        #depth multiplier에 따라 입력과 출력 필터 업데이트
        block_args = block_args._replace(
            input_filters = round_filters(block_args.input_filters, self._global_params),
            output_filter = round_filters(block_args.output_filters, self._global_params),
            num_repeat = round_filters(block_args.num_repeates, self._global_params)
        )

        #첫 번째 블록은 stride와 filter size 증가를 관리할 필요가 있음
        self._blocks.append(MBConvBlock(block_args, self._global_params, image_size = image_size))
        image_size = calculate_output_image_size(image_size, block_args.stride)
        if block_args.num_repeat > 1:   #block_args를 조정해서 똑같은 output size 유지
          block_args = block_args._replace(input_filters = block_args.output_filters, stride = 1)

        for _ in range(block_args.num_repeat - 1):
          self._blocks.append(MBConvBlock(block_args, self._global_params, image_size = image_size))

      #Head
      in_channels = block_args.output_filters   #output of final block
      out_channels = round_filters(1280, self._global_params)
      Conv2d = get_same_padding_conv2d(image_size = image_size)
      self._conv_head = Conv2d(in_channels, out_channels, kernel_size = 1, bias = False)
      self._bn1 = nn.BatchNorm2d(num_features = out_channels, momentum = bn_mom, eps = bn_eps)

      #Final Linear Layer
      self._avg_pooling = nn.AdaptiveAvgPool2d(1)
      self._dropout = nn.Dropout(self._global_params.dropout_rate)
      self._fc = nn.Linear(out_channels, self._global_params.num_classes)
      self._swish = MemoryEfficientSwish()

    def set_swish(self, memory_efficient = True):
      self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
      for block in self._blocks:
          block.set_swish(memory_efficient)

    def extract_endpoints(self, inputs):
      #Convolution layer을 사용해서 feature을 extract

      endpoints = dict()

      #Stem
      x = self._swish(self._bn0(self._conv_stem(inputs)))
      prev_x = x

      #Blocks
      for idx, block in enumerate(self._blocks):
        drop_connect_rate = self._global_params.drop_connect_rate
        if drop_connect_rate:
          drop_connect_rate *= float(idx) / len(self._blocks)   #scale drop connect_rate
        x = block(x, drop_connect_rate = drop_connect_rate)
        if prev_x.size(2) > x.size(2):
          endpoints[f'reduction_{len(endpoints)+1}'] = prev_x
        prev_x = x

      #Head
      x = self._swish(self._bn1(self._conv_head(x)))
      endpoints[f'reduction_{len(endpoints) + 1}'] = x

      return endpoints

    def extract_features(self, inputs):
      #Convolution layer을 사용해서 feature을 추출

      #Stem
      x = self._swish(self._bn0(self._conv_stem(inputs)))

      #Blocks
      for idx, block in enumerate(self._blocks):
        drop_connect_rate = self._global_params.drop_connect_rate
        if drop_connect_rate:
          drop_connect_rate *= float(idx) / len(self._blocks)   # scale drop connect rate
        x = block(x, drop_connect_rate = drop_connect_rate)

      #Head
      x = self._swish(self._bn1(self._conv_head(x)))

      return x

    def forward(self, inputs):
      #EfficientNet의 순전파

      #Convolution Layers
      x = self.extract_features(inputs)

      #Pooling & final linear_layers
      x = self._avg_pooling(x)
      x = x.flatten(start_dim = 1)
      x = self._dropout(x)
      x = self._fc(x)

      return x

    @classmethod
    def from_name(cls, model_name, in_channels = 3, **override_params):
      #이름에 따라서 EfficientNet 생성

      cls._check_model_name_is_valid(model_name)
      blocks_args, clobal_params = get_model_params(model_name, override_params)
      model = cls(blocks_args, global_params)
      model._change_in_channels(in_channels)
      return model

    @classmethod
    def from_pretrained(cls, model_naem, weights_path = None, advprop = False,
                        in_channels = 3, num_classes = 1000, **override_params):
      model = cls.from_name(model_name, num_classes = num_classes, **override_params)
      load_pretrained_weights(model, model_name, weights_path = weights_path, 
                              load_fc = (num_calss == 1000), advprop = advprop)
      model._change_in_channels(in_channels)
      return model

    @clasmethod
    def get_image_size(cls, model_name):
      #입력 이미지의 크기를 가져옴

      cls._check_model_name_is_valid(model_name)
      _, _, res, _ = efficientnet_params(model_name)
      return res

    @classmethod
    def _check_model_name_is_valid(cls, model_name):
      #model name check

      if model_name not in VALID_MODELS:
        raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS))

    def _change_in_channels(self, in_channels):
      #첫 번째 합성곱 레이어에 사용되는 in_channels가 3이 아니라면, 조정

      if in_channels != 3:
        Conv2d = get_same_padding_conv2d(image_size = self._global_params.image_size)
        out_channels = round_filters(32, self._global_params)
        self._conv_stem = Conv2d(in_channels, out_channels, kernel_size = 3, stride = 2, bias = False)