In [None]:
import torch 
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import random_split
from torchvision import models
from torchsummary import summary
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')

########## Teacher Definition #############


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError(
                'BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError(
                "Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x):
        return self._forward_impl(x)

def resnet_18(num_classes):

    model = ResNet(block=BasicBlock, layers=[2, 2, 2, 2], num_classes=10)

    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,
                            bias=False)

    return model


In [None]:
########## New Student ResNet block ####################

class ResNet1(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet1, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 32, layers[0])
        self.layer2 = self._make_layer(block, 64, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 128, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 256, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x):
        return self._forward_impl(x)


In [None]:
class Cifar10SearchDataset(torchvision.datasets.CIFAR10):
    def __init__(self, root="~/data/cifar10", train=True, download=True, transform=None):
        super().__init__(root=root, train=train, download=download, transform=transform)

    def __getitem__(self, index):
        image, label = self.data[index], self.targets[index]

        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label

In [None]:
def student_resnet(num_classes, layer):

    model = ResNet(block=BasicBlock, layers= layer, num_classes=10)

    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,
                            bias=False)

    return model

In [None]:
def student_resnet1(num_classes, layer):

    model = ResNet1(block=BasicBlock, layers= layer, num_classes=10)

    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,
                            bias=False)

    return model

In [None]:
# run test set

criterion = nn.CrossEntropyLoss()

def test(model, criterion, test_loader, use_gpu=True):
  test_size = len(test_loader.dataset)
  device = torch.device( "cuda:0" if use_gpu else "cpu" )
  test_loss = 0.0
  test_accuracy = 0
  correct = 0
  model.eval()
  with torch.no_grad():
      for i, data in enumerate(test_loader):
          # get the inputs; data is a list of [inputs, labels]
          inputs, labels = data[0].to(device), data[1].to(device)

          # forward + backward + optimize
          outputs = model(inputs).to(device)
          loss = criterion(outputs, labels)
          
          test_loss += loss * inputs.size(0)
          
          # val accuracy
          _, predicted = torch.max(outputs.data, 1)
          correct += (predicted == labels).sum().item()

      
      test_loss = test_loss/test_size
      test_accuracy = correct/test_size;

  return test_loss, test_accuracy

Baseline

In [None]:
baseline = resnet_18(num_classes = 10).to(device)

checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Weights/Copy of resnet18_baseline_cifar10.pth')
baseline.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

Knowledge Distillation

In [None]:
student_model_2221 = student_resnet(10, [2,2,2,1]).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Weights/student_resnet_50.pth')
student_model_2221.load_state_dict(checkpoint['state_dict'])

student_model_2212 = student_resnet(10, [2,2,1,2]).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Weights/student_resnet_50_2212.pth')
student_model_2212.load_state_dict(checkpoint['state_dict'])

student_model_1111 = student_resnet(10, [1,1,1,1]).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Weights/student_resnet_50_1111.pth')
student_model_1111.load_state_dict(checkpoint['state_dict'])

student_model1_2222 = student_resnet1(10, [2,2,2,2]).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Weights/student_resnet1_50_2222.pth')
student_model1_2222.load_state_dict(checkpoint['state_dict'])

student_model1_2211 = student_resnet1(10, [2,2,1,1]).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Weights/student_resnet1_50_2211.pth')
student_model1_2211.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

Pruning

In [None]:
# Global Unstructured
unstruct_global_01 = resnet_18(num_classes = 10).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Pruned_Models/global_unstruct_sparsity_0.1.pth')
unstruct_global_01.load_state_dict(checkpoint['state_dict'])

unstruct_global_04 = resnet_18(num_classes = 10).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Pruned_Models/global_unstruct_sparsity_0.4.pth')
unstruct_global_04.load_state_dict(checkpoint['state_dict'])

unstruct_global_06 = resnet_18(num_classes = 10).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Pruned_Models/global_unstruct_sparsity_0.6.pth')
unstruct_global_06.load_state_dict(checkpoint['state_dict'])

unstruct_global_075 = resnet_18(num_classes = 10).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Pruned_Models/global_unstruct_sparsity_0.75.pth')
unstruct_global_075.load_state_dict(checkpoint['state_dict'])

unstruct_global_09 = resnet_18(num_classes = 10).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Pruned_Models/global_unstruct_sparsity_0.9.pth')
unstruct_global_09.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [None]:
# local Unstructured

unstruct_local_01 = resnet_18(num_classes = 10).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Pruned_Models/local_unstruct_sparsity_0.1.pth')
unstruct_local_01.load_state_dict(checkpoint['state_dict'])

unstruct_local_04 = resnet_18(num_classes = 10).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Pruned_Models/local_unstruct_sparsity_0.4.pth')
unstruct_local_04.load_state_dict(checkpoint['state_dict'])

unstruct_local_06 = resnet_18(num_classes = 10).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Pruned_Models/local_unstruct_sparsity_0.6.pth')
unstruct_local_06.load_state_dict(checkpoint['state_dict'])

