In [5]:
import efficientnet_pytorch 
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
# Set PIL to be tolerant of image files that are truncated.
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [6]:
class EfficientNet_b0(nn.Module):
    def __init__(self):
        super(EfficientNet_b0, self).__init__()
        self.model = efficientnet_pytorch.EfficientNet.from_pretrained('efficientnet-b0')
        
        self.classifier_layer = nn.Sequential(
            nn.Linear(1280 , 512),
            nn.ReLU(),
            #nn.BatchNorm1d(512),
            #nn.Dropout(0.2),
            nn.Linear(512 , 256),
            nn.ReLU(),
            nn.Linear(256 , 4)
        )
        
    def forward(self, inputs):
        x = self.model.extract_features(inputs)

        # Pooling and final linear layer
        x = self.model._avg_pooling(x)
        x = x.flatten(start_dim=1)
        x = self.model._dropout(x)
        x = self.classifier_layer(x)
        return x

In [12]:
model= EfficientNet_b0()

Loaded pretrained weights for efficientnet-b0


### Model has two parts,
1. model (model.model)
2. classifier_model (model.classifier_model)


In [14]:
model

EfficientNet_b0(
  (model): EfficientNet(
    (_conv_stem): Conv2dStaticSamePadding(
      3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
    )
    (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
          (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
        )
        (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          32, 8, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          8, 32, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        

In [15]:
model.model

EfficientNet(
  (_conv_stem): Conv2dStaticSamePadding(
    3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
    (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
  )
  (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock(
      (_depthwise_conv): Conv2dStaticSamePadding(
        32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
        (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
      )
      (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv2dStaticSamePadding(
        32, 8, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv2dStaticSamePadding(
        8, 32, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv2dStaticSamePadding(
        32, 16, kernel_size=

In [10]:
model.classifier_layer

Sequential(
  (0): Linear(in_features=1280, out_features=512, bias=True)
  (1): ReLU()
  (2): Linear(in_features=512, out_features=256, bias=True)
  (3): ReLU()
  (4): Linear(in_features=256, out_features=4, bias=True)
)

In [17]:
model.classifier_layer[4].weight[0]

tensor([-4.6963e-02,  4.9580e-02, -5.7288e-02,  5.7153e-02, -4.4709e-03,
        -4.6504e-02, -2.5128e-02, -2.1510e-02, -2.6675e-02, -4.6974e-02,
        -5.7091e-02, -2.3977e-02,  4.2503e-03,  3.2935e-03, -7.9617e-03,
        -3.5226e-02, -4.2100e-02, -4.7044e-02,  5.4168e-02,  1.0494e-02,
         4.3784e-02, -4.2649e-02,  3.7323e-02,  1.1789e-02, -2.7641e-02,
        -3.8416e-02, -3.1389e-02,  1.7239e-02,  6.0447e-02,  9.8327e-03,
        -1.0789e-02, -5.1771e-02,  2.2032e-02,  5.9040e-02, -5.5121e-02,
        -1.2798e-03,  1.1518e-03,  1.1880e-02,  3.7730e-03,  2.8631e-02,
        -6.1506e-02,  7.1166e-04,  4.9849e-02,  2.8511e-02,  7.1381e-03,
         3.6296e-02,  1.6927e-02, -1.5687e-02,  1.5262e-02, -1.3787e-02,
         1.8730e-02,  4.8113e-03,  1.9443e-02, -3.9681e-02, -9.1014e-03,
        -4.3239e-02, -4.3387e-02, -3.4540e-02, -1.4751e-02,  4.6169e-02,
        -3.5545e-02, -3.4017e-02, -1.5196e-02,  2.1473e-02,  1.5384e-02,
        -3.3887e-02, -5.0040e-02, -3.7345e-02,  3.1