In [1]:
import os
import sys
sys.path.append("./../")
import numpy as np  # type:ignore
import pandas as pd  # type:ignore
import torch  # type:ignore
import torch.nn as nn  # type:ignore
import torchvision.transforms as transforms  # type:ignore
from PyTorch_CIFAR10.cifar10_models.mobilenetv2 import mobilenet_v2
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10  # type:ignore
# from PyTorch_CIFAR10.UTILS_TORCH import *

In [2]:
class ModifiedTeacher(nn.Module):
    def __init__(self, original_model, divider):
        super(ModifiedTeacher, self).__init__()
        # print("original_model.features", original_model.features)
        # print("original_model.features.children()", original_model.features.children())
        # Divide the model into two parts around the middle layer
        middle_index = len(original_model.features) // divider
        print("divider: ", divider)
        print("middle_index: ", middle_index)
        self.front_layers = nn.Sequential(*original_model.features[:middle_index])
        self.middle_layer = original_model.features[middle_index]
        self.end_layers = nn.Sequential(*original_model.features[middle_index + 1 :])
        self.classifier = nn.Sequential(*original_model.classifier)

    def forward(self, x):
        x = self.front_layers(x)
        middle_feature_maps = self.middle_layer(x)
        attention_maps = attention_map(middle_feature_maps)
        x = self.end_layers(middle_feature_maps)
        x = x.mean([2, 3])
        x = self.classifier(x)
        return x, attention_maps

In [3]:
class SmallerMobileNet(nn.Module):
    def __init__(self, original_model, layer):
        super(SmallerMobileNet, self).__init__()
        self.features = nn.Sequential(
            *list(original_model.features.children())[:-layer]
        )

        for block in reversed(self.features):
            if hasattr(block, "conv"):
                if hasattr(block.conv, "__iter__"):
                    # Find the last Conv2d module in the block
                    for layer in reversed(block.conv):
                        if isinstance(layer, nn.Conv2d):
                            num_output_channels = layer.out_channels
                            break
                    break
            elif isinstance(block, nn.Conv2d):
                num_output_channels = block.out_channels
                break

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(num_output_channels, 10),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        x = self.classifier(x)
        return x

In [7]:
ModifiedTeacher(mobilenet_v2(pretrained=True), 2);
ModifiedTeacher(mobilenet_v2(pretrained=True), 4);
ModifiedTeacher(mobilenet_v2(pretrained=True), 8);
ModifiedTeacher(mobilenet_v2(pretrained=True), 20);

divider:  2
middle_index:  9
divider:  4
middle_index:  4
divider:  8
middle_index:  2
divider:  20
middle_index:  0


In [9]:
mobilenet_v2(pretrained=True).features

Sequential(
  (0): ConvBNReLU(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace=True)
  )
  (1): InvertedResidual(
    (conv): Sequential(
      (0): ConvBNReLU(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
      (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (2): InvertedResidual(
    (conv): Sequential(
      (0): ConvBNReLU(
        (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      

In [17]:
SmallerMobileNet(mobilenet_v2(pretrained=False), 17)

SmallerMobileNet(
  (features): Sequential(
    (0): ConvBNReLU(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNReLU(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (pool): AdaptiveAvgPool2d(output_size=(1, 1))
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=16, out_features=10, bias=True)
  )
)