<a href="https://colab.research.google.com/github/hhaemin/computer_vision/blob/main/7_EfficientNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch, torchvision
import torchvision.models as models
import torchvision.datasets as datasets

import matplotlib.pyplot as plt
from PIL import Image

In [2]:
models.efficientnet_b0()

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivat

In [3]:
### Model
efficientnet_b0 = models.efficientnet_b0(pretrained=True)

## Dataset
to_tensor = torchvision.transforms.Compose(
                [torchvision.transforms.ToTensor(),
               torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])]
                                          )

cifar10 = torchvision.datasets.CIFAR10(root='./', download=True, transform=to_tensor)

dataloader = torch.utils.data.DataLoader(cifar10, batch_size=8, shuffle=True, num_workers=2)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-3dd342df.pth


  0%|          | 0.00/20.5M [00:00<?, ?B/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to ./


In [4]:
for idx, data in enumerate(dataloader):
    
    img, gt = data
    
    print(img.shape)
    
    scores = efficientnet_b0(img)
    
    print(scores.shape)
    break

torch.Size([8, 3, 32, 32])
torch.Size([8, 1000])


In [5]:
for k,v in efficientnet_b0.named_parameters():
    print(k,v.shape)

features.0.0.weight torch.Size([32, 3, 3, 3])
features.0.1.weight torch.Size([32])
features.0.1.bias torch.Size([32])
features.1.0.block.0.0.weight torch.Size([32, 1, 3, 3])
features.1.0.block.0.1.weight torch.Size([32])
features.1.0.block.0.1.bias torch.Size([32])
features.1.0.block.1.fc1.weight torch.Size([8, 32, 1, 1])
features.1.0.block.1.fc1.bias torch.Size([8])
features.1.0.block.1.fc2.weight torch.Size([32, 8, 1, 1])
features.1.0.block.1.fc2.bias torch.Size([32])
features.1.0.block.2.0.weight torch.Size([16, 32, 1, 1])
features.1.0.block.2.1.weight torch.Size([16])
features.1.0.block.2.1.bias torch.Size([16])
features.2.0.block.0.0.weight torch.Size([96, 16, 1, 1])
features.2.0.block.0.1.weight torch.Size([96])
features.2.0.block.0.1.bias torch.Size([96])
features.2.0.block.1.0.weight torch.Size([96, 1, 3, 3])
features.2.0.block.1.1.weight torch.Size([96])
features.2.0.block.1.1.bias torch.Size([96])
features.2.0.block.2.fc1.weight torch.Size([4, 96, 1, 1])
features.2.0.block.2.

In [6]:
# parameter수 구하기
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
    


In [7]:
count_parameters(efficientnet_b0)

+---------------------------------+------------+
|             Modules             | Parameters |
+---------------------------------+------------+
|       features.0.0.weight       |    864     |
|       features.0.1.weight       |     32     |
|        features.0.1.bias        |     32     |
|  features.1.0.block.0.0.weight  |    288     |
|  features.1.0.block.0.1.weight  |     32     |
|   features.1.0.block.0.1.bias   |     32     |
| features.1.0.block.1.fc1.weight |    256     |
|  features.1.0.block.1.fc1.bias  |     8      |
| features.1.0.block.1.fc2.weight |    256     |
|  features.1.0.block.1.fc2.bias  |     32     |
|  features.1.0.block.2.0.weight  |    512     |
|  features.1.0.block.2.1.weight  |     16     |
|   features.1.0.block.2.1.bias   |     16     |
|  features.2.0.block.0.0.weight  |    1536    |
|  features.2.0.block.0.1.weight  |     96     |
|   features.2.0.block.0.1.bias   |     96     |
|  features.2.0.block.1.0.weight  |    864     |
|  features.2.0.bloc

5288548

In [8]:
def count_parameters2(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
#     print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [9]:
count_parameters2(models.efficientnet_b0())
count_parameters2(models.efficientnet_b1())
count_parameters2(models.efficientnet_b2())
count_parameters2(models.efficientnet_b3())
count_parameters2(models.efficientnet_b4())
count_parameters2(models.efficientnet_b5())
count_parameters2(models.efficientnet_b6())
count_parameters2(models.efficientnet_b7())

Total Trainable Params: 5288548
Total Trainable Params: 7794184
Total Trainable Params: 9109994
Total Trainable Params: 12233232
Total Trainable Params: 19341616
Total Trainable Params: 30389784
Total Trainable Params: 43040704
Total Trainable Params: 66347960


66347960