In [1]:
import argparse
import os
import time
import shutil

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

#from tensorboardX import SummaryWriter      

import torchvision
import torchvision.transforms as transforms

from models import *
#from models import vgg_quant

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"

global best_prec
use_gpu = torch.cuda.is_available()
print('=> Building model...')
    
    
batch_size = 128
model_name = "VGG16_quant"
model = VGG16_quant()
print(model)
print(model.features[26])


normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])


train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)


test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))

testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


print_freq = 100 # every 100 batches, accuracy printed. Here, each batch includes "batch_size" data points
# CIFAR10 has 50,000 training data, and 10,000 validation data.

def train(trainloader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.train()

    end = time.time()
    for i, (input, target) in enumerate(trainloader):
        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.cuda(), target.cuda()

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # measure accuracy and record loss
        prec = accuracy(output, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()


        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format(
                   epoch, i, len(trainloader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1))

            

def validate(val_loader, model, criterion ):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
         
            input, target = input.cuda(), target.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec = accuracy(output, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:  # This line shows how frequently print out the status. e.g., i%5 => every 5 batch, prints out
                print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format(
                   i, len(val_loader), batch_time=batch_time, loss=losses,
                   top1=top1))

    print(' * Prec {top1.avg:.3f}% '.format(top1=top1))
    return top1.avg


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

        
def save_checkpoint(state, is_best, fdir):
    filepath = os.path.join(fdir, 'checkpoint.pth')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(fdir, 'model_best.pth.tar'))


def adjust_learning_rate(optimizer, epoch):
    """For resnet, the lr starts from 0.1, and is divided by 10 at 80 and 120 epochs"""
    adjust_list = [150, 225]
    #adjust_list = [40, 80]
    if epoch in adjust_list:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.1        

#model = nn.DataParallel(model).cuda()
#all_params = checkpoint['state_dict']
#model.load_state_dict(all_params, strict=False)
#criterion = nn.CrossEntropyLoss().cuda()
#validate(testloader, model, criterion)

=> Building model...
VGG_quant(
  (features): Sequential(
    (0): QuantConv2d(
      3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
      (weight_quant): weight_quantize_fn()
    )
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): QuantConv2d(
      64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
      (weight_quant): weight_quantize_fn()
    )
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): QuantConv2d(
      64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
      (weight_quant): weight_quantize_fn()
    )
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): QuantConv2d(
      128, 128, kernel_size=(3, 3), stride

In [2]:
# This cell is from the website

lr = 1e-2

weight_decay = 1e-4
epochs = 200
best_prec = 0

model = model.cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
# weight decay: for regularization to prevent overfitting

if not os.path.exists('result'):
    os.makedirs('result')
    
fdir = 'result/'+str(model_name)

if not os.path.exists(fdir):
    os.makedirs(fdir)
        

for epoch in range(0, epochs):
    adjust_learning_rate(optimizer, epoch)

    train(trainloader, model, criterion, optimizer, epoch)
    
    # evaluate on test set
    print("Validation starts")
    prec = validate(testloader, model, criterion)

    # remember best precision and save checkpoint
    is_best = prec > best_prec
    best_prec = max(prec,best_prec)
    print('best acc: {:1f}'.format(best_prec))
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec': best_prec,
        'optimizer': optimizer.state_dict(),
    }, is_best, fdir)
    

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch: [0][0/391]	Time 0.372 (0.372)	Data 0.204 (0.204)	Loss 2.3849 (2.3849)	Prec 15.625% (15.625%)
Epoch: [0][100/391]	Time 0.057 (0.060)	Data 0.002 (0.004)	Loss 2.1634 (2.5237)	Prec 18.750% (12.028%)
Epoch: [0][200/391]	Time 0.055 (0.058)	Data 0.002 (0.003)	Loss 2.3282 (2.3663)	Prec 11.719% (14.179%)
Epoch: [0][300/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 2.2102 (2.2900)	Prec 14.062% (15.293%)
Validation starts
Test: [0/79]	Time 0.213 (0.213)	Loss 2.1143 (2.1143)	Prec 21.875% (21.875%)
 * Prec 19.110% 
best acc: 19.110000
Epoch: [1][0/391]	Time 0.240 (0.240)	Data 0.196 (0.196)	Loss 2.1457 (2.1457)	Prec 17.969% (17.969%)
Epoch: [1][100/391]	Time 0.054 (0.059)	Data 0.002 (0.004)	Loss 2.2627 (2.0683)	Prec 12.500% (19.732%)
Epoch: [1][200/391]	Time 0.060 (0.058)	Data 0.002 (0.003)	Loss 2.2487 (2.0731)	Prec 8.594% (19.450%)
Epoch: [1][300/391]	Time 0.060 (0.057)	Data 0.002 (0.003)	Loss 2.1032 (2.0684)	Prec 21.094% (19.614%)
Validation starts
Test: [0/79]	Time 0.224 (0.224)	Loss 1.9

Epoch: [15][200/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 1.7149 (1.7389)	Prec 28.125% (33.283%)
Epoch: [15][300/391]	Time 0.057 (0.057)	Data 0.002 (0.003)	Loss 1.6826 (1.7377)	Prec 45.312% (33.539%)
Validation starts
Test: [0/79]	Time 0.210 (0.210)	Loss 1.7006 (1.7006)	Prec 32.031% (32.031%)
 * Prec 35.700% 
best acc: 36.500000
Epoch: [16][0/391]	Time 0.235 (0.235)	Data 0.193 (0.193)	Loss 1.6405 (1.6405)	Prec 35.156% (35.156%)
Epoch: [16][100/391]	Time 0.054 (0.058)	Data 0.002 (0.004)	Loss 1.7085 (1.7065)	Prec 38.281% (35.241%)
Epoch: [16][200/391]	Time 0.056 (0.058)	Data 0.002 (0.003)	Loss 1.7677 (1.7135)	Prec 35.938% (34.966%)
Epoch: [16][300/391]	Time 0.059 (0.058)	Data 0.002 (0.003)	Loss 1.6826 (1.7162)	Prec 29.688% (34.780%)
Validation starts
Test: [0/79]	Time 0.201 (0.201)	Loss 1.6441 (1.6441)	Prec 37.500% (37.500%)
 * Prec 36.020% 
best acc: 36.500000
Epoch: [17][0/391]	Time 0.300 (0.300)	Data 0.250 (0.250)	Loss 1.7119 (1.7119)	Prec 33.594% (33.594%)
Epoch: [17][100/391]	

Epoch: [30][300/391]	Time 0.055 (0.057)	Data 0.002 (0.003)	Loss 1.4396 (1.5141)	Prec 45.312% (42.938%)
Validation starts
Test: [0/79]	Time 0.205 (0.205)	Loss 1.5452 (1.5452)	Prec 40.625% (40.625%)
 * Prec 42.400% 
best acc: 43.840000
Epoch: [31][0/391]	Time 0.282 (0.282)	Data 0.233 (0.233)	Loss 1.5001 (1.5001)	Prec 42.969% (42.969%)
Epoch: [31][100/391]	Time 0.057 (0.059)	Data 0.002 (0.004)	Loss 1.5881 (1.4869)	Prec 35.156% (44.121%)
Epoch: [31][200/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 1.4293 (1.4944)	Prec 46.875% (43.699%)
Epoch: [31][300/391]	Time 0.055 (0.058)	Data 0.002 (0.003)	Loss 1.4329 (1.4936)	Prec 44.531% (43.823%)
Validation starts
Test: [0/79]	Time 0.213 (0.213)	Loss 1.4283 (1.4283)	Prec 51.562% (51.562%)
 * Prec 44.640% 
best acc: 44.640000
Epoch: [32][0/391]	Time 0.272 (0.272)	Data 0.220 (0.220)	Loss 1.4462 (1.4462)	Prec 41.406% (41.406%)
Epoch: [32][100/391]	Time 0.059 (0.059)	Data 0.002 (0.004)	Loss 1.7539 (1.4663)	Prec 41.406% (44.933%)
Epoch: [32][200/391]	

Validation starts
Test: [0/79]	Time 0.226 (0.226)	Loss 1.2599 (1.2599)	Prec 52.344% (52.344%)
 * Prec 50.710% 
best acc: 50.710000
Epoch: [46][0/391]	Time 0.260 (0.260)	Data 0.213 (0.213)	Loss 1.2546 (1.2546)	Prec 57.031% (57.031%)
Epoch: [46][100/391]	Time 0.059 (0.059)	Data 0.002 (0.004)	Loss 1.3332 (1.3117)	Prec 47.656% (51.617%)
Epoch: [46][200/391]	Time 0.056 (0.059)	Data 0.002 (0.003)	Loss 1.4339 (1.3255)	Prec 45.312% (51.236%)
Epoch: [46][300/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 1.3428 (1.3257)	Prec 49.219% (51.100%)
Validation starts
Test: [0/79]	Time 0.198 (0.198)	Loss 1.3043 (1.3043)	Prec 57.031% (57.031%)
 * Prec 48.680% 
best acc: 50.710000
Epoch: [47][0/391]	Time 0.280 (0.280)	Data 0.232 (0.232)	Loss 1.5344 (1.5344)	Prec 48.438% (48.438%)
Epoch: [47][100/391]	Time 0.058 (0.059)	Data 0.001 (0.004)	Loss 1.4471 (1.3600)	Prec 45.312% (49.799%)
Epoch: [47][200/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 1.3338 (1.3554)	Prec 52.344% (50.323%)
Epoch: [47][300/391]	

 * Prec 57.120% 
best acc: 57.120000
Epoch: [61][0/391]	Time 0.297 (0.297)	Data 0.252 (0.252)	Loss 1.1352 (1.1352)	Prec 60.938% (60.938%)
Epoch: [61][100/391]	Time 0.056 (0.059)	Data 0.002 (0.004)	Loss 1.2572 (1.2067)	Prec 50.000% (55.902%)
Epoch: [61][200/391]	Time 0.056 (0.058)	Data 0.002 (0.003)	Loss 1.1994 (1.2123)	Prec 57.812% (55.838%)
Epoch: [61][300/391]	Time 0.057 (0.057)	Data 0.002 (0.003)	Loss 1.0444 (1.2061)	Prec 65.625% (56.172%)
Validation starts
Test: [0/79]	Time 0.234 (0.234)	Loss 1.2448 (1.2448)	Prec 52.344% (52.344%)
 * Prec 56.370% 
best acc: 57.120000
Epoch: [62][0/391]	Time 0.272 (0.272)	Data 0.225 (0.225)	Loss 1.2905 (1.2905)	Prec 53.125% (53.125%)
Epoch: [62][100/391]	Time 0.056 (0.059)	Data 0.002 (0.004)	Loss 1.3815 (1.2017)	Prec 52.344% (56.026%)
Epoch: [62][200/391]	Time 0.059 (0.058)	Data 0.002 (0.003)	Loss 1.1956 (1.1925)	Prec 52.344% (56.405%)
Epoch: [62][300/391]	Time 0.063 (0.058)	Data 0.002 (0.003)	Loss 1.3129 (1.1897)	Prec 52.344% (56.471%)
Validation s

Epoch: [76][100/391]	Time 0.057 (0.059)	Data 0.002 (0.004)	Loss 1.2448 (1.1256)	Prec 53.906% (59.916%)
Epoch: [76][200/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 1.1218 (1.1309)	Prec 57.031% (59.573%)
Epoch: [76][300/391]	Time 0.059 (0.058)	Data 0.002 (0.003)	Loss 1.1927 (1.1290)	Prec 58.594% (59.430%)
Validation starts
Test: [0/79]	Time 0.209 (0.209)	Loss 1.0547 (1.0547)	Prec 64.844% (64.844%)
 * Prec 60.440% 
best acc: 60.440000
Epoch: [77][0/391]	Time 0.274 (0.274)	Data 0.226 (0.226)	Loss 1.1108 (1.1108)	Prec 54.688% (54.688%)
Epoch: [77][100/391]	Time 0.057 (0.059)	Data 0.002 (0.004)	Loss 1.1241 (1.1097)	Prec 58.594% (59.916%)
Epoch: [77][200/391]	Time 0.058 (0.058)	Data 0.002 (0.003)	Loss 1.1569 (1.1152)	Prec 58.594% (59.748%)
Epoch: [77][300/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 1.1912 (1.1178)	Prec 56.250% (59.681%)
Validation starts
Test: [0/79]	Time 0.212 (0.212)	Loss 1.1210 (1.1210)	Prec 59.375% (59.375%)
 * Prec 59.780% 
best acc: 60.440000
Epoch: [78][0/391]	

Epoch: [91][200/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 0.9496 (1.0276)	Prec 67.969% (63.211%)
Epoch: [91][300/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 0.8932 (1.0328)	Prec 71.875% (63.133%)
Validation starts
Test: [0/79]	Time 0.214 (0.214)	Loss 1.0682 (1.0682)	Prec 62.500% (62.500%)
 * Prec 61.680% 
best acc: 64.420000
Epoch: [92][0/391]	Time 0.265 (0.265)	Data 0.217 (0.217)	Loss 1.0545 (1.0545)	Prec 64.062% (64.062%)
Epoch: [92][100/391]	Time 0.056 (0.059)	Data 0.002 (0.004)	Loss 0.8778 (1.0397)	Prec 71.094% (62.910%)
Epoch: [92][200/391]	Time 0.061 (0.058)	Data 0.002 (0.003)	Loss 0.9451 (1.0408)	Prec 67.969% (62.613%)
Epoch: [92][300/391]	Time 0.056 (0.057)	Data 0.002 (0.003)	Loss 0.8585 (1.0342)	Prec 67.188% (62.866%)
Validation starts
Test: [0/79]	Time 0.198 (0.198)	Loss 0.9539 (0.9539)	Prec 64.062% (64.062%)
 * Prec 64.840% 
best acc: 64.840000
Epoch: [93][0/391]	Time 0.268 (0.268)	Data 0.221 (0.221)	Loss 1.0055 (1.0055)	Prec 62.500% (62.500%)
Epoch: [93][100/391]	

Epoch: [106][300/391]	Time 0.057 (0.057)	Data 0.002 (0.002)	Loss 1.0168 (0.9805)	Prec 63.281% (64.932%)
Validation starts
Test: [0/79]	Time 0.210 (0.210)	Loss 0.9407 (0.9407)	Prec 69.531% (69.531%)
 * Prec 63.370% 
best acc: 66.080000
Epoch: [107][0/391]	Time 0.266 (0.266)	Data 0.218 (0.218)	Loss 0.8885 (0.8885)	Prec 69.531% (69.531%)
Epoch: [107][100/391]	Time 0.057 (0.059)	Data 0.002 (0.004)	Loss 0.9024 (0.9924)	Prec 71.094% (64.581%)
Epoch: [107][200/391]	Time 0.058 (0.058)	Data 0.002 (0.003)	Loss 1.1210 (0.9718)	Prec 60.938% (65.372%)
Epoch: [107][300/391]	Time 0.059 (0.057)	Data 0.002 (0.003)	Loss 1.0459 (0.9693)	Prec 59.375% (65.376%)
Validation starts
Test: [0/79]	Time 0.207 (0.207)	Loss 0.8481 (0.8481)	Prec 72.656% (72.656%)
 * Prec 64.850% 
best acc: 66.080000
Epoch: [108][0/391]	Time 0.264 (0.264)	Data 0.213 (0.213)	Loss 0.9648 (0.9648)	Prec 65.625% (65.625%)
Epoch: [108][100/391]	Time 0.055 (0.059)	Data 0.002 (0.004)	Loss 1.0666 (1.0259)	Prec 60.156% (62.833%)
Epoch: [108][2

Validation starts
Test: [0/79]	Time 0.282 (0.282)	Loss 0.8945 (0.8945)	Prec 66.406% (66.406%)
 * Prec 66.700% 
best acc: 67.580000
Epoch: [122][0/391]	Time 0.225 (0.225)	Data 0.185 (0.185)	Loss 0.8446 (0.8446)	Prec 71.875% (71.875%)
Epoch: [122][100/391]	Time 0.059 (0.059)	Data 0.002 (0.004)	Loss 1.0417 (0.9471)	Prec 64.062% (66.252%)
Epoch: [122][200/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 0.9645 (0.9542)	Prec 68.750% (66.130%)
Epoch: [122][300/391]	Time 0.059 (0.058)	Data 0.002 (0.002)	Loss 0.9108 (0.9400)	Prec 66.406% (66.580%)
Validation starts
Test: [0/79]	Time 0.204 (0.204)	Loss 0.8306 (0.8306)	Prec 70.312% (70.312%)
 * Prec 67.650% 
best acc: 67.650000
Epoch: [123][0/391]	Time 0.302 (0.302)	Data 0.249 (0.249)	Loss 0.9504 (0.9504)	Prec 71.094% (71.094%)
Epoch: [123][100/391]	Time 0.059 (0.059)	Data 0.002 (0.004)	Loss 0.9112 (0.9371)	Prec 66.406% (66.507%)
Epoch: [123][200/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 0.8702 (0.9304)	Prec 64.062% (66.772%)
Epoch: [123][3

Test: [0/79]	Time 0.203 (0.203)	Loss 0.7785 (0.7785)	Prec 74.219% (74.219%)
 * Prec 67.050% 
best acc: 68.930000
Epoch: [137][0/391]	Time 0.267 (0.267)	Data 0.216 (0.216)	Loss 0.9829 (0.9829)	Prec 65.625% (65.625%)
Epoch: [137][100/391]	Time 0.057 (0.059)	Data 0.002 (0.004)	Loss 0.8910 (0.8968)	Prec 69.531% (68.479%)
Epoch: [137][200/391]	Time 0.059 (0.058)	Data 0.002 (0.003)	Loss 1.0698 (0.9045)	Prec 65.625% (68.272%)
Epoch: [137][300/391]	Time 0.057 (0.057)	Data 0.002 (0.003)	Loss 0.9343 (0.9014)	Prec 64.062% (68.218%)
Validation starts
Test: [0/79]	Time 0.195 (0.195)	Loss 0.7719 (0.7719)	Prec 73.438% (73.438%)
 * Prec 69.050% 
best acc: 69.050000
Epoch: [138][0/391]	Time 0.243 (0.243)	Data 0.200 (0.200)	Loss 0.7815 (0.7815)	Prec 71.875% (71.875%)
Epoch: [138][100/391]	Time 0.056 (0.059)	Data 0.002 (0.004)	Loss 0.9026 (0.8803)	Prec 67.969% (68.851%)
Epoch: [138][200/391]	Time 0.060 (0.058)	Data 0.002 (0.003)	Loss 0.7647 (0.8672)	Prec 70.312% (69.181%)
Epoch: [138][300/391]	Time 0.057

 * Prec 72.910% 
best acc: 72.910000
Epoch: [152][0/391]	Time 0.261 (0.261)	Data 0.218 (0.218)	Loss 0.7342 (0.7342)	Prec 75.781% (75.781%)
Epoch: [152][100/391]	Time 0.059 (0.060)	Data 0.002 (0.004)	Loss 0.7376 (0.7536)	Prec 71.875% (73.352%)
Epoch: [152][200/391]	Time 0.057 (0.059)	Data 0.002 (0.003)	Loss 0.7750 (0.7598)	Prec 74.219% (72.870%)
Epoch: [152][300/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 0.8527 (0.7613)	Prec 67.969% (72.934%)
Validation starts
Test: [0/79]	Time 0.207 (0.207)	Loss 0.6548 (0.6548)	Prec 72.656% (72.656%)
 * Prec 73.580% 
best acc: 73.580000
Epoch: [153][0/391]	Time 0.243 (0.243)	Data 0.195 (0.195)	Loss 0.7395 (0.7395)	Prec 71.875% (71.875%)
Epoch: [153][100/391]	Time 0.056 (0.059)	Data 0.002 (0.004)	Loss 0.6550 (0.7467)	Prec 71.875% (73.360%)
Epoch: [153][200/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 0.8375 (0.7634)	Prec 69.531% (72.870%)
Epoch: [153][300/391]	Time 0.057 (0.057)	Data 0.001 (0.002)	Loss 0.7881 (0.7568)	Prec 71.875% (73.204%)
Vali

Epoch: [167][0/391]	Time 0.257 (0.257)	Data 0.215 (0.215)	Loss 0.6394 (0.6394)	Prec 75.781% (75.781%)
Epoch: [167][100/391]	Time 0.060 (0.059)	Data 0.002 (0.004)	Loss 0.6926 (0.6939)	Prec 76.562% (75.124%)
Epoch: [167][200/391]	Time 0.059 (0.059)	Data 0.002 (0.003)	Loss 0.7719 (0.6958)	Prec 71.094% (75.463%)
Epoch: [167][300/391]	Time 0.061 (0.059)	Data 0.002 (0.003)	Loss 0.8119 (0.6964)	Prec 71.875% (75.395%)
Validation starts
Test: [0/79]	Time 0.219 (0.219)	Loss 0.6340 (0.6340)	Prec 79.688% (79.688%)
 * Prec 75.990% 
best acc: 75.990000
Epoch: [168][0/391]	Time 0.302 (0.302)	Data 0.248 (0.248)	Loss 0.6187 (0.6187)	Prec 75.000% (75.000%)
Epoch: [168][100/391]	Time 0.057 (0.060)	Data 0.002 (0.004)	Loss 0.8528 (0.7002)	Prec 73.438% (75.155%)
Epoch: [168][200/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 0.8873 (0.7083)	Prec 67.969% (74.848%)
Epoch: [168][300/391]	Time 0.055 (0.058)	Data 0.002 (0.003)	Loss 0.7705 (0.7037)	Prec 69.531% (75.223%)
Validation starts
Test: [0/79]	Time 0.188

Epoch: [182][100/391]	Time 0.057 (0.059)	Data 0.002 (0.004)	Loss 0.6125 (0.6775)	Prec 78.125% (76.253%)
Epoch: [182][200/391]	Time 0.059 (0.058)	Data 0.002 (0.003)	Loss 0.7986 (0.6765)	Prec 71.094% (76.158%)
Epoch: [182][300/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 0.5470 (0.6753)	Prec 82.812% (76.319%)
Validation starts
Test: [0/79]	Time 0.195 (0.195)	Loss 0.5770 (0.5770)	Prec 82.031% (82.031%)
 * Prec 76.240% 
best acc: 76.380000
Epoch: [183][0/391]	Time 0.262 (0.262)	Data 0.220 (0.220)	Loss 0.5486 (0.5486)	Prec 82.812% (82.812%)
Epoch: [183][100/391]	Time 0.056 (0.059)	Data 0.002 (0.004)	Loss 0.6405 (0.6693)	Prec 78.906% (76.477%)
Epoch: [183][200/391]	Time 0.056 (0.058)	Data 0.002 (0.003)	Loss 0.5931 (0.6715)	Prec 78.125% (76.356%)
Epoch: [183][300/391]	Time 0.057 (0.058)	Data 0.002 (0.003)	Loss 0.6334 (0.6739)	Prec 75.781% (76.235%)
Validation starts
Test: [0/79]	Time 0.195 (0.195)	Loss 0.5765 (0.5765)	Prec 82.031% (82.031%)
 * Prec 76.510% 
best acc: 76.510000
Epoch: [184]

Epoch: [197][200/391]	Time 0.058 (0.059)	Data 0.002 (0.003)	Loss 0.6288 (0.6567)	Prec 75.000% (76.621%)
Epoch: [197][300/391]	Time 0.056 (0.058)	Data 0.002 (0.003)	Loss 0.6744 (0.6532)	Prec 78.125% (76.814%)
Validation starts
Test: [0/79]	Time 0.213 (0.213)	Loss 0.6193 (0.6193)	Prec 76.562% (76.562%)
 * Prec 76.840% 
best acc: 76.980000
Epoch: [198][0/391]	Time 0.297 (0.297)	Data 0.248 (0.248)	Loss 0.5786 (0.5786)	Prec 79.688% (79.688%)
Epoch: [198][100/391]	Time 0.054 (0.059)	Data 0.002 (0.004)	Loss 0.6037 (0.6472)	Prec 78.125% (76.818%)
Epoch: [198][200/391]	Time 0.055 (0.058)	Data 0.002 (0.003)	Loss 0.6246 (0.6552)	Prec 75.781% (76.889%)
Epoch: [198][300/391]	Time 0.056 (0.058)	Data 0.002 (0.003)	Loss 0.6572 (0.6610)	Prec 77.344% (76.710%)
Validation starts
Test: [0/79]	Time 0.203 (0.203)	Loss 0.5410 (0.5410)	Prec 78.906% (78.906%)
 * Prec 76.700% 
best acc: 76.980000
Epoch: [199][0/391]	Time 0.320 (0.320)	Data 0.272 (0.272)	Loss 0.6200 (0.6200)	Prec 80.469% (80.469%)
Epoch: [199][1

In [5]:
# This cell is from the website
PATH = "result/VGG16_quant/model_best.pth.tar"
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['state_dict'])
device = torch.device("cuda") 

lr = 1e-2

weight_decay = 1e-4
epochs = 100
best_prec = 0

model = model.cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
# weight decay: for regularization to prevent overfitting

if not os.path.exists('result'):
    os.makedirs('result')
    
fdir = 'result/'+str(model_name)

if not os.path.exists(fdir):
    os.makedirs(fdir)
        

for epoch in range(0, epochs):
    adjust_learning_rate(optimizer, epoch)

    train(trainloader, model, criterion, optimizer, epoch)
    
    # evaluate on test set
    print("Validation starts")
    prec = validate(testloader, model, criterion)

    # remember best precision and save checkpoint
    is_best = prec > best_prec
    best_prec = max(prec,best_prec)
    print('best acc: {:1f}'.format(best_prec))
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec': best_prec,
        'optimizer': optimizer.state_dict(),
    }, is_best, fdir)
    

Epoch: [0][0/391]	Time 0.247 (0.247)	Data 0.206 (0.206)	Loss 0.5886 (0.5886)	Prec 78.125% (78.125%)
Epoch: [0][100/391]	Time 0.053 (0.058)	Data 0.002 (0.005)	Loss 0.8314 (0.7575)	Prec 69.531% (73.368%)
Epoch: [0][200/391]	Time 0.054 (0.057)	Data 0.002 (0.003)	Loss 0.9465 (0.7775)	Prec 68.750% (72.520%)
Epoch: [0][300/391]	Time 0.056 (0.056)	Data 0.002 (0.003)	Loss 0.6544 (0.7963)	Prec 78.125% (71.911%)
Validation starts
Test: [0/79]	Time 0.280 (0.280)	Loss 0.8550 (0.8550)	Prec 69.531% (69.531%)
 * Prec 68.830% 
best acc: 68.830000
Epoch: [1][0/391]	Time 0.279 (0.279)	Data 0.241 (0.241)	Loss 0.9467 (0.9467)	Prec 68.750% (68.750%)
Epoch: [1][100/391]	Time 0.058 (0.057)	Data 0.002 (0.005)	Loss 0.9046 (0.8584)	Prec 67.969% (69.732%)
Epoch: [1][200/391]	Time 0.058 (0.056)	Data 0.002 (0.003)	Loss 0.8734 (0.8604)	Prec 68.750% (69.807%)
Epoch: [1][300/391]	Time 0.054 (0.056)	Data 0.002 (0.003)	Loss 0.9504 (0.8556)	Prec 67.188% (69.970%)
Validation starts
Test: [0/79]	Time 0.215 (0.215)	Loss 0.

Epoch: [15][200/391]	Time 0.058 (0.056)	Data 0.002 (0.003)	Loss 0.7487 (0.8193)	Prec 74.219% (71.451%)
Epoch: [15][300/391]	Time 0.053 (0.056)	Data 0.002 (0.003)	Loss 0.8942 (0.8257)	Prec 67.188% (71.340%)
Validation starts
Test: [0/79]	Time 0.196 (0.196)	Loss 0.8383 (0.8383)	Prec 67.969% (67.969%)
 * Prec 70.570% 
best acc: 71.530000
Epoch: [16][0/391]	Time 0.263 (0.263)	Data 0.215 (0.215)	Loss 0.9081 (0.9081)	Prec 67.969% (67.969%)
Epoch: [16][100/391]	Time 0.053 (0.057)	Data 0.002 (0.004)	Loss 0.7558 (0.8184)	Prec 73.438% (71.279%)
Epoch: [16][200/391]	Time 0.058 (0.056)	Data 0.002 (0.003)	Loss 0.8654 (0.8184)	Prec 67.969% (71.591%)
Epoch: [16][300/391]	Time 0.054 (0.056)	Data 0.002 (0.003)	Loss 0.7577 (0.8104)	Prec 71.094% (71.789%)
Validation starts
Test: [0/79]	Time 0.229 (0.229)	Loss 0.8590 (0.8590)	Prec 71.094% (71.094%)
 * Prec 69.840% 
best acc: 71.530000
Epoch: [17][0/391]	Time 0.239 (0.239)	Data 0.190 (0.190)	Loss 0.7978 (0.7978)	Prec 69.531% (69.531%)
Epoch: [17][100/391]	

Epoch: [30][300/391]	Time 0.057 (0.056)	Data 0.002 (0.003)	Loss 0.7721 (0.8605)	Prec 73.438% (70.775%)
Validation starts
Test: [0/79]	Time 0.199 (0.199)	Loss 1.0215 (1.0215)	Prec 63.281% (63.281%)
 * Prec 66.030% 
best acc: 71.600000
Epoch: [31][0/391]	Time 0.256 (0.256)	Data 0.214 (0.214)	Loss 0.7942 (0.7942)	Prec 75.000% (75.000%)
Epoch: [31][100/391]	Time 0.054 (0.057)	Data 0.002 (0.004)	Loss 0.7934 (0.8727)	Prec 70.312% (69.377%)
Epoch: [31][200/391]	Time 0.051 (0.056)	Data 0.002 (0.003)	Loss 0.7520 (0.8763)	Prec 71.094% (69.209%)
Epoch: [31][300/391]	Time 0.055 (0.055)	Data 0.002 (0.003)	Loss 0.9515 (0.8705)	Prec 64.062% (69.632%)
Validation starts
Test: [0/79]	Time 0.187 (0.187)	Loss 0.8425 (0.8425)	Prec 67.969% (67.969%)
 * Prec 69.610% 
best acc: 71.600000
Epoch: [32][0/391]	Time 0.246 (0.246)	Data 0.200 (0.200)	Loss 0.7954 (0.7954)	Prec 71.875% (71.875%)
Epoch: [32][100/391]	Time 0.054 (0.058)	Data 0.002 (0.004)	Loss 0.8681 (0.8236)	Prec 71.875% (71.635%)
Epoch: [32][200/391]	

KeyboardInterrupt: 

In [6]:
PATH = "result/VGG16_quant/model_best.pth.tar"
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['state_dict'])
device = torch.device("cuda") 

model.cuda()
model.eval()

test_loss = 0
correct = 0

with torch.no_grad():
    for data, target in testloader:
        data, target = data.to(device), target.to(device) # loading to GPU
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)  
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(testloader.dataset)

print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(testloader.dataset),
        100. * correct / len(testloader.dataset)))


Test set: Accuracy: 7187/10000 (72%)



In [43]:
## Send an image and use prehook to grab the inputs of all the QuantConv2d layers

class SaveOutput:
    def __init__(self):
        self.outputs = []
    def __call__(self, module, module_in):
        self.outputs.append(module_in)
    def clear(self):
        self.outputs = []  
        
######### Save inputs from selected layer ##########
save_output = SaveOutput()
i = 0

for layer in model.modules():
    i = i+1
    if isinstance(layer, QuantConv2d):
        print(i,"-th layer prehooked",str(layer))
        layer.register_forward_pre_hook(save_output)             
####################################################

dataiter = iter(testloader)
images, labels = dataiter.next()
images = images.to(device)
model.cuda()
out = model(images)

3 -th layer prehooked QuantConv2d(
  3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
  (weight_quant): weight_quantize_fn()
)
7 -th layer prehooked QuantConv2d(
  64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
  (weight_quant): weight_quantize_fn()
)
12 -th layer prehooked QuantConv2d(
  64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
  (weight_quant): weight_quantize_fn()
)
16 -th layer prehooked QuantConv2d(
  128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
  (weight_quant): weight_quantize_fn()
)
21 -th layer prehooked QuantConv2d(
  128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
  (weight_quant): weight_quantize_fn()
)
25 -th layer prehooked QuantConv2d(
  256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
  (weight_quant): weight_quantize_fn()
)
29 -th layer prehooked QuantConv2d(
  256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bia

In [44]:
#print(model.features[26].weight)
w_bits = 4
weight_q = model.features[26].weight_q # quantized value is stored during
w_alpha = model.features[26].weight_quant.wgt_alpha
w_delta = w_alpha/(2**(w_bits-1)-1)
weight_int = weight_q/w_delta
print(weight_int) # you should see clean integer numbers

tensor([[[[ 0.0000,  0.0000, -0.0000],
          [-0.0000, -0.0000, -1.0000],
          [ 0.0000, -0.0000, -0.0000]],

         [[ 0.0000, -1.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 1.0000, -0.0000, -0.0000]],

         [[-1.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000]],

         [[ 0.0000, -0.0000, -1.0000],
          [ 0.0000, -0.0000, -1.0000],
          [-0.0000, -0.0000,  0.0000]],

         [[ 0.0000, -0.0000, -0.0000],
          [ 0.0000, -1.0000, -1.0000],
          [-0.0000, -0.0000, -0.0000]],

         [[ 0.0000,  0.0000,  0.0000],
          [ 4.0000,  7.0000,  3.0000],
          [ 1.0000,  0.0000,  0.0000]],

         [[-1.0000, -0.0000, -1.0000],
          [-0.0000, -0.0000, -0.0000],
          [-1.0000, -0.0000, -1.0000]],

         [[-0.0000, -0.0000,  1.0000],
          [ 0.0000, -0.0000, -0.0000],
          [ 1.0000,  0.0000,  0.0000]]],


        [[[-0.0000, -0.0000,  0.0000],
       

In [45]:
x_bit = 4
x = save_output.outputs[8][0] # input of the 2nd conv layer
x_alpha = model.features[26].act_alpha
x_delta = x_alpha/(2**x_bit-1) # resolution
act_quant_fn = act_quantization(x_bit) # define the quantization function
x_q = act_quant_fn(x, x_alpha) # create the quantized value for x
x_int = x_q/x_delta
print(x_int) # you should see clean integer numbers

tensor([[[[15.0000, 15.0000, 15.0000,  0.0000],
          [15.0000, 15.0000, 15.0000,  0.0000],
          [15.0000, 15.0000, 15.0000, 15.0000],
          [15.0000, 15.0000, 15.0000,  0.0000]],

         [[ 0.0000, 15.0000, 15.0000, 15.0000],
          [ 0.0000, 15.0000, 15.0000, 15.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[15.0000, 15.0000, 15.0000,  0.0000],
          [15.0000, 15.0000, 15.0000,  0.0000],
          [15.0000, 15.0000, 15.0000,  0.0000],
          [15.0000, 15.0000, 15.0000, 15.0000]],

         ...,

         [[15.0000, 15.0000, 15.0000, 15.0000],
          [ 0.0000, 15.0000, 15.0000, 15.0000],
          [ 0.0000,  0.0000, 15.0000, 15.0000],
          [ 0.0000,  0.0000,  0.0000, 15.0000]],

         [[ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000]],

         [[15.0

In [61]:
conv_int = torch.nn.Conv2d(in_channels = 8, out_channels=8, kernel_size = 3, padding=1, bias=False)
conv_int.to(device)
conv_int.weight = torch.nn.parameter.Parameter(weight_int)
output_int = conv_int(x_int)
output_recovered = output_int*x_delta*w_delta
print(output_recovered)

tensor([[[[ 5.3364e+00,  8.8939e+00,  1.1562e+01,  1.1562e+01],
          [-8.8939e-01,  3.5576e+00,  8.0045e+00,  8.8939e+00],
          [-4.4470e+00, -2.6682e+00,  3.3797e+00,  7.1152e+00],
          [-4.4470e+00, -4.4470e+00, -8.8939e-01,  4.4470e+00]],

         [[ 1.7788e+00,  4.4470e+00,  5.5142e+00,  7.1152e+00],
          [ 1.7788e+00,  8.8939e-01,  7.2930e+00,  7.6488e+00],
          [ 0.0000e+00,  8.8939e-01,  7.8267e+00,  1.0495e+01],
          [-8.8939e-01,  0.0000e+00, -1.1309e-07, -3.5576e+00]],

         [[-7.1152e+00, -8.8939e+00, -8.8939e-01,  5.1585e+00],
          [-1.0673e+01, -1.0673e+01, -1.7788e+00,  6.2258e+00],
          [-1.0673e+01, -1.1562e+01, -5.3364e+00,  4.4470e+00],
          [-8.8939e+00, -9.7833e+00, -9.7833e+00, -2.6682e+00]],

         ...,

         [[ 7.1152e+00,  1.7788e+00, -3.3797e+00, -7.1152e+00],
          [ 5.3364e+00, -1.7788e+00, -6.4036e+00, -9.7833e+00],
          [ 8.0045e+00,  5.3364e+00,  5.5142e+00,  2.6682e+00],
          [ 7.1152e

In [60]:
print(output_recovered.shape)
print(save_output.outputs[8][0].shape)
relu_output_recovered = model.features[27](output_recovered) #RElU
difference = abs(save_output.outputs[9][0] - relu_output_recovered )  ##Difference between prehooked input of next layer
print(difference.mean()) ## It should be small

torch.Size([128, 8, 4, 4])
torch.Size([128, 8, 4, 4])
tensor(5.6099e-07, device='cuda:0', grad_fn=<MeanBackward0>)
