In [1]:
!pip install torchinfo



In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from utils import get_network, get_training_dataloader, get_test_dataloader, WarmUpLR, \
    most_recent_folder, most_recent_weights, last_epoch, best_acc_weights
from conf import settings


In [5]:
cifar100_training_loader = get_training_dataloader(
        settings.CIFAR100_TRAIN_MEAN,
        settings.CIFAR100_TRAIN_STD,
        num_workers=4,
        batch_size=512,
        shuffle=True
    )

cifar100_test_loader = get_test_dataloader(
    settings.CIFAR100_TRAIN_MEAN,
    settings.CIFAR100_TRAIN_STD,
    num_workers=4,
    batch_size=512,
    shuffle=True
)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
from torchinfo import summary
from models.vgg import vgg16_bn
teacher_net = vgg16_bn()
summary(teacher_net, input_size=(1,3,32,32))

Layer (type:depth-idx)                   Output Shape              Param #
VGG                                      [1, 100]                  --
├─Sequential: 1-1                        [1, 512, 1, 1]            --
│    └─Conv2d: 2-1                       [1, 64, 32, 32]           1,792
│    └─BatchNorm2d: 2-2                  [1, 64, 32, 32]           128
│    └─ReLU: 2-3                         [1, 64, 32, 32]           --
│    └─Conv2d: 2-4                       [1, 64, 32, 32]           36,928
│    └─BatchNorm2d: 2-5                  [1, 64, 32, 32]           128
│    └─ReLU: 2-6                         [1, 64, 32, 32]           --
│    └─MaxPool2d: 2-7                    [1, 64, 16, 16]           --
│    └─Conv2d: 2-8                       [1, 128, 16, 16]          73,856
│    └─BatchNorm2d: 2-9                  [1, 128, 16, 16]          256
│    └─ReLU: 2-10                        [1, 128, 16, 16]          --
│    └─Conv2d: 2-11                      [1, 128, 16, 16]          147,

