<a href="https://colab.research.google.com/github/matteobarato/sc/blob/master/nas_resnet_on_cifar_100_(2).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [68]:
!pip install torch 
!pip install torchvision
!pip install torch_pruning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [69]:
# Downalod ResNet18 weights
!wget -c http://cipizio.it/storage/Nas-ResNet/resnet18_net_e_199.pth
!mkdir checkpoint
!mv resnet18_net_e_199.pth ./checkpoint

URL transformed to HTTPS due to an HSTS policy
--2022-06-14 18:53:05--  https://cipizio.it/storage/Nas-ResNet/resnet18_net_e_199.pth
Resolving cipizio.it (cipizio.it)... 51.83.75.172
Connecting to cipizio.it (cipizio.it)|51.83.75.172|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 44773645 (43M) [application/octet-stream]
Saving to: ‘resnet18_net_e_199.pth’


2022-06-14 18:53:11 (8.50 MB/s) - ‘resnet18_net_e_199.pth’ saved [44773645/44773645]

mkdir: cannot create directory ‘checkpoint’: File exists


In [70]:
'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
# from KDLib.vanilla_kd import GatedKD


import torch_pruning as tp

import os
import argparse

In [71]:
def progress_bar(current, total, msg=None):
    if current >= (total -1) : print(f"Batch {current}/{total} : {msg}")

def resume_checkpoint(net, path, map_location=torch.device('cuda') ):
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/'+path+'.pth', map_location=map_location )
    net.load_state_dict(checkpoint['net'])
    return  checkpoint['acc'], checkpoint['epoch']
    # best_acc = checkpoint['acc']
    # start_epoch = checkpoint['epoch']


In [72]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
args = {'resume':False, 'lr':0.001, 'lambda_gating':0.1}

if device == 'cuda':
    # net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

## Dataset CIFAR10

In [73]:
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')


==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


## ResNet18

