In [1]:
from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import MemoryEfficientSwish
import torch
from torch import nn

In [2]:
eff_b0 = EfficientNet.from_pretrained('efficientnet-b0', num_classes=168)

Loaded pretrained weights for efficientnet-b0


In [3]:
eff_b0

EfficientNet(
  (_conv_stem): Conv2dStaticSamePadding(
    3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
    (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
  )
  (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock(
      (_depthwise_conv): Conv2dStaticSamePadding(
        32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
        (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
      )
      (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv2dStaticSamePadding(
        32, 8, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv2dStaticSamePadding(
        8, 32, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv2dStaticSamePadding(
        32, 16, kernel_size=

In [4]:
class EfficientNet_0(nn.Module):
    def __init__(self):
        super(EfficientNet_0, self).__init__()
        #eff_b0 = EfficientNet.from_pretrained('efficientnet-b0', num_classes=168)
        eff_b0 = EfficientNet.from_name('efficientnet-b0')

        self._conv_stem = eff_b0._conv_stem
        self._bn0 = eff_b0._bn0
        self._blocks = eff_b0._blocks
        self._conv_head = eff_b0._conv_head
        self._bn1 = eff_b0._bn1
        self._avg_pooling = eff_b0._avg_pooling
        self._dropout = eff_b0._dropout
        
        self._root_fc = nn.Linear(1280, 168)
        # self._root_swish = MemoryEfficientSwish()
        
        self._vowel_fc = nn.Linear(1280, 11)
        # self._vowel_swish = MemoryEfficientSwish()
        
        self._consonant_fc = nn.Linear(1280, 7)
        # self._consonant_swish = MemoryEfficientSwish()
        
    def forward(self, x):
        x = self._conv_stem(x)
        x = self._bn0(x)
        for m in self._blocks:
            x = m(x)
        x = self._conv_head(x)
        x = self._bn1(x)
        x = self._avg_pooling(x)
        x = self._dropout(x)
        x = x.view(-1, 1280)
        
        print(x.shape)
        x_root = self._root_fc(x)
        x_root = self._root_swish(x_root)
        
        x_vowel = self._vowel_fc(x)
        x_vowel = self._vowel_swish(x_vowel)
        
        x_consonant = self._consonant_fc(x)
        x_consonant = self._consonant_swish(x_consonant)
            
        return x_root, x_vowel, x_consonant


In [5]:
eff_b0

EfficientNet(
  (_conv_stem): Conv2dStaticSamePadding(
    3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
    (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
  )
  (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock(
      (_depthwise_conv): Conv2dStaticSamePadding(
        32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
        (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
      )
      (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv2dStaticSamePadding(
        32, 8, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv2dStaticSamePadding(
        8, 32, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv2dStaticSamePadding(
        32, 16, kernel_size=

In [29]:
sample = torch.ones(1, 3, 448, 448)

In [30]:
eff_b0(sample)

tensor([[ 0.0663,  0.0688,  0.0039,  0.0073, -0.0647, -0.0123, -0.0613,  0.0333,
          0.0170, -0.0830, -0.0433,  0.0202,  0.0354, -0.0099,  0.0176, -0.1546,
         -0.0821, -0.0662,  0.0436,  0.0054, -0.1167,  0.0418,  0.0095,  0.1410,
         -0.0465,  0.1290,  0.0819, -0.0243,  0.1016,  0.0840,  0.0785, -0.0158,
         -0.0962,  0.0991, -0.0241, -0.0428,  0.0098,  0.1685,  0.0681,  0.0103,
          0.0335,  0.1109, -0.1152,  0.1044,  0.0535,  0.0026,  0.1808,  0.0354,
         -0.1374, -0.0405,  0.1394, -0.0345,  0.0213,  0.0504, -0.0221,  0.1774,
          0.0006, -0.0384, -0.1365,  0.0628, -0.0082, -0.0333, -0.0788,  0.1868,
          0.2300,  0.1783, -0.0711, -0.0080,  0.0119, -0.1350,  0.0942,  0.0092,
         -0.0327, -0.0811, -0.0175,  0.1605,  0.0222, -0.0078, -0.0858, -0.1015,
          0.1242,  0.0028,  0.0592, -0.1092,  0.1310, -0.0343, -0.1368,  0.0264,
          0.0521, -0.0041,  0.0330,  0.0892,  0.0550, -0.0755,  0.0916,  0.0463,
          0.1030, -0.0349,  

In [5]:
eff_seq = EfficientNet_0()

In [6]:
eff_seq

EfficientNet_0(
  (_conv_stem): Conv2dStaticSamePadding(
    3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
    (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
  )
  (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock(
      (_depthwise_conv): Conv2dStaticSamePadding(
        32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
        (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
      )
      (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv2dStaticSamePadding(
        32, 8, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv2dStaticSamePadding(
        8, 32, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv2dStaticSamePadding(
        32, 16, kernel_siz

In [33]:
eff_seq(sample)


torch.Size([1, 1280])


(tensor([[-2.2357e-01, -1.5539e-02,  8.6996e-01, -1.9838e-01, -1.1289e-01,
          -6.2418e-02,  4.4452e-02,  3.6350e-01,  1.1067e+00, -2.7643e-01,
          -8.2375e-02,  3.7778e+00,  1.8769e-03, -2.5986e-01, -3.0666e-02,
           1.4120e-01,  1.7073e-01, -2.6595e-01, -2.1492e-01,  3.0552e-01,
           6.8114e-02, -1.6309e-01,  3.4651e-01, -2.4062e-01, -9.9122e-02,
           1.2292e+00,  3.0076e+00,  3.1739e-02,  2.1770e+00,  5.5068e-01,
          -1.2920e-01,  1.2439e+00, -1.2519e-01, -2.5127e-01,  1.9527e+00,
           3.9038e-01, -1.1028e-01,  1.0218e+00,  1.0944e+00, -1.3938e-01,
           1.5669e+00, -2.1018e-01,  1.1484e+00,  1.8699e+00, -1.0653e-01,
           6.7655e-01, -1.2773e-01,  5.1968e-01,  3.9642e-01, -1.1439e-01,
          -1.9535e-01, -2.7716e-01,  1.2832e-01, -2.3728e-01, -2.6666e-01,
          -2.2307e-01,  3.5386e-01,  1.4820e-02, -2.0384e-01,  1.0852e+00,
          -2.5707e-01, -2.6543e-01,  3.0489e+00,  2.4817e+00,  4.3210e-01,
          -2.7059e-01,  7