In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torchsummary import summary
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import time
import math
from tqdm.autonotebook import tqdm

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
transformTrain = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transformTest = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform = transformTrain)

trainLoader = torch.utils.data.DataLoader(trainset, batch_size = 128,
                                          shuffle=True, num_workers=2)

test = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform = transformTest)
testLoader = torch.utils.data.DataLoader(test, batch_size=128,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
__all__ = ['resnet20_cifar', 'resnet32_cifar', 'resnet44_cifar', 'resnet56_cifar']

NUM_CLASSES = 10

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, block_gates, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.block_gates = block_gates
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu1 = nn.ReLU(inplace=False)  # To enable layer removal inplace must be False
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.ReLU(inplace=False)
        self.downsample = downsample
        self.stride = stride
        self.residual_eltwiseadd = EltwiseAdd()

    def forward(self, x):
        residual = out = x

        if self.block_gates[0]:
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu1(out)

        if self.block_gates[1]:
            out = self.conv2(out)
            out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out = self.residual_eltwiseadd(residual, out)
        out = self.relu2(out)

        return out


class ResNetCifar(nn.Module):

    def __init__(self, block, layers, num_classes=NUM_CLASSES):
        self.nlayers = 0
        # Each layer manages its own gates
        self.layer_gates = []
        for layer in range(3):
            # For each of the 3 layers, create block gates: each block has two layers
            self.layer_gates.append([])  # [True, True] * layers[layer])
            for blk in range(layers[layer]):
                self.layer_gates[layer].append([True, True])

        self.inplanes = 16  # 64
        super(ResNetCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(self.layer_gates[0], block, 16, layers[0])
        self.layer2 = self._make_layer(self.layer_gates[1], block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(self.layer_gates[2], block, 64, layers[2], stride=2)
        self.avgpool = nn.AvgPool2d(8, stride=1)
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, layer_gates, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(layer_gates[0], self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(layer_gates[i], self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def resnet20_cifar(**kwargs):
    model = ResNetCifar(BasicBlock, [3, 3, 3], **kwargs)
    return model

def resnet32_cifar(**kwargs):
    model = ResNetCifar(BasicBlock, [5, 5, 5], **kwargs)
    return model

def resnet44_cifar(**kwargs):
    model = ResNetCifar(BasicBlock, [7, 7, 7], **kwargs)
    return model

def resnet56_cifar(**kwargs):
    model = ResNetCifar(BasicBlock, [9, 9, 9], **kwargs)
    return model

In [None]:
net = resnet56_cifar()
name = "resnet-56-distiller"
epochs = 200

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.backends.cudnn.benchmark = True
net = net.to(device)
print(device)

cuda


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), momentum = 0.9, weight_decay = 1e-4, nesterov=True, lr=0.3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=45, gamma=0.1)

# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainLoader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    return train_loss/len(trainLoader), 100.*correct/total

best_acc = 0
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testLoader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        checkpoint = {
            'net': net.state_dict(),
            'arch' : 'resnet50',
            'optimizer_type' : torch.optim.SGD,
            'optimizer_state_dict' : optimizer.state_dict(),
            'acc': acc,
            'epoch': epoch
        }
        torch.save(checkpoint, '/content/drive/MyDrive/Colab Notebooks/6787 Notebooks/models/' + name + '-' + str(epoch) + '-' + str(acc))
        best_acc = acc
    return acc


for epoch in range(epochs):
    startTime = time.time()
    trainLoss, trainAcc = train(epoch)
    testAcc = test(epoch)
    scheduler.step()
    endTime = time.time() - startTime
    print("Trn L: {:.4f}".format(trainLoss) + " " + "Trn A: {:.4f}".format(trainAcc) + " " + "Test A: {:.4f}".format(testAcc) + str(endTime // 60) + ":" + str(int(endTime % 60)))


Epoch: 0
Saving..
Trn L: 4.4565 Trn A: 10.4120 Test A: 13.00000.0:39

Epoch: 1
Saving..
Trn L: 2.1199 Trn A: 18.1840 Test A: 21.49000.0:37

Epoch: 2
Saving..
Trn L: 1.9348 Trn A: 24.7220 Test A: 26.70000.0:37

Epoch: 3
Saving..
Trn L: 1.8077 Trn A: 30.0740 Test A: 34.23000.0:36

Epoch: 4
Trn L: 1.7198 Trn A: 34.8440 Test A: 32.24000.0:36

Epoch: 5
Saving..
Trn L: 1.6295 Trn A: 39.2820 Test A: 41.38000.0:37

Epoch: 6
Saving..
Trn L: 1.5575 Trn A: 42.6480 Test A: 47.43000.0:36

Epoch: 7
Saving..
Trn L: 1.4922 Trn A: 45.5680 Test A: 47.52000.0:37

Epoch: 8
Saving..
Trn L: 1.4519 Trn A: 47.5040 Test A: 49.91000.0:37

Epoch: 9
Saving..
Trn L: 1.4240 Trn A: 48.3720 Test A: 50.71000.0:37

Epoch: 10
Saving..
Trn L: 1.3824 Trn A: 50.3360 Test A: 51.21000.0:37

Epoch: 11
Trn L: 1.3104 Trn A: 53.2580 Test A: 38.86000.0:36

Epoch: 12
Saving..
Trn L: 1.3102 Trn A: 53.1520 Test A: 54.13000.0:37

Epoch: 13
Saving..
Trn L: 1.2204 Trn A: 56.7000 Test A: 55.63000.0:38

Epoch: 14
Trn L: 1.1681 Trn A: 58

KeyboardInterrupt: ignored