In [74]:
'''ResNet in PyTorch.

For Pre-activation ResNet, see 'preact_resnet.py'.

Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])


def test():
    net = ResNet18()
    y = net(torch.randn(1, 3, 32, 32))
    print(y.size())

def freezeAllLayers(model):
    for param in model.parameters():
        param.requires_grad = False

test()

torch.Size([1, 10])


## KD

In [75]:
import os
from copy import deepcopy

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter


class BaseClass:
    """
    Basic implementation of a general Knowledge Distillation framework

    :param teacher_model (torch.nn.Module): Teacher model
    :param student_model (torch.nn.Module): Student model
    :param train_loader (torch.utils.data.DataLoader): Dataloader for training
    :param val_loader (torch.utils.data.DataLoader): Dataloader for validation/testing
    :param optimizer_teacher (torch.optim.*): Optimizer used for training teacher
    :param optimizer_student (torch.optim.*): Optimizer used for training student
    :param loss_fn (torch.nn.Module): Loss Function used for distillation
    :param temp (float): Temperature parameter for distillation
    :param distil_weight (float): Weight paramter for distillation loss
    :param device (str): Device used for training; 'cpu' for cpu and 'cuda' for gpu
    :param log (bool): True if logging required
    :param logdir (str): Directory for storing logs
    """

    def __init__(
        self,
        teacher_model,
        student_model,
        train_loader,
        val_loader,
        optimizer_teacher,
        optimizer_student,
        loss_fn=nn.KLDivLoss(),
        temp=20.0,
        distil_weight=0.5,
        device="cpu",
        log=False,
        logdir="./Experiments",
    ):

        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer_teacher = optimizer_teacher
        self.optimizer_student = optimizer_student
        self.temp = temp
        self.distil_weight = distil_weight
        self.log = log
        self.logdir = logdir

        if self.log:
            self.writer = SummaryWriter(logdir)

        if device == "cpu":
            self.device = torch.device("cpu")
        elif device == "cuda":
            if torch.cuda.is_available():
                self.device = torch.device("cuda")
            else:
                print(
                    "Either an invalid device or CUDA is not available. Defaulting to CPU."
                )
                self.device = torch.device("cpu")

        if teacher_model:
            self.teacher_model = teacher_model.to(self.device)
        else:
            print("Warning!!! Teacher is NONE.")

        self.student_model = student_model.to(self.device)
        self.loss_fn = loss_fn.to(self.device)
        self.ce_fn = nn.CrossEntropyLoss().to(self.device)

    def train_teacher(
        self,
        epochs=20,
        plot_losses=True,
        save_model=True,
        save_model_pth="./models/teacher.pt",
    ):
        """
        Function that will be training the teacher

        :param epochs (int): Number of epochs you want to train the teacher
        :param plot_losses (bool): True if you want to plot the losses
        :param save_model (bool): True if you want to save the teacher model
        :param save_model_pth (str): Path where you want to store the teacher model
        """
        self.teacher_model.train()
        loss_arr = []
        length_of_dataset = len(self.train_loader.dataset)
        best_acc = 0.0
        self.best_teacher_model_weights = deepcopy(self.teacher_model.state_dict())

        save_dir = os.path.dirname(save_model_pth)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        print("Training Teacher... ")

        for ep in range(epochs):
            epoch_loss = 0.0
            correct = 0
            for (data, label) in self.train_loader:
                data = data.to(self.device)
                label = label.to(self.device)
                out = self.teacher_model(data)

                if isinstance(out, tuple):
                    out = out[0]

                pred = out.argmax(dim=1, keepdim=True)
                correct += pred.eq(label.view_as(pred)).sum().item()

                loss = self.ce_fn(out, label)

                self.optimizer_teacher.zero_grad()
                loss.backward()
                self.optimizer_teacher.step()

                epoch_loss += loss.item()

            epoch_acc = correct / length_of_dataset

            epoch_val_acc = self.evaluate(teacher=True)

            if epoch_val_acc > best_acc:
                best_acc = epoch_val_acc
                self.best_teacher_model_weights = deepcopy(
                    self.teacher_model.state_dict()
                )

            if self.log:
                self.writer.add_scalar("Training loss/Teacher", epoch_loss, epochs)
                self.writer.add_scalar("Training accuracy/Teacher", epoch_acc, epochs)
                self.writer.add_scalar(
                    "Validation accuracy/Teacher", epoch_val_acc, epochs
                )

            loss_arr.append(epoch_loss)
            print(
                "Epoch: {}, Loss: {}, Accuracy: {}".format(
                    ep + 1, epoch_loss, epoch_acc
                )
            )

            self.post_epoch_call(ep)

        self.teacher_model.load_state_dict(self.best_teacher_model_weights)
        if save_model:
            self.save_checkpoint(self.teacher_model, epoch_acc, epochs, save_model_pth)
        if plot_losses:
            plt.plot(loss_arr)

    def _train_student(
        self,
        epochs=10,
        plot_losses=True,
        save_model=True,
        save_model_pth="./models/student.pt",
    ):
        """
        Function to train student model - for internal use only.

        :param epochs (int): Number of epochs you want to train the teacher
        :param plot_losses (bool): True if you want to plot the losses
        :param save_model (bool): True if you want to save the student model
        :param save_model_pth (str): Path where you want to save the student model
        """
        self.teacher_model.eval()
        self.student_model.train()
        loss_arr = []
        length_of_dataset = len(self.train_loader.dataset)
        best_acc = 0.0
        self.best_student_model_weights = deepcopy(self.student_model.state_dict())

        save_dir = os.path.dirname(save_model_pth)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        print("Training Student...")

        for ep in range(epochs):
            epoch_loss = 0.0
            correct = 0

            for (data, label) in self.train_loader:

                data = data.to(self.device)
                label = label.to(self.device)

                student_out = self.student_model(data)
                teacher_out = self.teacher_model(data)

                loss = self.calculate_kd_loss(student_out, teacher_out, label)

                if isinstance(student_out, tuple):
                    student_out = student_out[0]

                pred = student_out.argmax(dim=1, keepdim=True)
                correct += pred.eq(label.view_as(pred)).sum().item()

                self.optimizer_student.zero_grad()
                loss.backward()
                self.optimizer_student.step()

                epoch_loss += loss.item()

            epoch_acc = correct / length_of_dataset

            _, epoch_val_acc = self._evaluate_model(self.student_model, verbose=True)

            if epoch_val_acc > best_acc:
                best_acc = epoch_val_acc
                self.best_student_model_weights = deepcopy(
                    self.student_model.state_dict()
                )

            if self.log:
                self.writer.add_scalar("Training loss/Student", epoch_loss, epochs)
                self.writer.add_scalar("Training accuracy/Student", epoch_acc, epochs)
                self.writer.add_scalar(
                    "Validation accuracy/Student", epoch_val_acc, epochs
                )

            loss_arr.append(epoch_loss)
            print(
                "Epoch: {}, Loss: {}, Accuracy: {}".format(
                    ep + 1, epoch_loss, epoch_acc
                )
            )

        self.student_model.load_state_dict(self.best_student_model_weights)
        if save_model:
            self.save_checkpoint(self.student_model, epoch_acc, epochs, save_model_pth)
        if plot_losses:
            plt.plot(loss_arr)

    def save_checkpoint(net, acc, epoch, save_model_pth):
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        torch.save(state, save_model_pth)

    def train_student(
        self,
        epochs=10,
        plot_losses=True,
        save_model=True,
        save_model_pth="./models/student.pt",
    ):
        """
        Function that will be training the student

        :param epochs (int): Number of epochs you want to train the teacher
        :param plot_losses (bool): True if you want to plot the losses
        :param save_model (bool): True if you want to save the student model
        :param save_model_pth (str): Path where you want to save the student model
        """
        self._train_student(epochs, plot_losses, save_model, save_model_pth)

    def calculate_kd_loss(self, y_pred_student, y_pred_teacher, y_true):
        """
        Custom loss function to calculate the KD loss for various implementations

        :param y_pred_student (Tensor): Predicted outputs from the student network
        :param y_pred_teacher (Tensor): Predicted outputs from the teacher network
        :param y_true (Tensor): True labels
        """

        raise NotImplementedError

    def _evaluate_model(self, model, verbose=True):
        """
        Evaluate the given model's accuaracy over val set.
        For internal use only.

        :param model (nn.Module): Model to be used for evaluation
        :param verbose (bool): Display Accuracy
        """
        model.eval()
        length_of_dataset = len(self.val_loader.dataset)
        correct = 0
        outputs = []

        with torch.no_grad():
            for data, target in self.val_loader:
                data = data.to(self.device)
                target = target.to(self.device)
                output = model(data)

                if isinstance(output, tuple):
                    output = output[0]
                outputs.append(output)

                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        accuracy = correct / length_of_dataset

        if verbose:
            print("-" * 80)
            print("Validation Accuracy: {}".format(accuracy))
        return outputs, accuracy

    def evaluate(self, teacher=False):
        """
        Evaluate method for printing accuracies of the trained network

        :param teacher (bool): True if you want accuracy of the teacher network
        """
        if teacher:
            model = deepcopy(self.teacher_model).to(self.device)
        else:
            model = deepcopy(self.student_model).to(self.device)
        _, accuracy = self._evaluate_model(model)

        return accuracy

    def get_parameters(self):
        """
        Get the number of parameters for the teacher and the student network
        """
        teacher_params = sum(p.numel() for p in self.teacher_model.parameters())
        student_params = sum(p.numel() for p in self.student_model.parameters())

        print("-" * 80)
        print("Total parameters for the teacher network are: {}".format(teacher_params))
        print("Total parameters for the student network are: {}".format(student_params))

    def post_epoch_call(self, epoch):
        """
        Any changes to be made after an epoch is completed.

        :param epoch (int) : current epoch number
        :return            : nothing (void)
        """

        pass


## KD Losses

In [76]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VanillaKD(BaseClass):
    """
    Original implementation of Knowledge distillation from the paper "Distilling the
    Knowledge in a Neural Network" https://arxiv.org/pdf/1503.02531.pdf

    :param teacher_model (torch.nn.Module): Teacher model
    :param student_model (torch.nn.Module): Student model
    :param train_loader (torch.utils.data.DataLoader): Dataloader for training
    :param val_loader (torch.utils.data.DataLoader): Dataloader for validation/testing
    :param optimizer_teacher (torch.optim.*): Optimizer used for training teacher
    :param optimizer_student (torch.optim.*): Optimizer used for training student
    :param loss_fn (torch.nn.Module):  Calculates loss during distillation
    :param temp (float): Temperature parameter for distillation
    :param distil_weight (float): Weight paramter for distillation loss
    :param device (str): Device used for training; 'cpu' for cpu and 'cuda' for gpu
    :param log (bool): True if logging required
    :param logdir (str): Directory for storing logs
    """

    def __init__(
        self,
        teacher_model,
        student_model,
        train_loader,
        val_loader,
        optimizer_teacher,
        optimizer_student,
        loss_fn=nn.MSELoss(),
        temp=20.0,
        distil_weight=0.5,
        device="cpu",
        log=False,
        logdir="./Experiments",
    ):
        super(VanillaKD, self).__init__(
            teacher_model,
            student_model,
            train_loader,
            val_loader,
            optimizer_teacher,
            optimizer_student,
            loss_fn,
            temp,
            distil_weight,
            device,
            log,
            logdir,
        )

    def calculate_kd_loss(self, y_pred_student, y_pred_teacher, y_true):
        """
        Function used for calculating the KD loss during distillation

        :param y_pred_student (torch.FloatTensor): Prediction made by the student model
        :param y_pred_teacher (torch.FloatTensor): Prediction made by the teacher model
        :param y_true (torch.FloatTensor): Original label
        """

        soft_teacher_out = F.softmax(y_pred_teacher / self.temp, dim=1)
        soft_student_out = F.softmax(y_pred_student / self.temp, dim=1)

        loss = (1 - self.distil_weight) * F.cross_entropy(y_pred_student, y_true)
        loss += (self.distil_weight * self.temp * self.temp) * self.loss_fn(
            soft_teacher_out, soft_student_out
        )

        return loss


class GatedKD(BaseClass):
    """
    Original implementation of Knowledge distillation from the paper "Distilling the
    Knowledge in a Neural Network" https://arxiv.org/pdf/1503.02531.pdf

    :param teacher_model (torch.nn.Module): Teacher model
    :param student_model (torch.nn.Module): Student model
    :param train_loader (torch.utils.data.DataLoader): Dataloader for training
    :param val_loader (torch.utils.data.DataLoader): Dataloader for validation/testing
    :param optimizer_teacher (torch.optim.*): Optimizer used for training teacher
    :param optimizer_student (torch.optim.*): Optimizer used for training student
    :param loss_fn (torch.nn.Module):  Calculates loss during distillation
    :param temp (float): Temperature parameter for distillation
    :param distil_weight (float): Weight paramter for distillation loss
    :param device (str): Device used for training; 'cpu' for cpu and 'cuda' for gpu
    :param log (bool): True if logging required
    :param logdir (str): Directory for storing logs
    """

    def __init__(
        self,
        teacher_model,
        student_model,
        train_loader,
        val_loader,
        optimizer_teacher,
        optimizer_student,
        loss_fn=nn.MSELoss(),
        temp=20.0,
        distil_weight=0.5,
        gating_weight=0.01,
        gate_layers=[],
        device="cpu",
        log=False,
        logdir="./Experiments",
    ):
        super(GatedKD, self).__init__(
            teacher_model,
            student_model,
            train_loader,
            val_loader,
            optimizer_teacher,
            optimizer_student,
            loss_fn,
            temp,
            distil_weight,
            device,
            log,
            logdir,
        )
        self.gating_weight = gating_weight
        self.gate_layers = gate_layers

    def l1_penalty(self):
        loss = 0
        for g in self.gate_layers :
            values = torch.cat([x.view(-1) for x in g[0].parameters()])
            loss += torch.norm(values, 1)
        loss = loss / len(self.gate_layers)
        return loss

    def std_penalty(self):
        loss = 0
        for g in self.gate_layers :
            values = torch.cat([x.view(-1) for x in g[0].parameters()])
            loss += (1/(torch.std(values)+1e-4)) / 1e3
        loss = loss / len(self.gate_layers)
        return loss

    def kd_penalty(self, rho=0.05):
        loss = 0
        for g in self.gate_layers :
            p_hat = torch.cat([g[0].parameters()])#torch.cat([x.view(-1) for x in g[0].parameters()])
            print(p_hat)
            funcs = nn.Sigmoid()
            p_hat = torch.mean(funcs(p_hat),1)
            p_tensor = torch.Tensor([rho] * len(p_hat)).to(device)
            loss += torch.sum(p_tensor * torch.log(p_tensor) - p_tensor * torch.log(p_hat) + (1 - p_tensor) * torch.log(1 - p_tensor) - (1 - p_tensor) * torch.log(1 - p_hat))
        return loss

    def calculate_kd_loss(self, y_pred_student, y_pred_teacher, y_true):
        """
        Function used for calculating the KD loss during distillation

        :param y_pred_student (torch.FloatTensor): Prediction made by the student model
        :param y_pred_teacher (torch.FloatTensor): Prediction made by the teacher model
        :param y_true (torch.FloatTensor): Original label
        """

        soft_teacher_out = F.softmax(y_pred_teacher / self.temp, dim=1)
        soft_student_out = F.softmax(y_pred_student / self.temp, dim=1)

        l1_penalty = self.gating_weight * self.l1_penalty()
        std_penalty = self.gating_weight * self.std_penalty()

        # kd_penalty = self.gating_weight * self.kd_penalty(rho=0.05)
        loss = (1 - self.distil_weight) * F.cross_entropy(y_pred_student, y_true)
        loss += (self.distil_weight * self.temp * self.temp) * self.loss_fn(
            soft_teacher_out, soft_student_out
        )
        loss += l1_penalty + std_penalty

        return loss


## NAS

In [77]:
# GatedConvolution Layer

class GatedConvolution(nn.Module):
    """ Custom Linear layer but mimics a standard linear layer """
    def __init__(self, n_channels, channel_first=True):
        super().__init__()
        self.n_channels = n_channels
        self.out_channels = self.n_channels
        if channel_first:
            weight = torch.ones(self.n_channels, 1, 1)
        else:
            weight = torch.ones(1, 1, self.n_channels)
        self.weight = nn.Parameter(weight)  # nn.Parameter is a Tensor that's a module parameter.
        # initialize weight
        nn.init.ones_(self.weight) # weight init
        self.weight_tranformation = lambda x:x #self.sigmoid_squeezed #torch.sigmoid

    def sigmoid_squeezed(self, x):
        c1 = 10
        c2 = 0.5
        return 1/(1+torch.exp(c1*(-x+c2)))

    def transformed_weight(self):
        with torch.no_grad():
            return self.weight_tranformation(self.weight) 

    def forward(self, x):
        a = torch.mul(x, self.weight_tranformation(self.weight))  # w times x + b
        return a

    def __repr__(self):
        return f'GatedConvolution(n_channels{self.n_channels})'

In [78]:
import random
class NAS_GatedConvolution:
    def __init__(self, verbose=False):
        self.verbose = verbose
        self.gate_layers = []
        self.apply_to_layer_types = nn.Conv2d
        self.max_recursion_level = 25
        self.gating_threshold = 1e-4
        self.ready_to_prune = []

    def factory_inject_layer(self, module):
        n_channels = module.out_channels
        gate = GatedConvolution(n_channels)
        return gate, nn.Sequential(module,gate)

    def apply_gates_to_model(self, model, apply_to_layer_types=[], _level=0): 
        if _level >= self.max_recursion_level: 
            print("|ERR| Maximum level of recursion reached ", self.max_recursion_level )
            return []
        if hasattr(model, 'gating_layers'): print("|WARNING| NAS already applied to model!")
        if self.verbose: print("-"*50, " Level ",_level)
        # if self.verbose: print("Step model", model)
        modules = list(model.named_children())
        self.gate_layers = []
        for i, module in enumerate(modules):
            module_name, module_obj = module
            module_children = list(module_obj.children())
        
            if any([isinstance(module_obj, x) for x in apply_to_layer_types]):
                gate, gated_module = self.factory_inject_layer(module_obj)
                if isinstance(model, nn.Sequential):
                    model[i] = gated_module
                else:
                    setattr(model, module_name, gated_module)
                setattr(gated_module, '_is_gated', True)
                if self.verbose: print("Added new Layer to module ", module_name)
                self.gate_layers.append((gate, module_name, model))
            elif len(module_children) and not hasattr(module_obj, '_is_gated'):
                self.gate_layers += self.apply_gates_to_model(module_obj, apply_to_layer_types=apply_to_layer_types, _level=_level+1)   
        return self.gate_layers

    def estimate_required_channels(self, use_mean = .8 ):
        gates_zeros_idxs = []
        with torch.no_grad():
            for i, layer in enumerate(self.gate_layers):
                gate, _, _ = layer
                w = gate.transformed_weight()
                threshold = torch.mean(w) * use_mean if use_mean else self.gating_threshold
                are_zeros = torch.where(w < threshold, 0., 1.)
                zeros_idxs = torch.squeeze((are_zeros-1).nonzero())[:,0]
                gates_zeros_idxs.append(zeros_idxs.detach().cpu().tolist())
                zeros = torch.squeeze(torch.count_nonzero(are_zeros -1 , dim=0)).item()
                if self.verbose: print(i, "Layer index ", i ," | Zeros", zeros , "/", w.shape[0], " | Mean", torch.mean(w, dim=0).item(), " | Std", torch.std(w, dim=0).item())
        return gates_zeros_idxs

    def project_gates_on_model(self, use_mean=.8):
        gates_zeros_idxs = self.estimate_required_channels(use_mean=use_mean)
        with torch.no_grad():
            for i, item in enumerate(zip(self.gate_layers, gates_zeros_idxs)):
                layer, zeros_idxs = item
                gate, module_name, model = layer
                submodules = [x for x in model.named_children() if hasattr(x[1], "_is_gated")]
                for submodule in submodules:
                    submodule_name, submodule_obj = submodule
                    if isinstance(submodule_obj, nn.Conv2d) and submodule_name == module_name:
                        target_module = submodule_obj
                        for i in zeros_idxs:
                            target_module.weight[i,:,:] = 0
    
    def remove_gates_from_model(self):
        with torch.no_grad():
            for i, layer in enumerate(self.gate_layers):
                gate, module_name, model = layer
                submodules = [x for x in model.named_children() if hasattr(x[1], "_is_gated")]
                for submodule in submodules:
                    submodule_name, submodule_obj = submodule
                    if isinstance(submodule_obj[0], nn.Conv2d) and submodule_name == module_name:
                        target_module = submodule_obj[0]
                        print("Replacing ", module_name, " with ", target_module)
                        setattr(model, module_name, target_module)
                        self.ready_to_prune.append((module_name, target_module))

    def optimize(self, model, use_mean=0.8):
        gates_zeros_idxs = self.estimate_required_channels(use_mean=use_mean)
        self.project_gates_on_model(use_mean=use_mean)
        self.remove_gates_from_model()
        self.prune_model_channels(model, amount=0.5, pruning_idxs=gates_zeros_idxs)
        return gates_zeros_idxs

    def calc_improvement(self, base_model , gated_model):
        base_model_parameters = sum(p.numel() for p in base_model.parameters() if p.requires_grad)
        gated_model_parameters = sum(p.numel() for p in gated_model.parameters() if p.requires_grad)
        return gated_model_parameters / base_model_parameters

    def set_zeros(self, amount=0.5):
        with torch.no_grad():
            for i, layer in enumerate(self.gate_layers):
                gate, _, _ = layer
                w = gate.weight
                size = int(w.shape[0]*amount)
                idxs_samples = random.sample(range(0, w.shape[0]), size)
                for i in idxs_samples:
                    w[i,:,:] = 0

    def prune_model_channels(self, model, amount=0.4, pruning_idxs=None, ):
        for i, module in enumerate(NAS.ready_to_prune):
            _, module_obj = module
            DG = tp.DependencyGraph()
            DG.build_dependency(model, example_inputs=torch.randn(1,3,32,32))
            
            if pruning_idxs is None:
                strategy = tp.strategy.L1Strategy() 
                idxs = strategy(module_obj.weight, amount=0.4)
            else: idxs = pruning_idxs[i]

            pruning_plan = DG.get_pruning_plan( module_obj, tp.prune_conv, idxs=idxs )
            if self.verbose:
                print(pruning_plan)
            pruning_plan.exec()




In [79]:
# NAS  = NAS_GatedConvolution(verbose=False)
# gated_model = ResNet18()
# NAS.apply_gates_to_model(gated_model, apply_to_layer_types=[nn.Conv2d])    
# NAS.estimate_required_channels(use_mean= 0.8)   
# NAS.project_gates_on_model(use_mean= 0.8)   
# NAS.remove_gates_from_model()   

In [80]:
# print(ResNet18())

## Model

In [92]:
args['lr'] = 1e-3
args['lambda_penalty'] = 1e-3
args['distil_weight'] = 0.8

NAS  = NAS_GatedConvolution(verbose=False)

# TEACHER
teacher_model = ResNet18()
resume_checkpoint(teacher_model, 'resnet18_net_e_199', map_location=torch.device(device))

# STUDENT
student_model = ResNet18()
#resume_checkpoint(student_model, 'resnet18_net_e_199', map_location=torch.device(device))
# freezeAllLayers(student_model)
NAS.apply_gates_to_model(student_model, apply_to_layer_types=[nn.Conv2d])    

# DISTILL
teacher_optimizer = optim.SGD(teacher_model.parameters(), lr=args['lr'],
                      momentum=0.7, weight_decay=5e-4)
student_optimizer = optim.SGD(student_model.parameters(), lr=args['lr'],
                      momentum=0.7, weight_decay=5e-4)


distiller = GatedKD(teacher_model, student_model, trainloader, testloader, 
                    teacher_optimizer, student_optimizer, loss_fn=nn.KLDivLoss(), 
                    gating_weight=args['lambda_penalty'], gate_layers=NAS.gate_layers, 
                    distil_weight=args['distil_weight'], log=True, logdir="./logs", device=device)  

## Training student

In [None]:
distiller.train_student(epochs=100, save_model_pth='./checkpoint/student_e100.pth')

Training Student...


  "reduction: 'mean' divides the total loss by both the batch size and the support size."


--------------------------------------------------------------------------------
Validation Accuracy: 0.3025
Epoch: 1, Loss: -25890.24652862549, Accuracy: 0.23678
--------------------------------------------------------------------------------
Validation Accuracy: 0.3858
Epoch: 2, Loss: -29792.415840148926, Accuracy: 0.33348
--------------------------------------------------------------------------------
Validation Accuracy: 0.4358
Epoch: 3, Loss: -29822.436561584473, Accuracy: 0.39696
--------------------------------------------------------------------------------
Validation Accuracy: 0.4633
Epoch: 4, Loss: -29828.41371154785, Accuracy: 0.43654
--------------------------------------------------------------------------------
Validation Accuracy: 0.4744
Epoch: 5, Loss: -29832.67473602295, Accuracy: 0.4614
--------------------------------------------------------------------------------
Validation Accuracy: 0.509
Epoch: 6, Loss: -29836.16623687744, Accuracy: 0.47956
----------------------

In [65]:
NAS.verbose=True
NAS.gating_threshold = 1e-2
idxs = NAS.estimate_required_channels(use_mean=0.5)   
NAS.verbose=False

0 Layer index  0  | Zeros 0 / 64  | Mean 0.993452787399292  | Std 0.0010497854091227055
1 Layer index  1  | Zeros 0 / 64  | Mean 0.9933662414550781  | Std 0.0008041745168156922
2 Layer index  2  | Zeros 0 / 64  | Mean 0.9902085661888123  | Std 0.0009862545412033796
3 Layer index  3  | Zeros 0 / 64  | Mean 0.9933662414550781  | Std 0.0008041745168156922
4 Layer index  4  | Zeros 0 / 64  | Mean 0.9902139902114868  | Std 0.0011929869651794434
5 Layer index  5  | Zeros 0 / 128  | Mean 0.9918797016143799  | Std 0.0005994968232698739
6 Layer index  6  | Zeros 0 / 128  | Mean 0.9933137893676758  | Std 0.0005438151420094073
7 Layer index  7  | Zeros 0 / 128  | Mean 0.9918797016143799  | Std 0.0005994917009957135
8 Layer index  8  | Zeros 0 / 128  | Mean 0.9933137893676758  | Std 0.0005437806248664856
9 Layer index  9  | Zeros 0 / 128  | Mean 0.9918797016143799  | Std 0.0005994464736431837
10 Layer index  10  | Zeros 0 / 256  | Mean 0.9932596683502197  | Std 0.00036435603396967053
11 Layer inde

## Old Training Step

In [None]:
NAS  = NAS_GatedConvolution(verbose=False)

print('==> Praparing base model..')
gated_model = ResNet18()
resume_checkpoint(gated_model, 'resnet18_net_e_199', map_location=torch.device(device))
freezeAllLayers(gated_model)
NAS.apply_gates_to_model(gated_model, apply_to_layer_types=[nn.Conv2d])    
#NAS.estimate_required_channels(use_mean= 0.8)    

==> Praparing base model..


[(GatedConvolution(n_channels64), 'conv1', ResNet(
    (conv1): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): GatedConvolution(n_channels64)
    )
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GatedConvolution(n_channels64)
        )
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): GatedConvolution(n_channels64)
        )
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlock(
        (conv1): Sequential(
         

In [None]:
# Model
print('==> Building model..')
net = gated_model.to(device)

if device == 'cuda':
    # net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

if args['resume']:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    resume_checkpoint(net, args['resume'])

==> Building model..


In [None]:
criterion = nn.CrossEntropyLoss()
lamba_gating_criterion = 1
optimizer = optim.SGD(net.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

In [None]:
# Training
def train(epoch, gating_regularization=False, lamba_gating_criterion=0.1):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets) 
        if gating_regularization:
            gating_loss = torch.sum(torch.Tensor([torch.norm(g[0].weight,1) for g in NAS.gate_layers]))
            loss += lamba_gating_criterion * gating_loss
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))


def test(epoch, save_ckpt='ckpt'):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Test Loss: %.3f | Test Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, f'./checkpoint/{save_ckpt}.pth')
        best_acc = acc


for epoch in range(start_epoch, start_epoch+50):
    train(epoch, gating_regularization=True, lamba_gating_criterion=lamba_gating_criterion)
    test(epoch, save_ckpt=f'pretrained_gatednet_e_{epoch}')
    scheduler.step()


Epoch: 0
Batch 390/391 : Loss: 4378.415 | Acc: 99.996% (49998/50000)
Batch 99/100 : Test Loss: 0.177 | Test Acc: 95.330% (9533/10000)
Saving..

Epoch: 1
Batch 390/391 : Loss: 3598.443 | Acc: 99.998% (49999/50000)
Batch 99/100 : Test Loss: 0.177 | Test Acc: 95.410% (9541/10000)
Saving..

Epoch: 2
Batch 390/391 : Loss: 2958.613 | Acc: 99.998% (49999/50000)
Batch 99/100 : Test Loss: 0.176 | Test Acc: 95.340% (9534/10000)

Epoch: 3
Batch 390/391 : Loss: 2436.752 | Acc: 99.994% (49997/50000)
Batch 99/100 : Test Loss: 0.172 | Test Acc: 95.210% (9521/10000)

Epoch: 4
Batch 390/391 : Loss: 2022.416 | Acc: 99.994% (49997/50000)
Batch 99/100 : Test Loss: 0.172 | Test Acc: 94.930% (9493/10000)

Epoch: 5
Batch 390/391 : Loss: 1738.249 | Acc: 99.978% (49989/50000)
Batch 99/100 : Test Loss: 0.176 | Test Acc: 94.790% (9479/10000)

Epoch: 6
Batch 390/391 : Loss: 1595.579 | Acc: 99.974% (49987/50000)
Batch 99/100 : Test Loss: 0.183 | Test Acc: 94.690% (9469/10000)

Epoch: 7
Batch 390/391 : Loss: 1518.

In [None]:
gate_weight = ([g[0].weight for g in NAS.gate_layers])
old_l = NAS.gate_layers
NAS.verbose=True
NAS.gating_threshold = 1e-2
idxs = NAS.estimate_required_channels(use_mean=0.3)   
NAS.verbose=False
#NAS.calc_improvement(ResNet18(), gated_model)
# new_model = ResNet18()
# NAS.prune_model_channels(new_model, pruning_idxs=idxs)

0 Layer index  0  | Zeros 0 / 64  | Mean 0.010490129701793194  | Std 0.005531997419893742
1 Layer index  1  | Zeros 5 / 64  | Mean 0.056963805109262466  | Std 0.026103349402546883
2 Layer index  2  | Zeros 9 / 64  | Mean 0.07979716360569  | Std 0.03785810247063637
3 Layer index  3  | Zeros 25 / 64  | Mean 0.03719443082809448  | Std 0.033819276839494705
4 Layer index  4  | Zeros 24 / 64  | Mean 0.0835062637925148  | Std 0.0638878270983696
5 Layer index  5  | Zeros 5 / 128  | Mean 0.04898415505886078  | Std 0.017910894006490707
6 Layer index  6  | Zeros 3 / 128  | Mean 0.09449909627437592  | Std 0.03006417490541935
7 Layer index  7  | Zeros 55 / 128  | Mean 0.054971545934677124  | Std 0.04958360642194748
8 Layer index  8  | Zeros 23 / 128  | Mean 0.05743551254272461  | Std 0.03022054024040699
9 Layer index  9  | Zeros 34 / 128  | Mean 0.10885466635227203  | Std 0.07703981548547745
10 Layer index  10  | Zeros 0 / 256  | Mean 0.07427909970283508  | Std 0.01823461428284645
11 Layer index  1

KeyError: ignored

In [None]:
net = gated_model.to(device)
test(epoch, save_ckpt=f'pretrained_gatednet_e_prova')

Batch 99/100 : Test Loss: 239542481526456.312 | Test Acc: 10.470% (1047/10000)


In [None]:
NAS.calc_improvement(ResNet18(), gated_model)

0.0

In [None]:
# TRAIN BASE
# Epoch: 4
# Batch 390/391 : Loss: 0.895 | Acc: 68.284% (34142/50000)
# Batch 99/100 : Test Loss: 0.918 | Test Acc: 67.270% (6727/10000)

# Epoch: 5
# Batch 390/391 : Loss: 0.768 | Acc: 73.060% (36530/50000)
# Batch 99/100 : Test Loss: 0.735 | Test Acc: 74.720% (7472/10000)

# ...

# Epoch: 10
# Batch 390/391 : Loss: 0.507 | Acc: 82.666% (41333/50000)
# Batch 99/100 : Test Loss: 0.672 | Test Acc: 77.910% (7791/10000)

# Epoch: 11
# Batch 390/391 : Loss: 0.487 | Acc: 83.324% (41662/50000)
# Batch 99/100 : Test Loss: 0.723 | Test Acc: 75.570% (7557/10000)

# ....

# Epoch: 23
# Batch 390/391 : Loss: 0.375 | Acc: 87.314% (43657/50000)
# Batch 99/100 : Test Loss: 0.508 | Test Acc: 83.310% (8331/10000)

# Epoch: 24
# Batch 390/391 : Loss: 0.370 | Acc: 87.426% (43713/50000)
# Batch 99/100 : Test Loss: 0.502 | Test Acc: 83.690% (8369/10000)
# Saving..

# Epoch: 25
# Batch 390/391 : Loss: 0.370 | Acc: 87.196% (43598/50000)
# Batch 99/100 : Test Loss: 0.468 | Test Acc: 84.700% (8470/10000)
# Saving..