### CNN Model

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import AutoFeatureExtractor, SwinForImageClassification, get_scheduler
from sklearn.metrics import confusion_matrix, precision_score, recall_score
import seaborn as sns

# Check and create output directory
if not os.path.exists('./outputs'):
    os.mkdir('./outputs')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers import AutoImageProcessor, MobileViTV2ForImageClassification
# Load pre-trained MobileViTV2Model Transformer
model_name = "apple/mobilevitv2-1.0-imagenet1k-256"
feature_extractor = AutoImageProcessor.from_pretrained(model_name)
model_MobileViT = MobileViTV2ForImageClassification.from_pretrained(model_name)

# Modify the classifier to match the number of classes
model_MobileViT.classifier = nn.Linear(model_MobileViT.classifier.in_features, 10)

In [4]:
from torchsummary import summary

In [5]:
print(model_MobileViT)

MobileViTV2ForImageClassification(
  (mobilevitv2): MobileViTV2Model(
    (conv_stem): MobileViTV2ConvLayer(
      (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (normalization): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): SiLU()
    )
    (encoder): MobileViTV2Encoder(
      (layer): ModuleList(
        (0): MobileViTV2MobileNetLayer(
          (layer): ModuleList(
            (0): MobileViTV2InvertedResidual(
              (expand_1x1): MobileViTV2ConvLayer(
                (convolution): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (activation): SiLU()
              )
              (conv_3x3): MobileViTV2ConvLayer(
                (convolution): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
 

In [6]:
from fvcore.nn import FlopCountAnalysis, parameter_count

model = model_MobileViT
input = torch.randn(1, 3, 256, 256)
flops = FlopCountAnalysis(model, input)
params = parameter_count(model)

print(f"FLOPs: {flops.total()}")
print(f"Parameters: {params['']}")


Unsupported operator aten::silu encountered 25 time(s)
Unsupported operator aten::add encountered 19 time(s)
Unsupported operator aten::im2col encountered 3 time(s)
Unsupported operator aten::softmax encountered 9 time(s)
Unsupported operator aten::mul encountered 21 time(s)
Unsupported operator aten::sum encountered 9 time(s)
Unsupported operator aten::expand_as encountered 9 time(s)
Unsupported operator aten::col2im encountered 3 time(s)
Unsupported operator aten::mean encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
mobilevitv2.encoder.layer.2.transformer.layer.0.dropout1, mobilevitv2.encoder.layer.2.transformer.layer.1.dropout1, mobilevitv2.encoder.layer.3.transformer.layer.0.dropout1, mobile

FLOPs: 1843303424
Parameters: 4393971


In [7]:
from thop import profile
from thop import clever_format

model = model_MobileViT
input = torch.randn(1, 3, 256, 256)
macs, params = profile(model, inputs=(input,))
macs, params = clever_format([macs, params], "%.3f")

print(f"MACs: {macs}")
print(f"Parameters: {params}")


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
MACs: 1.863G
Parameters: 4.386M


In [8]:
from torchinfo import summary

model = model_MobileViT
summary(model, input_size=(1, 3, 256, 256))



Layer (type:depth-idx)                                                           Output Shape              Param #
MobileViTV2ForImageClassification                                                [1, 10]                   --
├─MobileViTV2Model: 1-1                                                          [1, 512]                  --
│    └─MobileViTV2ConvLayer: 2-1                                                 [1, 32, 128, 128]         --
│    │    └─Conv2d: 3-1                                                          [1, 32, 128, 128]         864
│    │    └─BatchNorm2d: 3-2                                                     [1, 32, 128, 128]         64
│    │    └─SiLU: 3-3                                                            [1, 32, 128, 128]         --
│    └─MobileViTV2Encoder: 2-2                                                   [1, 512, 8, 8]            --
│    │    └─ModuleList: 3-4                                                      --                        4,387,9