In [1]:
# Import Necessary Libraries
import torch
import torch.quantization
import torch.nn.utils.prune as prune
import torch.nn as nn
from thop import profile
import os

import warnings
warnings.filterwarnings('ignore')

%matplotlib inline

In [2]:
# Global Variables
seed = torch.manual_seed(29592)  # set the seed for reproducibility

## Model

In [3]:
class SimpleConvModel(torch.nn.Module):
	def __init__(self):
		super(SimpleConvModel, self).__init__()
		self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
		self.conv2 = torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1)
		self.relu = torch.nn.ReLU()
		self.fc = torch.nn.Linear(in_features=16*32*32, out_features=1000)

	def forward(self, x):
		x = self.relu(self.conv1(x))
		x = self.relu(self.conv2(x))
		x = x.reshape(x.size(0), -1)
		x = self.fc(x)
		return x

## LPCV Techniques

In [4]:
def quantization(model, dtype):
	return torch.quantization.quantize_dynamic(model, {nn.Conv2d, nn.Linear}, dtype=dtype)

In [5]:
def pruning(model, pruning_perc=0.5):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            prune.ln_structured(module, name='weight', amount=pruning_perc, n=2, dim=0)
            prune.remove(module, 'weight')
    return model

In [6]:
## Revising the base model to include depthwise convolutions to demonstrate layer compression
class CompressedModel(torch.nn.Module):
	def __init__(self):
		super(CompressedModel, self).__init__()
		def convolution_with_batch_norm(in_ch, out_ch, stride):
			return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(num_features=out_ch),
                nn.ReLU(inplace=True)
                )

		def depth_wise_convolution(in_ch, out_ch, stride):
			return nn.Sequential(
                # depthwise convolution
                nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, groups=in_ch, bias=False),
                nn.BatchNorm2d(in_ch),
                nn.ReLU(inplace=True),

                # pointwise convolution
                nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                )

		self.model = nn.Sequential(convolution_with_batch_norm(3, 16, 1),
		depth_wise_convolution(16, 16, 1),
		nn.AdaptiveAvgPool2d(output_size=1)
		)

		self.fc = torch.nn.Linear(in_features=16*1*1, out_features=1000)

	def forward(self, x):
		x = self.model(x)
		x = x.reshape(x.size(0), -1)
		x = self.fc(x)
		return x

## Utilities

In [7]:
def print_model_structure(model, type):
	print(f'Here is the {type} version of this module:')
	print(model)

In [8]:
def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p")
    print("model: ",label,' \t','Size (MB):', size/1e6)
    os.remove('temp.p')
    return size

In [9]:
def num_operations(model, input_size, label):
	input_size = torch.randn(input_size)
	flops, params = profile(model, inputs=(input_size, ))
	print(f"For model {label}, FLOPs: {flops}, Params: {params}")

In [10]:
def num_parameters(model, type):
	params = sum(p.numel() for p in model.parameters())
	print(f"Number of parameters for {type}: {params}")

## Main

In [11]:
model = SimpleConvModel()
inputs = torch.randn(1,3,32,32)

### Base Model

In [12]:
print_model_structure(model,'base')

Here is the base version of this module:
SimpleConvModel(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu): ReLU()
  (fc): Linear(in_features=16384, out_features=1000, bias=True)
)


In [13]:
print_size_of_model(model,'base')

model:  base  	 Size (MB): 65.553445


65553445

In [14]:
%timeit model(inputs)

1.81 ms ± 6.21 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [15]:
num_operations(model, inputs.shape, 'base')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
For model base, FLOPs: 19185664.0, Params: 16387768.0


### Quantized

In [16]:
quantized  = quantization(model, dtype=torch.qint8)

In [17]:
print_model_structure(quantized,'quantized')

Here is the quantized version of this module:
SimpleConvModel(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu): ReLU()
  (fc): DynamicQuantizedLinear(in_features=16384, out_features=1000, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)


In [18]:
print_size_of_model(quantized,'quantized')

model:  quantized  	 Size (MB): 16.402761


16402761

In [19]:
%timeit quantized(inputs)

759 µs ± 564 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [20]:
num_operations(quantized, inputs.shape, 'quantized')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
For model quantized, FLOPs: 2801664.0, Params: 2768.0


### Pruning

In [21]:
pruned = pruning(model, pruning_perc=0.2)

In [22]:
print_model_structure(pruned,'pruned')

Here is the pruned version of this module:
SimpleConvModel(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu): ReLU()
  (fc): Linear(in_features=16384, out_features=1000, bias=True)
)


In [23]:
print_size_of_model(pruned,'pruned')

model:  pruned  	 Size (MB): 65.553975


65553975

In [24]:
%timeit pruned(inputs)

1.84 ms ± 3.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [25]:
num_operations(pruned, inputs.shape, 'pruned')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
For model pruned, FLOPs: 19185664.0, Params: 16387768.0


### Layer Compression

In [26]:
compressed = CompressedModel()

In [27]:
print_model_structure(compressed,'compressed')

Here is the compressed version of this module:
CompressedModel(
  (model): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (2): AdaptiveAvgPool2d(output_size=1)
  )
  (fc): Linear(in_features=16, out_features=1000, bias=True)
)


In [28]:
print_size_of_model(compressed,'compressed')

model:  compressed  	 Size (MB): 0.078755


78755

In [29]:
%timeit compressed(inputs)

679 µs ± 519 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [30]:
num_operations(compressed, inputs.shape, 'compressed')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
For model compressed, FLOPs: 1080976.0, Params: 17928.0
