In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

Simple model architecture for solving MNIST, that uses 2 conv and 2 fully connected layers

In [3]:
class Net(nn.Module):
    def __init__(self, mnist=True):
      
        super(Net, self).__init__()
          
        #conv1
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=1)
        
        #conv2
        self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, kernel_size=5, stride=1)
        
        #fully connected layer 1
        self.fc1 = nn.Linear(in_features=4*4*50, out_features=500)
        
        #fully connected layer 2
        self.fc2 = nn.Linear(in_features=500, out_features=10)

    #forward function  
    def forward(self, x):
        
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)   
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return F.log_softmax(x, dim=1)
    
    

Simple training script

In [4]:
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % args["log_interval"] == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            
def testQuant(model, test_loader, quant=False, stats=None):
    device = 'cpu'
    
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            if quant:
                output = quantForward(model, data, stats)
            else:
                output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [5]:
def main():
 
    batch_size = 64
    test_batch_size = 64
    epochs = 3
    lr = 0.01
    momentum = 0.5
    seed = 1
    log_interval = 500
    save_model = False
    no_cuda = False
    
    use_cuda = not no_cuda and torch.cuda.is_available()

    torch.manual_seed(seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True, **kwargs)
    
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=test_batch_size, shuffle=True, **kwargs)
    
  
    model = Net().to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    args = {}
    args["log_interval"] = log_interval
    for epoch in range(1, epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        testQuant(model, test_loader)

    if (save_model):
        torch.save(model.state_dict(),"mnist_cnn.pt")
    
    return model

Using downloaded and verified file: ./data/ILSVRC2012_devkit_t12.tar.gz
Extracting ./data/ILSVRC2012_devkit_t12.tar.gz to /var/folders/th/1db1p8gj5jx_vgvgmsfdx57h0000gn/T/tmpulg8rl6x
Downloading http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar to ./data/ILSVRC2012_img_val.tar


HTTPError: HTTP Error 404: Not Found

In [6]:
model = main()


Test set: Average loss: 0.1017, Accuracy: 9661/10000 (97%)


Test set: Average loss: 0.0610, Accuracy: 9826/10000 (98%)


Test set: Average loss: 0.0550, Accuracy: 9815/10000 (98%)



In [89]:
model

Parameter containing:
tensor([[[[ 9.1196e-02, -1.1278e-01, -1.0681e-01, -3.7622e-03, -2.4457e-01],
          [ 1.6207e-01,  2.9669e-02,  1.3495e-01,  4.6878e-02, -9.9357e-03],
          [ 1.0880e-01,  1.1285e-01,  1.8251e-01,  2.5795e-02,  5.4897e-02],
          [-3.3242e-02,  5.4917e-02,  4.2395e-02,  2.3614e-01,  1.4821e-01],
          [-1.3611e-01, -1.7517e-01, -1.0584e-01, -1.2284e-01, -2.2899e-02]]],


        [[[ 7.0385e-02,  2.5570e-01,  1.9098e-01, -2.3319e-01,  4.3958e-02],
          [ 1.3056e-01,  3.2718e-01,  1.7801e-01, -2.5750e-01, -2.8356e-01],
          [ 8.2280e-03,  3.1927e-01, -1.4386e-02, -2.8260e-05, -1.8262e-01],
          [ 2.8211e-01,  1.7638e-02,  1.5929e-01, -4.8542e-02, -1.4374e-01],
          [ 1.8625e-01, -1.9564e-02,  8.1077e-02, -6.1486e-02, -1.8115e-02]]],


        [[[-2.3365e-01, -1.2033e-01,  1.2521e-01,  3.4815e-02,  3.0283e-01],
          [ 1.2494e-01,  7.7094e-03, -3.3054e-02,  2.7073e-01,  1.6790e-01],
          [-1.2619e-01,  1.8656e-01,  2.3734e-

# A Whitepaper: Uniform Affine Quantizer

Formula: <br>

<center><i>x_int = round(x/scale) + zero_point</i></center>
<center><i>x_Q = clamp(0, qmax-1, x_int)</i></center>

De-quantization, <center><i>x = x_Q * scale<i></center>

In [52]:
from collections import namedtuple
QTensor = namedtuple('QTensor', ['tensor', 'scale', 'zero_point'])

In [53]:
def calcScaleZeroPoint(min_val, max_val,num_bits=8):
    # Calc Scale and zero point of next 
    qmin = 0.
    qmax = 2.**num_bits - 1.

    scale = (max_val - min_val) / (qmax - qmin)

    initial_zero_point = qmin - min_val / scale

    zero_point = 0
    if initial_zero_point < qmin:
        zero_point = qmin
    elif initial_zero_point > qmax:
        zero_point = qmax
    else:
        zero_point = initial_zero_point

    zero_point = int(zero_point)

    return scale, zero_point

def quantize_tensor(x, num_bits=8, min_val=None, max_val=None):
    
    if not min_val and not max_val: 
        min_val, max_val = x.min(), x.max()

    qmin = 0.
    qmax = 2.**num_bits - 1.

    scale, zero_point = calcScaleZeroPoint(min_val, max_val, num_bits)
    q_x = zero_point + x / scale
    q_x.clamp_(qmin, qmax).round_()
    q_x = q_x.round().byte()
    
    return QTensor(tensor=q_x, scale=scale, zero_point=zero_point)

def dequantize_tensor(q_x):
    return q_x.scale * (q_x.tensor.float() - q_x.zero_point)

In [170]:
def quantizeLayer(x, layer, stat, scale_x, zp_x):
    # for both conv and linear layers

    # cache old values
    W = layer.weight.data
    B = layer.bias.data

    # quantise weights, activations are already quantised
    w = quantize_tensor(layer.weight.data) 
    b = quantize_tensor(layer.bias.data)

    layer.weight.data = w.tensor.float()
    layer.bias.data = b.tensor.float()

    # This is Quantisation Artihmetic
    scale_w = w.scale
    zp_w = w.zero_point
    scale_b = b.scale
    zp_b = b.zero_point
  
    scale_next, zero_point_next = calcScaleZeroPoint(min_val=stat['min'], max_val=stat['max'])

    # Preparing input by shifting
    X = scale_x*(x.float() - zp_x)
    layer.weight.data = scale_w*(layer.weight.data - zp_w)
    #layer.weight.data = scale_w*(layer.weight.data - zp_w)
    layer.bias.data = scale_b*(layer.bias.data - zp_b)

    # All int computation
    x = (layer(X)/ scale_next).round() + zero_point_next 

    # Perform relu too
    x = F.relu(x)

    # Reset weights for next forward pass
    layer.weight.data = W
    layer.bias.data = B
  
    return x, scale_next, zero_point_next

### Get Max and Min Stats for Quantising Activations of Network.
This is done by running the network with around 1000 examples and getting the average min and max activation values before and after each layer.

In [198]:
# Get Min and max of x tensor, and stores it
def updateStats(x, stats, key):
    max_val, _ = torch.max(x, dim=1)
    min_val, _ = torch.min(x, dim=1)
  
  
    if key not in stats:
        stats[key] = {"max": int(max_val.sum()), "min": int(min_val.sum())}
    else:
        if stats[key]['max'] < int(max_val.sum().item()):
            stats[key]['max'] = int(max_val.sum().item())
        if stats[key]['min'] > min_val.sum().item():
            stats[key]['min'] = int(min_val.sum().item())
  
    return stats

# Reworked Forward Pass to access activation Stats through updateStats function
def gatherActivationStats(model, x, stats):

    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv1') 
    
    x = F.relu(model.conv1(x))

    x = F.max_pool2d(x, 2, 2)
  
    stats = updateStats(x.clone().view(x.shape[0], -1), stats, 'conv2')
  
    x = F.relu(model.conv2(x))

    x = F.max_pool2d(x, 2, 2)

    x = x.view(-1, 4*4*50)
  
    stats = updateStats(x, stats, 'fc1')

    x = F.relu(model.fc1(x))
  
    stats = updateStats(x, stats, 'fc2')

    x = model.fc2(x)

    return stats

# Entry function to get stats of all functions.
def gatherStats(model, test_loader):
    device = 'cpu'
    
    model.eval()
    stats = {}
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            stats = gatherActivationStats(model, data, stats)
    
    final_stats = {}
    for key, value in stats.items():
          final_stats[key] = { "max" : value["max"], "min" : value["min"]}
    return final_stats

In [56]:
def quantForward(model, x, stats):
  
    # Quantise before inputting into incoming layers
    x = quantize_tensor(x, min_val=stats['conv1']['min'], max_val=stats['conv1']['max'])

    x, scale_next, zero_point_next = quantizeLayer(x.tensor, model.conv1, stats['conv2'], x.scale, x.zero_point)

    x = F.max_pool2d(x, 2, 2)
  
    x, scale_next, zero_point_next = quantizeLayer(x, model.conv2, stats['fc1'], scale_next, zero_point_next)

    x = F.max_pool2d(x, 2, 2)

    x = x.view(-1, 4*4*50)

    x, scale_next, zero_point_next = quantizeLayer(x, model.fc1, stats['fc2'], scale_next, zero_point_next)
  
    # Back to dequant for final layer
    x = dequantize_tensor(QTensor(tensor=x, scale=scale_next, zero_point=zero_point_next))
   
    x = model.fc2(x)

    return F.log_softmax(x, dim=1)

# Start Quantizing

In [155]:
import copy
q_model = copy.deepcopy(model)

In [156]:
kwargs = {'num_workers': 1, 'pin_memory': True}
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])), 
    batch_size=64, shuffle=True, **kwargs)

In [157]:
testQuant(q_model, test_loader, quant=False)


Test set: Average loss: 0.0550, Accuracy: 9815/10000 (98%)



### Stats of Activations

In [191]:
testQuant(q_model, test_loader, quant=True, stats=stats)


Test set: Average loss: 0.2590, Accuracy: 9771/10000 (98%)



In [199]:
stats = gatherStats(q_model, test_loader)
print(stats)

{'conv1': {'max': 180, 'min': -27}, 'conv2': {'max': 531, 'min': 0}, 'fc1': {'max': 866, 'min': 0}, 'fc2': {'max': 465, 'min': 0}}


In [200]:
testQuant(q_model, test_loader, quant=True, stats=stats)


Test set: Average loss: 0.0617, Accuracy: 9800/10000 (98%)