unstruct_local_075 = resnet_18(num_classes = 10).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Pruned_Models/local_unstruct_sparsity_0.75.pth')
unstruct_local_075.load_state_dict(checkpoint['state_dict'])

unstruct_local_09 = resnet_18(num_classes = 10).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Pruned_Models/local_unstruct_sparsity_0.9.pth')
unstruct_local_09.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [None]:
# Structured

structured_01 = resnet_18(num_classes = 10).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Pruned_Models/structured_sparsity_0.1.pth')
structured_01.load_state_dict(checkpoint['state_dict'])

structured_04 = resnet_18(num_classes = 10).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Pruned_Models/structured_sparsity_0.4.pth')
structured_04.load_state_dict(checkpoint['state_dict'])

structured_06 = resnet_18(num_classes = 10).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Pruned_Models/structured_sparsity_0.6.pth')
structured_06.load_state_dict(checkpoint['state_dict'])

structured_075 = resnet_18(num_classes = 10).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Pruned_Models/structured_sparsity_0.75.pth')
structured_075.load_state_dict(checkpoint['state_dict'])

structured_09 = resnet_18(num_classes = 10).to(device)
checkpoint = torch.load('/content/drive/Shareddrives/RDNN/Pruned_Models/structured_sparsity_0.9.pth')
structured_09.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [None]:
test_transform_no_noise = A.Compose(
    [A.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)),
     ToTensorV2()])

test_transform_gaussian = A.Compose(
    [A.GaussNoise(var_limit = (0,2), p=1),
      A.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)),
     ToTensorV2()])

test_transform_randbright = A.Compose(
    [A.RandomBrightnessContrast (brightness_limit=0.2, contrast_limit=0.2, brightness_by_max=True, always_apply=False, p=1),
     A.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)),
     ToTensorV2()])

test_transform_isonoise = A.Compose(
    [A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), always_apply=False, p=1),
     A.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)),
     ToTensorV2()])

test_transform_rgbshift = A.Compose(
    [A.RGBShift(r_shift_limit=20, g_shift_limit=20, b_shift_limit=20, always_apply=False, p=1),
     A.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)),
     ToTensorV2()])

test_transform_pixeldrop = A.Compose(
    [A.PixelDropout (dropout_prob=0.01, per_channel=False, drop_value=0, mask_drop_value=None, always_apply=False, p=1),
     A.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)),
     ToTensorV2()])

test_transform_randomfog = A.Compose(
    [A.RandomFog (fog_coef_lower=0.3, fog_coef_upper=1, alpha_coef=0.08, always_apply=False, p=1),
     A.Normalize((0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784)),
     ToTensorV2()])

In [None]:
noise = [test_transform_no_noise, test_transform_gaussian ,test_transform_randbright, test_transform_isonoise, test_transform_rgbshift,
         test_transform_pixeldrop,  test_transform_randomfog]

In [None]:
noise_labels = ['test_transform_no_noise','test_transform_gaussian' ,'test_transform_randbright', 'test_transform_isonoise', 'test_transform_rgbshift',
         'test_transform_pixeldrop',  'test_transform_randomfog']

In [None]:
def load_test(noise_label):
  test_ds= Cifar10SearchDataset(root='./data', train=False,
                                       download=True, transform=noise_label)
  torch.manual_seed(50)
  test_size = len(test_ds)
  batch_size = 256

  test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False)
  return test_ds, test_loader

In [None]:
import matplotlib.pyplot as plt

In [None]:
models = [baseline,
          student_model_2221, student_model_2212,student_model_1111, student_model1_2222,student_model1_2211,
          unstruct_global_01, unstruct_global_04, unstruct_global_06, unstruct_global_075, unstruct_global_09,
          unstruct_local_01, unstruct_local_04, unstruct_local_06, unstruct_local_075, unstruct_local_09,
          structured_01, structured_04, structured_06, structured_075, structured_09]

In [None]:
models = [baseline,
          student_model_2221]

In [None]:
model_labels = ['baseline',
          'student_model_2221', 'student_model_2212','student_model_1111', 'student_model1_2222','student_model1_2211',
          'unstruct_global_01', 'unstruct_global_04', 'unstruct_global_06', 'unstruct_global_075', 'unstruct_global_09',
          'unstruct_local_01', 'unstruct_local_04', 'unstruct_local_06', 'unstruct_local_075', 'unstruct_local_09',
          'structured_01', 'structured_04', 'structured_06', 'structured_075', 'structured_09']

In [None]:
import csv

In [None]:
# Test the various odels on different noises and save the results

for i in range(len(models)):
  model = models[i]
  model_name = model_labels[i]

  model_name_dict = dict()
  for y in range(len(noise)):
    n = noise[y]
    noise_name = noise_labels[y]
    test_ds, test_loader = load_test(n)

    _ , test_accuracy =test(model, criterion, test_loader, use_gpu=True)
    model_name_dict[f'{noise_name}'] = test_accuracy

  with open(f'/content/{model_name}.csv','w') as f:
    for key in model_name_dict.keys():
      f.write("%s, %f\n"%(key,model_name_dict[key]))