In [28]:
class MyCompressNet(nn.Module):
    def __init__(self, num_channels = 64, dr_rate = 0.3):
        super(MyCompressNet, self).__init__()
        self.num_channels = num_channels
        self.conv1 = nn.Conv2d(3, self.num_channels, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(self.num_channels)
        self.conv2 = nn.Conv2d(self.num_channels, self.num_channels*2, 3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(self.num_channels*2)
        self.conv3 = nn.Conv2d(self.num_channels*2, self.num_channels*2, 3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(self.num_channels*2)
        self.conv4 = nn.Conv2d(self.num_channels*2, self.num_channels*4, 3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(self.num_channels*4)
        self.conv5 = nn.Conv2d(self.num_channels*4, self.num_channels*4, 3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(self.num_channels*4)

        self.fc1 = nn.Linear(4*4*self.num_channels*4, self.num_channels*4*4)
        self.fc2 = nn.Linear(self.num_channels*4*4, self.num_channels*4)
        self.fc3 = nn.Linear(self.num_channels*4, self.num_channels*2)
        self.fc4 = nn.Linear(self.num_channels*2, 100)      
        self.dropout_rate = dr_rate
    
    def forward(self,x):
        """
        Forward function
        """
        x = self.bn1(self.conv1(x))
        x = F.relu(F.max_pool2d(x,2))
        x = self.bn2(self.conv2(x))
        x = F.relu(x)
        x = self.bn3(self.conv3(x))
        x = F.relu(F.max_pool2d(x,2))
        x = self.bn4(self.conv4(x))
        x = F.relu(x)
        x = self.bn5(self.conv5(x))
        x = F.relu(F.max_pool2d(x,2))
        x = x.view(-1, 4*4*self.num_channels*4)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p = self.dropout_rate)
        x = F.relu(self.fc2(x))
        x = F.dropout(x, p = self.dropout_rate)
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x
        
        

In [4]:
class MyCompressNet2(nn.Module):
    def __init__(self, num_channels = 64, dr_rate = 0.3):
        super(MyCompressNet2, self).__init__()
        self.num_channels = num_channels
        self.conv1 = nn.Conv2d(3, self.num_channels, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(self.num_channels)
        self.conv2 = nn.Conv2d(self.num_channels, self.num_channels*2, 3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(self.num_channels*2)
        self.conv3 = nn.Conv2d(self.num_channels*2, self.num_channels*2, 3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(self.num_channels*2)
        self.conv4 = nn.Conv2d(self.num_channels*2, self.num_channels*4, 3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(self.num_channels*4)
        self.conv5 = nn.Conv2d(self.num_channels*4, self.num_channels*4, 3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(self.num_channels*4)
        self.conv6 = nn.Conv2d(self.num_channels*4, self.num_channels*4, 3, stride=1, padding=1)
        self.bn6 = nn.BatchNorm2d(self.num_channels*4)
        self.conv7 = nn.Conv2d(self.num_channels*4, self.num_channels*4, 3, stride=1, padding=1)
        self.bn7 = nn.BatchNorm2d(self.num_channels*4)

        self.fc1 = nn.Linear(4*4*self.num_channels*4, self.num_channels*4*4)
        self.fc2 = nn.Linear(self.num_channels*4*4, self.num_channels*4*4)
        self.fc3 = nn.Linear(self.num_channels*4*4, self.num_channels*4)
        self.fc4 = nn.Linear(self.num_channels*4, self.num_channels*4)
        self.fc5 = nn.Linear(self.num_channels*4, self.num_channels*2)
        self.fc6 = nn.Linear(self.num_channels*2, 100)      
        self.dropout_rate = dr_rate
    
    def forward(self,x):
        """
        Forward function
        """
        x = self.bn1(self.conv1(x))
        x = F.relu(F.max_pool2d(x,2))
        x = self.bn2(self.conv2(x))
        x = F.relu(x)
        x = self.bn3(self.conv3(x))
        x = F.relu(F.max_pool2d(x,2))
        x = self.bn4(self.conv4(x))
        x = F.relu(x)
        x = self.bn5(self.conv5(x))
        x = F.relu(F.max_pool2d(x,2))
        x = self.bn6(self.conv6(x))
        x = F.relu(x)
        x = self.bn7(self.conv7(x))
        x = F.relu(x)
        x = x.view(-1, 4*4*self.num_channels*4)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p = self.dropout_rate)
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = F.dropout(x, p = self.dropout_rate)
        x = F.relu(self.fc5(x))
        x = self.fc6(x)
        return x
        
        

In [8]:
class MyCompressNet4(nn.Module):
    def __init__(self, num_channels = 64, dr_rate = 0.3):
        super(MyCompressNet4, self).__init__()
        self.num_channels = num_channels
        self.conv1 = nn.Conv2d(3, self.num_channels, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(self.num_channels)
        self.conv2 = nn.Conv2d(self.num_channels, self.num_channels*4, 3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(self.num_channels*4)
        self.conv5 = nn.Conv2d(self.num_channels*4, self.num_channels*4, 1, stride=1, padding='same')
        self.bn5 = nn.BatchNorm2d(self.num_channels*4)
        self.conv6 = nn.Conv2d(self.num_channels*4, self.num_channels*4, 3, stride=1, padding=1)
        self.bn6 = nn.BatchNorm2d(self.num_channels*4)
        self.conv7 = nn.Conv2d(self.num_channels*4, self.num_channels*8, 3, stride=1, padding=1)
        self.bn7 = nn.BatchNorm2d(self.num_channels*8)
        self.conv8 = nn.Conv2d(self.num_channels*8, self.num_channels*8, 1, stride=1, padding='same')
        self.bn8 = nn.BatchNorm2d(self.num_channels*8)
        self.conv9 = nn.Conv2d(self.num_channels*8, self.num_channels*8, 3, stride=1, padding='same')
        self.bn9 = nn.BatchNorm2d(self.num_channels*8)
        self.mp1 = nn.MaxPool2d(2, stride=2)
        self.fc1 = nn.Linear(4*4*self.num_channels*8, self.num_channels*4*4*2)
        self.fc2 = nn.Linear(self.num_channels*4*4*2, self.num_channels*4*4)
        self.fc3 = nn.Linear(self.num_channels*4*4, self.num_channels*4*4)
        self.fc4 = nn.Linear(self.num_channels*4*4, self.num_channels*4*2)
        self.fc5 = nn.Linear(self.num_channels*4*2, self.num_channels*4)
        self.fc6 = nn.Linear(self.num_channels*4, self.num_channels*4)      
        self.fc7 = nn.Linear(self.num_channels*4, self.num_channels*2)
        self.fc8 = nn.Linear(self.num_channels*2, 100)
        self.dropout_rate = dr_rate
    
    def forward(self,x):
        """
        Forward function
        """
        x = self.bn1(self.conv1(x))
        x = F.relu(F.max_pool2d(x,2))
        x = self.bn2(self.conv2(x))
        x = F.relu(x)
        # x = self.bn3(self.conv3(x))
        # x = F.relu(F.max_pool2d(x,2))
        # x = self.bn4(self.conv4(x))
        # x = F.relu(x)
        x = self.bn5(self.conv5(x))
        x = F.relu(F.max_pool2d(x,2))
        x = self.bn6(self.conv6(x))
        x = F.relu(x)
        x = self.bn7(self.conv7(x))
        x = F.relu(x)
        x = self.bn8(self.conv8(x))
        x = F.relu(x)
        x = self.bn9(self.conv9(x))
        x = F.relu(x)
        x = self.mp1(x)
        x = x.view(-1, 4*4*self.num_channels*8)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p = self.dropout_rate)
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = F.dropout(x, p = self.dropout_rate)
        x = F.relu(self.fc5(x))
        x = F.relu(self.fc6(x))
        x = F.relu(self.fc7(x))
        x = F.relu(self.fc8(x))
        return x
        
        

In [23]:
class MyCompressNet5(nn.Module):
    def __init__(self, num_channels = 64, dr_rate = 0.3):
        super(MyCompressNet5, self).__init__()
        self.num_channels = num_channels
        self.conv1 = nn.Conv2d(3, self.num_channels, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(self.num_channels)
        self.conv2 = nn.Conv2d(self.num_channels, self.num_channels*4, 3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(self.num_channels*4)
        self.mp1 = nn.MaxPool2d(2, stride=2)
        self.conv7 = nn.Conv2d(self.num_channels*4, self.num_channels*8, 3, stride=1, padding=1)
        self.bn7 = nn.BatchNorm2d(self.num_channels*8)
        self.conv8 = nn.Conv2d(self.num_channels*8, self.num_channels*8, 1, stride=1, padding='same')
        self.bn8 = nn.BatchNorm2d(self.num_channels*8)
        self.conv9 = nn.Conv2d(self.num_channels*8, self.num_channels*8, 3, stride=1, padding='same')
        self.bn9 = nn.BatchNorm2d(self.num_channels*8)
        self.mp2 = nn.MaxPool2d(2, stride=2)
        self.fc1 = nn.Linear(4*4*self.num_channels*8, self.num_channels*4*4*2)
        self.fc2 = nn.Linear(self.num_channels*4*4*2, self.num_channels*4*4)
        self.fc3 = nn.Linear(self.num_channels*4*4, self.num_channels*4*4)
        self.fc4 = nn.Linear(self.num_channels*4*4, self.num_channels*4*2)
        self.fc5 = nn.Linear(self.num_channels*4*2, self.num_channels*4)
        self.fc6 = nn.Linear(self.num_channels*4, self.num_channels*4)      
        self.fc7 = nn.Linear(self.num_channels*4, self.num_channels*2)
        self.fc8 = nn.Linear(self.num_channels*2, 100)
        self.dropout_rate = dr_rate
    
    def forward(self,x):
        """
        Forward function
        """
        x = self.bn1(self.conv1(x))
        x = F.relu(F.max_pool2d(x,2))
        x = self.bn2(self.conv2(x))
        x = F.relu(x)
        # x = self.bn3(self.conv3(x))
        # x = F.relu(F.max_pool2d(x,2))
        # x = self.bn4(self.conv4(x))
        # x = F.relu(x)
        x = self.mp1(x)
        # x = self.bn5(self.conv5(x))
        # x = F.relu(F.max_pool2d(x,2))
        # x = self.bn6(self.conv6(x))
        # x = F.relu(x)
        x = self.bn7(self.conv7(x))
        x = F.relu(x)
        x = self.bn8(self.conv8(x))
        x = F.relu(x)
        x = self.bn9(self.conv9(x))
        x = F.relu(x)
        x = self.mp2(x)
        x = x.view(-1, 4*4*self.num_channels*8)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p = self.dropout_rate)
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = F.dropout(x, p = self.dropout_rate)
        x = F.relu(self.fc5(x))
        x = F.relu(self.fc6(x))
        x = F.relu(self.fc7(x))
        x = F.relu(self.fc8(x))
        return x
        
        

In [24]:
x = MyCompressNet5()
print(x.forward(torch.rand(1,3,32,32)).shape)

torch.Size([1, 100])


In [9]:
x = MyCompressNet4()
summary(x, input_size=(1,3,32,32))


Layer (type:depth-idx)                   Output Shape              Param #
MyCompressNet4                           [1, 100]                  --
├─Conv2d: 1-1                            [1, 64, 32, 32]           1,792
├─BatchNorm2d: 1-2                       [1, 64, 32, 32]           128
├─Conv2d: 1-3                            [1, 256, 16, 16]          147,712
├─BatchNorm2d: 1-4                       [1, 256, 16, 16]          512
├─Conv2d: 1-5                            [1, 256, 16, 16]          65,792
├─BatchNorm2d: 1-6                       [1, 256, 16, 16]          512
├─Conv2d: 1-7                            [1, 256, 8, 8]            590,080
├─BatchNorm2d: 1-8                       [1, 256, 8, 8]            512
├─Conv2d: 1-9                            [1, 512, 8, 8]            1,180,160
├─BatchNorm2d: 1-10                      [1, 512, 8, 8]            1,024
├─Conv2d: 1-11                           [1, 512, 8, 8]            262,656
├─BatchNorm2d: 1-12                      [1, 512,

In [46]:
print(x(torch.rand(1,3,32,32)).shape)

torch.Size([1, 100])


In [47]:
# Loss functions definition
def loss_fn(outputs, labels):
    return nn.CrossEntropyLoss()(outputs, labels)

def loss_fn_kd(student_outputs, labels, teacher_outputs, alpha = 0.9, T=0.01):
    return nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
              F.cross_entropy(outputs, labels) * (1. - alpha)

In [3]:
!python3 test.py -net=v4 -weights=v4-200-regular.pth -b=128

Files already downloaded and verified
MyCompressNet4(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), padding=same)
  (bn5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv7): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv8): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), padding=same)
  (bn8): BatchNorm2d(512, eps=1

In [2]:
!pip install matplotlib

