In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import  torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

In [2]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [3]:
imagenet_data = torchvision.datasets.CIFAR10('D:\\datasets\\', download = True, transform = transform)

Files already downloaded and verified


In [4]:
dataloader = torch.utils.data.DataLoader(imagenet_data, batch_size=2, shuffle = False, num_workers=4)

In [5]:
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(identity)

        out += identity
        out = self.relu(out)

        return out

class ResNet(nn.Module):

    def __init__(self, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        
        block = BasicBlock
        layers = [3, 4, 6, 3]
        
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    # Allow for accessing forward method in a inherited class
    forward = _forward


In [6]:
def measure_time(x, layer):
    global device
    
    if device == torch.device('cpu'):
        start = time.time()
        out = layer(x)
        end = time.time()
        return out, (end - start) * 1000
    else:
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        out = layer(x)
        end.record()
        torch.cuda.synchronize()
        return out, start.elapsed_time(end)
    

class modifiedBasicBlock(nn.Module):
    expansion = 1
    __constants__ = ['downsample']

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(modifiedBasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        time_arr = []
        global device
        
        identity = x
        
        out, exec_time = measure_time(x, self.conv1) 
        time_arr.append(exec_time)
        out = self.bn1(out)
        out = self.relu(out)
        
        out, exec_time = measure_time(out, self.conv2) 
        time_arr.append(exec_time)
        out = self.bn2(out)

        if self.downsample is not None:
            for layer in self.downsample:
                if isinstance(layer, nn.Conv2d):
                    identity, exec_time = measure_time(identity, layer) 
                    time_arr.append(exec_time)
                else:
                    identity = layer(identity)

        out += identity
        out = self.relu(out)

        return out, time_arr

class modifiedResNet(nn.Module):

    def __init__(self, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        
        block = modifiedBasicBlock
        layers = [3, 4, 6, 3]
        
        super(modifiedResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, modifiedBasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward(self, x):
        time_arr = []
        global device
        
        x, exec_time = measure_time(x, self.conv1) 
        time_arr.append(exec_time)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        for layer in [self.layer1, self.layer2, self.layer3, self.layer4]:
            for blk in layer:
                x, time_list = blk(x)
                time_arr += time_list

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x, time_arr

    # Allow for accessing forward method in a inherited class
    forward = _forward


In [7]:
resnet = models.resnet34(pretrained=True)

In [11]:
best_model_wts = copy.deepcopy(resnet.state_dict())
model = modifiedResNet()
device = torch.device('cuda')
model.to(device)
model.load_state_dict(best_model_wts)

<All keys matched successfully>

In [12]:
num_conv = 36

since = time.time()
model.eval()

with torch.no_grad():

    if isinstance(model, modifiedResNet):
        tot_times = np.zeros(num_conv)
        #deb_times = np.zeros(8)
        #end_times = 0
    corrects = 0
    done = 0

    end_to_end_time = 0.0

    with torch.autograd.profiler.profile(use_cuda = (device == torch.device('cuda'))) as prof:
        for inputs, labels in dataloader:

            start_main = torch.cuda.Event(enable_timing=True)
            end_main = torch.cuda.Event(enable_timing=True)
            start_main.record()

            inputs = inputs.to(device)
            labels = labels.to(device)
            done += 1
            if done * len(inputs) >= 100:
                break
            print(done * len(inputs), end = '\r')

            if isinstance(model, modifiedResNet):
                #outputs, exec_times, hid_times, end_time = model(inputs)
                outputs, exec_times = model(inputs)
                tot_times += exec_times
                #deb_times += hid_times
                #end_times += end_time
            else:
                outputs = model(inputs)

            end_main.record()
            torch.cuda.synchronize()
            end_to_end_time += start_main.elapsed_time(end_main)

            _, preds = torch.max(outputs, 1)
            corrects += torch.sum(preds == labels.data)

    acc = corrects.double() / (done * len(inputs))
    print('Acc: {:.4f}'.format(acc))

    if isinstance(model, modifiedResNet):
        print(tot_times)
        #print(deb_times)
        #print(end_times)
        tot_times = tot_times / done
        print(tot_times)

    time_elapsed = time.time() - since
    print('Total time taken = {} seconds'.format(time_elapsed))

    end_to_end_time = end_to_end_time / (done)
    print('Avg End to End = {} ms'.format(end_to_end_time))



Acc: 0.0000
[783.56583387  20.00281602  20.610816    20.14288008  19.33350405
  18.20691189  18.13279995  18.18732804  17.97503996  15.74630402
  18.22339204  17.66812804  17.21609592  17.68863994  16.95247996
  17.65817595  19.35884792  20.35078397  14.580384    20.57139194
  20.20355195  19.93833607  20.06607997  20.01260796  20.10051203
  20.21478406  20.41337606  20.27904004  20.89686397  26.34995201
  35.0767681   14.82057597  33.50294399  32.64422393  32.77196777
  34.92768019]
[15.67131668  0.40005632  0.41221632  0.4028576   0.38667008  0.36413824
  0.362656    0.36374656  0.3595008   0.31492608  0.36446784  0.35336256
  0.34432192  0.3537728   0.3390496   0.35316352  0.38717696  0.40701568
  0.29160768  0.41142784  0.40407104  0.39876672  0.4013216   0.40025216
  0.40201024  0.40429568  0.40826752  0.4055808   0.41793728  0.52699904
  0.70153536  0.29641152  0.67005888  0.65288448  0.65543936  0.6985536 ]
Total time taken = 10.766541957855225 seconds
Avg End to End = 57.438330

In [None]:
print(prof.key_averages().table())

In [None]:
alexnet, vgg16, vgg19, lenet5, zfnet, resnet34