In [1]:
# Some standard imports
import io
import numpy as np

from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
import torchvision.transforms as transforms

import torchvision.datasets as datasets
import netron
import time

### part 1.   The cost of doing a convolution
this is the detail in the paper where we do the loop restructuring to improve the locality of the computation and then reduce the computation to low-level primatives.   

The forumla below is for a comvolution.  

$$
Output(i, k, j) ~=~ \sum^3_{t=1} \sum^3_{u=1}\sum^3_{v=1}W(i,t, u,v) \cdot Input(t,k+u-2,j+v-2) 
$$

Here are some of the sample parameters used in the table.

In [3]:
Cout =4
Cin = 64
width = 10
height = 10
Output = torch.zeros(Cout, width, height)
Image = torch.zeros(Cin,width, height)+2
Input = torch.zeros(Cin, width+2, height+2)+1
Input[:,1:width+1, 1:height+1]= Image[:, 0:width, 0:height]
W = torch.zeros(Cout, Cin, 3,3)+1

In [243]:
import time

In [4]:
t0 = time.time()
for i in range(0,Cout):
    for k in range(0, width):
        for j in range(0, height):
            for t in range(0,Cin):
                for u in range(0,3):
                    for v in range(0,3):
                        Output[i,k,j]+= W[i,t,u,v]*Input[t, k+u, j+v]
elapse = time.time()-t0
print(elapse)

7.816189765930176


The first transformation pulls out the inner three looks and does he sum locally rather than writing it back to memory in the array each time

In [5]:
def kernel(W,Input, i, k, j):
    Outv = 0.0
    for t in range(0,Cin):
        for u in range(0,3):
            for v in range(0,3):
                Outv += W[i,t,u,v]*Input[t, k+u, j+v] 
    return Outv
t0 =time.time()
for i in range(0,Cout):
    for k in range(0, width):
        for j in range(0, height):
                Output[i,k,j] = kernel(W,Input,i,k,j)

elapse = time.time()-t0
print(elapse)

4.252971172332764


second transformation.  in kernel function move the t loop inside and replace it with the dot product

In [7]:
def kernel2(W,Input, i, k, j):
    Outv = 0.0
    for u in range(0,3):
        for v in range(0,3):
            Outv += torch.dot(W[i,:,u,v],Input[:, k+u, j+v]) 
    return Outv
t0 = time.time()
for i in range(0,Cout):
    for k in range(0, width):
        for j in range(0, height):
                Output[i,k,j] = kernel2(W,Input,i,k,j)
elapse = time.time()-t0
print(elapse)

0.10005521774291992


now recognize the function is just a sum of pointwise multiplies

In [9]:
def kernel3(W,Input):
    return torch.sum(W*Input) 

t0 = time.time()
for i in range(0,Cout):
    for k in range(0, width):
        for j in range(0, height):
                Output[i,k,j] = kernel3(W[i,:,:,:],Input[i,k:k+3, j:j+3])

elapse = time.time()-t0

print(elapse)

0.02820301055908203


Now do it really fast.   10 invocations of the Conv2d function.

In [11]:
q = torch.nn.Conv2d(Cin,Cout,3, stride=1, padding=1, dilation=1, groups=1, bias=False, padding_mode='zeros')
Im= torch.zeros(1,Cin,width, height)

t0 = time.time()
outl = []
for i in range(0,10):
    outv = q(Im)
    outl.append(outv)
    outv=outv+1

elapse = time.time()-t0
print(elapse)


0.007052421569824219


In [17]:
from PIL import Image
import torch.optim as optim
from tqdm import tqdm
from torchvision import transforms
import torch.nn.functional as F
import torch.nn as nn
import torchvision.datasets as dt
import torch
import os 

The following is a demonstration of generating ONNX from a simplifed version of a residual network.

see https://towardsdatascience.com/residual-network-implementing-resnet-a7da63c7b278
for a more complete version

In [24]:
class ResidualBlock(nn.Module):
    expansion = 1
    def __init__(self, inchannel, outchannel, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
                        nn.BatchNorm2d(outchannel),
                        )
        self.conv2  = nn.Sequential(
                        nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
                        nn.BatchNorm2d(outchannel)
                    )
        self.skip = nn.Sequential()
        if stride != 1 or inchannel != self.expansion * outchannel:
            self.skip = nn.Sequential(
                nn.Conv2d(inchannel, self.expansion * outchannel, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * outchannel)
            )

    def forward(self, X):
        out = F.relu(self.conv1(X))
        out = self.conv2(out)
        out += self.skip(X)
        out = F.relu(out)
        return out


In [25]:
class Model(nn.Module):
    def __init__(self, ResidualBlock, num_classes):
        super(Model, self).__init__()
        self.inchannel = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3, bias=False),
            nn.BatchNorm2d(64),
        )
        self.layer1 = self.make_layer(ResidualBlock, 64,  2, stride=1)
        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
        self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
        self.fc = nn.Linear(512*ResidualBlock.expansion, num_classes)

    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)   
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        i = int(out.size()[3])
        out = F.avg_pool2d(out, i)
        out = torch.flatten(out,1 )
        out = self.fc(out)
        return out


In [26]:
resnet = Model(ResidualBlock,num_classes = 2)
resnet.eval()

Model(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (layer1): Sequential(
    (0): ResidualBlock(
      (conv1): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv2): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (skip): Sequential()
    )
    (1): ResidualBlock(
      (conv1): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv2): Sequential(
  

In [29]:
inframe = torch.zeros(1,3,32,32)

In [30]:
resnet(inframe)

tensor([[ 0.0045, -0.0049]], grad_fn=<AddmmBackward>)

In [32]:
torch.onnx.export(resnet,
                  inframe,   
                  "resnet.onnx",
                  verbose=True,
                  input_names = ['input']
                 )



graph(%input : Float(1, 3, 32, 32),
      %conv1.0.weight : Float(64, 3, 7, 7),
      %conv1.1.weight : Float(64),
      %conv1.1.bias : Float(64),
      %conv1.1.running_mean : Float(64),
      %conv1.1.running_var : Float(64),
      %conv1.1.num_batches_tracked : Long(),
      %layer1.0.conv1.0.weight : Float(64, 64, 3, 3),
      %layer1.0.conv1.1.weight : Float(64),
      %layer1.0.conv1.1.bias : Float(64),
      %layer1.0.conv1.1.running_mean : Float(64),
      %layer1.0.conv1.1.running_var : Float(64),
      %layer1.0.conv1.1.num_batches_tracked : Long(),
      %layer1.0.conv2.0.weight : Float(64, 64, 3, 3),
      %layer1.0.conv2.1.weight : Float(64),
      %layer1.0.conv2.1.bias : Float(64),
      %layer1.0.conv2.1.running_mean : Float(64),
      %layer1.0.conv2.1.running_var : Float(64),
      %layer1.0.conv2.1.num_batches_tracked : Long(),
      %layer1.1.conv1.0.weight : Float(64, 64, 3, 3),
      %layer1.1.conv1.1.weight : Float(64),
      %layer1.1.conv1.1.bias : Float(64),





In [33]:
optimized_model_path = './resnet.onnx'
netron.start(optimized_model_path)

Serving './resnet.onnx' at http://localhost:8080
