In [1]:
import argparse
import os
import time
import shutil

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
     

import torchvision
import torchvision.transforms as transforms

from models import *


global best_prec
use_gpu = torch.cuda.is_available()
print('=> Building model...')
    
    
    
batch_size = 128
model_name = "resnet20_cifar_project"
model = resnet20_quant()

print(model)

normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])


train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)


test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))

testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


print_freq = 100 # every 100 batches, accuracy printed. Here, each batch includes "batch_size" data points
# CIFAR10 has 50,000 training data, and 10,000 validation data.

def train(trainloader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.train()

    end = time.time()
    for i, (input, target) in enumerate(trainloader):
        # measure data loading time
        data_time.update(time.time() - end)

        input, target = input.cuda(), target.cuda()

        # compute output
        output = model(input)
        loss = criterion(output, target)
        #loss1 = criterion(output, target)
        #loss2 = model.conv1.weight.abs().sum() + model.conv2.weight.abs().sum()
        #loss = loss1
        #loss = loss1 + 0.05*loss2

        # measure accuracy and record loss
        prec = accuracy(output, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec.item(), input.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()


        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format(
                   epoch, i, len(trainloader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1))

            

def validate(val_loader, model, criterion ):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
         
            input, target = input.cuda(), target.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec = accuracy(output, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:  # This line shows how frequently print out the status. e.g., i%5 => every 5 batch, prints out
                print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format(
                   i, len(val_loader), batch_time=batch_time, loss=losses,
                   top1=top1))

    print(' * Prec {top1.avg:.3f}% '.format(top1=top1))
    return top1.avg


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

        
def save_checkpoint(state, is_best, fdir):
    filepath = os.path.join(fdir, 'checkpoint.pth')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(fdir, 'model_best.pth.tar'))


def adjust_learning_rate(optimizer, epoch):
    """For resnet, the lr starts from 0.1, and is divided by 10 at 80 and 120 epochs"""
    adjust_list = [150, 225]
    if epoch in adjust_list:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.1        

#model = nn.DataParallel(model).cuda()
#all_params = checkpoint['state_dict']
#model.load_state_dict(all_params, strict=False)
#criterion = nn.CrossEntropyLoss().cuda()
#validate(testloader, model, criterion)

=> Building model...
ResNet_Cifar(
  (conv1): QuantConv2d(
    3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
    (weight_quant): weight_quantize_fn()
  )
  (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): QuantConv2d(
        8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (weight_quant): weight_quantize_fn()
      )
      (conv2): QuantConv2d(
        8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (weight_quant): weight_quantize_fn()
      )
      (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): QuantConv2d(
        8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (weight_quant): weight_quantize_fn()
      )
      (conv2): QuantConv2d(
        8, 

In [4]:
# This cell won't be given, but students will complete the training

lr = 4e-3
weight_decay = 1e-4
epochs = 200
best_prec = 0

#model = nn.DataParallel(model).cuda()
model.cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
#cudnn.benchmark = True

if not os.path.exists('result'):
    os.makedirs('result')
fdir = 'result/'+str(model_name)
if not os.path.exists(fdir):
    os.makedirs(fdir)
        

for epoch in range(0, epochs):
    adjust_learning_rate(optimizer, epoch)

    train(trainloader, model, criterion, optimizer, epoch)
    
    # evaluate on test set
    print("Validation starts")
    prec = validate(testloader, model, criterion)

    # remember best precision and save checkpoint
    is_best = prec > best_prec
    best_prec = max(prec,best_prec)
    print('best acc: {:1f}'.format(best_prec))
    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec': best_prec,
        'optimizer': optimizer.state_dict(),
    }, is_best, fdir)

Epoch: [0][0/391]	Time 0.280 (0.280)	Data 0.233 (0.233)	Loss 0.1761 (0.1761)	Prec 94.531% (94.531%)
Epoch: [0][100/391]	Time 0.058 (0.050)	Data 0.002 (0.004)	Loss 0.2710 (0.2781)	Prec 89.844% (90.091%)
Epoch: [0][200/391]	Time 0.041 (0.047)	Data 0.001 (0.003)	Loss 0.2797 (0.2787)	Prec 88.281% (90.306%)
Epoch: [0][300/391]	Time 0.072 (0.049)	Data 0.002 (0.003)	Loss 0.1764 (0.2740)	Prec 95.312% (90.555%)
Validation starts
Test: [0/79]	Time 0.192 (0.192)	Loss 0.2695 (0.2695)	Prec 88.281% (88.281%)
 * Prec 87.230% 
best acc: 87.230000
Epoch: [1][0/391]	Time 0.289 (0.289)	Data 0.232 (0.232)	Loss 0.1209 (0.1209)	Prec 98.438% (98.438%)
Epoch: [1][100/391]	Time 0.049 (0.058)	Data 0.002 (0.004)	Loss 0.3106 (0.2615)	Prec 87.500% (90.896%)
Epoch: [1][200/391]	Time 0.049 (0.056)	Data 0.002 (0.003)	Loss 0.1471 (0.2564)	Prec 94.531% (91.088%)
Epoch: [1][300/391]	Time 0.054 (0.055)	Data 0.002 (0.003)	Loss 0.1994 (0.2545)	Prec 92.969% (91.157%)
Validation starts
Test: [0/79]	Time 0.202 (0.202)	Loss 0.

Epoch: [15][200/391]	Time 0.062 (0.052)	Data 0.002 (0.003)	Loss 0.1826 (0.2130)	Prec 92.969% (92.572%)
Epoch: [15][300/391]	Time 0.046 (0.050)	Data 0.002 (0.003)	Loss 0.1954 (0.2111)	Prec 91.406% (92.595%)
Validation starts
Test: [0/79]	Time 0.213 (0.213)	Loss 0.2667 (0.2667)	Prec 92.188% (92.188%)
 * Prec 87.770% 
best acc: 88.140000
Epoch: [16][0/391]	Time 0.263 (0.263)	Data 0.209 (0.209)	Loss 0.2802 (0.2802)	Prec 89.844% (89.844%)
Epoch: [16][100/391]	Time 0.058 (0.045)	Data 0.002 (0.004)	Loss 0.2250 (0.2007)	Prec 92.188% (92.961%)
Epoch: [16][200/391]	Time 0.048 (0.049)	Data 0.002 (0.003)	Loss 0.1411 (0.2063)	Prec 96.094% (92.759%)
Epoch: [16][300/391]	Time 0.038 (0.047)	Data 0.002 (0.003)	Loss 0.2132 (0.2106)	Prec 92.969% (92.618%)
Validation starts
Test: [0/79]	Time 0.227 (0.227)	Loss 0.2306 (0.2306)	Prec 91.406% (91.406%)
 * Prec 87.670% 
best acc: 88.140000
Epoch: [17][0/391]	Time 0.270 (0.270)	Data 0.215 (0.215)	Loss 0.2719 (0.2719)	Prec 89.844% (89.844%)
Epoch: [17][100/391]	

Epoch: [30][300/391]	Time 0.062 (0.047)	Data 0.003 (0.003)	Loss 0.1934 (0.1933)	Prec 92.188% (93.210%)
Validation starts
Test: [0/79]	Time 0.199 (0.199)	Loss 0.2405 (0.2405)	Prec 89.844% (89.844%)
 * Prec 87.420% 
best acc: 88.140000
Epoch: [31][0/391]	Time 0.247 (0.247)	Data 0.196 (0.196)	Loss 0.2037 (0.2037)	Prec 94.531% (94.531%)
Epoch: [31][100/391]	Time 0.053 (0.053)	Data 0.002 (0.004)	Loss 0.2548 (0.1911)	Prec 91.406% (93.386%)
Epoch: [31][200/391]	Time 0.044 (0.051)	Data 0.002 (0.003)	Loss 0.1844 (0.1874)	Prec 92.969% (93.385%)
Epoch: [31][300/391]	Time 0.043 (0.050)	Data 0.002 (0.003)	Loss 0.2192 (0.1903)	Prec 92.188% (93.246%)
Validation starts
Test: [0/79]	Time 0.249 (0.249)	Loss 0.2677 (0.2677)	Prec 90.625% (90.625%)
 * Prec 87.450% 
best acc: 88.140000
Epoch: [32][0/391]	Time 0.284 (0.284)	Data 0.226 (0.226)	Loss 0.1861 (0.1861)	Prec 92.969% (92.969%)
Epoch: [32][100/391]	Time 0.060 (0.055)	Data 0.002 (0.004)	Loss 0.1587 (0.1788)	Prec 96.094% (93.580%)
Epoch: [32][200/391]	

Validation starts
Test: [0/79]	Time 0.218 (0.218)	Loss 0.2651 (0.2651)	Prec 89.844% (89.844%)
 * Prec 87.630% 
best acc: 88.140000
Epoch: [46][0/391]	Time 0.271 (0.271)	Data 0.214 (0.214)	Loss 0.0764 (0.0764)	Prec 96.875% (96.875%)
Epoch: [46][100/391]	Time 0.050 (0.049)	Data 0.002 (0.004)	Loss 0.1404 (0.1715)	Prec 92.969% (93.959%)
Epoch: [46][200/391]	Time 0.039 (0.046)	Data 0.002 (0.003)	Loss 0.1721 (0.1745)	Prec 92.188% (93.855%)
Epoch: [46][300/391]	Time 0.038 (0.046)	Data 0.002 (0.003)	Loss 0.1301 (0.1738)	Prec 96.875% (93.838%)
Validation starts
Test: [0/79]	Time 0.197 (0.197)	Loss 0.2282 (0.2282)	Prec 92.969% (92.969%)
 * Prec 87.550% 
best acc: 88.140000
Epoch: [47][0/391]	Time 0.278 (0.278)	Data 0.221 (0.221)	Loss 0.1350 (0.1350)	Prec 96.094% (96.094%)
Epoch: [47][100/391]	Time 0.049 (0.050)	Data 0.001 (0.004)	Loss 0.1753 (0.1674)	Prec 95.312% (93.843%)
Epoch: [47][200/391]	Time 0.040 (0.050)	Data 0.002 (0.003)	Loss 0.1391 (0.1690)	Prec 95.312% (93.898%)
Epoch: [47][300/391]	

 * Prec 87.150% 
best acc: 88.140000
Epoch: [61][0/391]	Time 0.236 (0.236)	Data 0.190 (0.190)	Loss 0.1234 (0.1234)	Prec 94.531% (94.531%)
Epoch: [61][100/391]	Time 0.042 (0.049)	Data 0.002 (0.004)	Loss 0.2262 (0.1566)	Prec 95.312% (94.593%)
Epoch: [61][200/391]	Time 0.046 (0.048)	Data 0.003 (0.003)	Loss 0.3035 (0.1604)	Prec 88.281% (94.411%)
Epoch: [61][300/391]	Time 0.043 (0.048)	Data 0.002 (0.003)	Loss 0.1708 (0.1577)	Prec 93.750% (94.505%)
Validation starts
Test: [0/79]	Time 0.208 (0.208)	Loss 0.2839 (0.2839)	Prec 90.625% (90.625%)
 * Prec 87.420% 
best acc: 88.140000
Epoch: [62][0/391]	Time 0.247 (0.247)	Data 0.199 (0.199)	Loss 0.1883 (0.1883)	Prec 93.750% (93.750%)
Epoch: [62][100/391]	Time 0.041 (0.047)	Data 0.002 (0.004)	Loss 0.1618 (0.1639)	Prec 96.094% (94.098%)
Epoch: [62][200/391]	Time 0.042 (0.047)	Data 0.002 (0.003)	Loss 0.1496 (0.1631)	Prec 92.969% (94.166%)
Epoch: [62][300/391]	Time 0.048 (0.047)	Data 0.002 (0.003)	Loss 0.1065 (0.1656)	Prec 96.094% (94.059%)
Validation s

Epoch: [76][100/391]	Time 0.070 (0.047)	Data 0.002 (0.004)	Loss 0.2742 (0.1442)	Prec 90.625% (94.895%)
Epoch: [76][200/391]	Time 0.043 (0.046)	Data 0.002 (0.003)	Loss 0.2659 (0.1443)	Prec 91.406% (94.866%)
Epoch: [76][300/391]	Time 0.046 (0.047)	Data 0.002 (0.003)	Loss 0.1314 (0.1466)	Prec 95.312% (94.825%)
Validation starts
Test: [0/79]	Time 0.209 (0.209)	Loss 0.3199 (0.3199)	Prec 89.062% (89.062%)
 * Prec 87.940% 
best acc: 88.140000
Epoch: [77][0/391]	Time 0.278 (0.278)	Data 0.213 (0.213)	Loss 0.1287 (0.1287)	Prec 94.531% (94.531%)
Epoch: [77][100/391]	Time 0.052 (0.049)	Data 0.001 (0.004)	Loss 0.1494 (0.1491)	Prec 93.750% (94.477%)
Epoch: [77][200/391]	Time 0.069 (0.049)	Data 0.003 (0.003)	Loss 0.0883 (0.1496)	Prec 96.875% (94.613%)
Epoch: [77][300/391]	Time 0.048 (0.048)	Data 0.002 (0.002)	Loss 0.1087 (0.1485)	Prec 96.094% (94.664%)
Validation starts
Test: [0/79]	Time 0.207 (0.207)	Loss 0.1948 (0.1948)	Prec 92.188% (92.188%)
 * Prec 87.340% 
best acc: 88.140000
Epoch: [78][0/391]	

Epoch: [91][200/391]	Time 0.039 (0.048)	Data 0.001 (0.003)	Loss 0.1479 (0.1407)	Prec 95.312% (95.009%)
Epoch: [91][300/391]	Time 0.039 (0.046)	Data 0.002 (0.003)	Loss 0.1394 (0.1401)	Prec 96.094% (95.014%)
Validation starts
Test: [0/79]	Time 0.206 (0.206)	Loss 0.3160 (0.3160)	Prec 89.844% (89.844%)
 * Prec 87.520% 
best acc: 88.140000
Epoch: [92][0/391]	Time 0.257 (0.257)	Data 0.201 (0.201)	Loss 0.1043 (0.1043)	Prec 96.094% (96.094%)
Epoch: [92][100/391]	Time 0.036 (0.046)	Data 0.002 (0.004)	Loss 0.1359 (0.1334)	Prec 96.094% (95.359%)
Epoch: [92][200/391]	Time 0.055 (0.047)	Data 0.002 (0.003)	Loss 0.2187 (0.1354)	Prec 92.969% (95.305%)
Epoch: [92][300/391]	Time 0.037 (0.047)	Data 0.002 (0.002)	Loss 0.1652 (0.1380)	Prec 95.312% (95.167%)
Validation starts
Test: [0/79]	Time 0.214 (0.214)	Loss 0.2472 (0.2472)	Prec 92.969% (92.969%)
 * Prec 87.540% 
best acc: 88.140000
Epoch: [93][0/391]	Time 0.267 (0.267)	Data 0.206 (0.206)	Loss 0.1043 (0.1043)	Prec 97.656% (97.656%)
Epoch: [93][100/391]	

Epoch: [106][300/391]	Time 0.036 (0.044)	Data 0.002 (0.002)	Loss 0.1357 (0.1314)	Prec 93.750% (95.328%)
Validation starts
Test: [0/79]	Time 0.198 (0.198)	Loss 0.2407 (0.2407)	Prec 91.406% (91.406%)
 * Prec 87.740% 
best acc: 88.190000
Epoch: [107][0/391]	Time 0.292 (0.292)	Data 0.230 (0.230)	Loss 0.1604 (0.1604)	Prec 95.312% (95.312%)
Epoch: [107][100/391]	Time 0.042 (0.048)	Data 0.002 (0.004)	Loss 0.1178 (0.1283)	Prec 95.312% (95.490%)
Epoch: [107][200/391]	Time 0.043 (0.047)	Data 0.002 (0.003)	Loss 0.0916 (0.1262)	Prec 96.875% (95.538%)
Epoch: [107][300/391]	Time 0.042 (0.047)	Data 0.002 (0.003)	Loss 0.0988 (0.1273)	Prec 95.312% (95.502%)
Validation starts
Test: [0/79]	Time 0.202 (0.202)	Loss 0.2917 (0.2917)	Prec 89.062% (89.062%)
 * Prec 87.370% 
best acc: 88.190000
Epoch: [108][0/391]	Time 0.259 (0.259)	Data 0.212 (0.212)	Loss 0.0931 (0.0931)	Prec 96.875% (96.875%)
Epoch: [108][100/391]	Time 0.041 (0.048)	Data 0.002 (0.004)	Loss 0.0758 (0.1203)	Prec 97.656% (95.506%)
Epoch: [108][2

Validation starts
Test: [0/79]	Time 0.207 (0.207)	Loss 0.2792 (0.2792)	Prec 91.406% (91.406%)
 * Prec 87.590% 
best acc: 88.190000
Epoch: [122][0/391]	Time 0.252 (0.252)	Data 0.202 (0.202)	Loss 0.0950 (0.0950)	Prec 97.656% (97.656%)
Epoch: [122][100/391]	Time 0.041 (0.051)	Data 0.002 (0.004)	Loss 0.1561 (0.1207)	Prec 93.750% (95.661%)
Epoch: [122][200/391]	Time 0.056 (0.048)	Data 0.002 (0.003)	Loss 0.1458 (0.1197)	Prec 95.312% (95.658%)
Epoch: [122][300/391]	Time 0.052 (0.047)	Data 0.001 (0.002)	Loss 0.1544 (0.1223)	Prec 93.750% (95.546%)
Validation starts
Test: [0/79]	Time 0.219 (0.219)	Loss 0.2937 (0.2937)	Prec 89.062% (89.062%)
 * Prec 87.330% 
best acc: 88.190000
Epoch: [123][0/391]	Time 0.288 (0.288)	Data 0.240 (0.240)	Loss 0.1193 (0.1193)	Prec 94.531% (94.531%)
Epoch: [123][100/391]	Time 0.047 (0.045)	Data 0.002 (0.004)	Loss 0.1238 (0.1185)	Prec 95.312% (95.730%)
Epoch: [123][200/391]	Time 0.039 (0.045)	Data 0.002 (0.003)	Loss 0.1358 (0.1196)	Prec 95.312% (95.756%)
Epoch: [123][3

Test: [0/79]	Time 0.212 (0.212)	Loss 0.1699 (0.1699)	Prec 93.750% (93.750%)
 * Prec 87.750% 
best acc: 88.190000
Epoch: [137][0/391]	Time 0.285 (0.285)	Data 0.219 (0.219)	Loss 0.0687 (0.0687)	Prec 97.656% (97.656%)
Epoch: [137][100/391]	Time 0.040 (0.049)	Data 0.001 (0.004)	Loss 0.0980 (0.1077)	Prec 96.875% (96.040%)
Epoch: [137][200/391]	Time 0.037 (0.046)	Data 0.001 (0.003)	Loss 0.0907 (0.1095)	Prec 97.656% (95.989%)
Epoch: [137][300/391]	Time 0.040 (0.046)	Data 0.001 (0.002)	Loss 0.0997 (0.1119)	Prec 96.094% (95.891%)
Validation starts
Test: [0/79]	Time 0.205 (0.205)	Loss 0.2471 (0.2471)	Prec 91.406% (91.406%)
 * Prec 87.440% 
best acc: 88.190000
Epoch: [138][0/391]	Time 0.281 (0.281)	Data 0.222 (0.222)	Loss 0.0918 (0.0918)	Prec 96.094% (96.094%)
Epoch: [138][100/391]	Time 0.043 (0.045)	Data 0.001 (0.004)	Loss 0.1435 (0.1126)	Prec 95.312% (95.831%)
Epoch: [138][200/391]	Time 0.039 (0.048)	Data 0.001 (0.003)	Loss 0.0927 (0.1137)	Prec 96.875% (95.837%)
Epoch: [138][300/391]	Time 0.038

 * Prec 88.090% 
best acc: 88.250000
Epoch: [152][0/391]	Time 0.265 (0.265)	Data 0.214 (0.214)	Loss 0.0736 (0.0736)	Prec 97.656% (97.656%)
Epoch: [152][100/391]	Time 0.041 (0.050)	Data 0.002 (0.004)	Loss 0.0664 (0.0821)	Prec 97.656% (97.231%)
Epoch: [152][200/391]	Time 0.042 (0.047)	Data 0.002 (0.003)	Loss 0.0587 (0.0839)	Prec 98.438% (97.073%)
Epoch: [152][300/391]	Time 0.046 (0.047)	Data 0.002 (0.002)	Loss 0.0462 (0.0847)	Prec 100.000% (97.051%)
Validation starts
Test: [0/79]	Time 0.202 (0.202)	Loss 0.1974 (0.1974)	Prec 91.406% (91.406%)
 * Prec 88.260% 
best acc: 88.260000
Epoch: [153][0/391]	Time 0.263 (0.263)	Data 0.209 (0.209)	Loss 0.0625 (0.0625)	Prec 97.656% (97.656%)
Epoch: [153][100/391]	Time 0.039 (0.051)	Data 0.002 (0.004)	Loss 0.0849 (0.0795)	Prec 96.094% (97.130%)
Epoch: [153][200/391]	Time 0.041 (0.048)	Data 0.002 (0.003)	Loss 0.1006 (0.0831)	Prec 96.875% (97.034%)
Epoch: [153][300/391]	Time 0.038 (0.047)	Data 0.002 (0.002)	Loss 0.0456 (0.0845)	Prec 99.219% (97.005%)
Val

Epoch: [167][0/391]	Time 0.277 (0.277)	Data 0.224 (0.224)	Loss 0.0833 (0.0833)	Prec 97.656% (97.656%)
Epoch: [167][100/391]	Time 0.042 (0.047)	Data 0.002 (0.004)	Loss 0.0764 (0.0727)	Prec 96.875% (97.625%)
Epoch: [167][200/391]	Time 0.041 (0.046)	Data 0.002 (0.003)	Loss 0.1282 (0.0721)	Prec 95.312% (97.606%)
Epoch: [167][300/391]	Time 0.039 (0.046)	Data 0.001 (0.002)	Loss 0.0647 (0.0722)	Prec 97.656% (97.610%)
Validation starts
Test: [0/79]	Time 0.215 (0.215)	Loss 0.2186 (0.2186)	Prec 92.969% (92.969%)
 * Prec 88.390% 
best acc: 88.570000
Epoch: [168][0/391]	Time 0.261 (0.261)	Data 0.216 (0.216)	Loss 0.0925 (0.0925)	Prec 96.875% (96.875%)
Epoch: [168][100/391]	Time 0.037 (0.048)	Data 0.002 (0.004)	Loss 0.0705 (0.0776)	Prec 97.656% (97.208%)
Epoch: [168][200/391]	Time 0.044 (0.052)	Data 0.001 (0.003)	Loss 0.0879 (0.0777)	Prec 96.094% (97.248%)
Epoch: [168][300/391]	Time 0.056 (0.050)	Data 0.002 (0.003)	Loss 0.1492 (0.0772)	Prec 95.312% (97.251%)
Validation starts
Test: [0/79]	Time 0.209

Epoch: [182][100/391]	Time 0.038 (0.046)	Data 0.001 (0.004)	Loss 0.0409 (0.0732)	Prec 99.219% (97.486%)
Epoch: [182][200/391]	Time 0.039 (0.045)	Data 0.002 (0.003)	Loss 0.0614 (0.0726)	Prec 97.656% (97.501%)
Epoch: [182][300/391]	Time 0.060 (0.046)	Data 0.002 (0.002)	Loss 0.1196 (0.0724)	Prec 96.094% (97.508%)
Validation starts
Test: [0/79]	Time 0.196 (0.196)	Loss 0.2453 (0.2453)	Prec 93.750% (93.750%)
 * Prec 88.450% 
best acc: 88.570000
Epoch: [183][0/391]	Time 0.279 (0.279)	Data 0.224 (0.224)	Loss 0.0750 (0.0750)	Prec 97.656% (97.656%)
Epoch: [183][100/391]	Time 0.045 (0.043)	Data 0.001 (0.004)	Loss 0.0692 (0.0726)	Prec 97.656% (97.679%)
Epoch: [183][200/391]	Time 0.049 (0.046)	Data 0.001 (0.003)	Loss 0.0874 (0.0729)	Prec 97.656% (97.625%)
Epoch: [183][300/391]	Time 0.037 (0.047)	Data 0.002 (0.002)	Loss 0.0597 (0.0742)	Prec 98.438% (97.552%)
Validation starts
Test: [0/79]	Time 0.207 (0.207)	Loss 0.2042 (0.2042)	Prec 93.750% (93.750%)
 * Prec 88.050% 
best acc: 88.570000
Epoch: [184]

Epoch: [197][200/391]	Time 0.041 (0.052)	Data 0.002 (0.003)	Loss 0.0450 (0.0727)	Prec 97.656% (97.547%)
Epoch: [197][300/391]	Time 0.045 (0.049)	Data 0.002 (0.003)	Loss 0.0283 (0.0720)	Prec 100.000% (97.555%)
Validation starts
Test: [0/79]	Time 0.198 (0.198)	Loss 0.2182 (0.2182)	Prec 91.406% (91.406%)
 * Prec 88.110% 
best acc: 88.590000
Epoch: [198][0/391]	Time 0.251 (0.251)	Data 0.199 (0.199)	Loss 0.0641 (0.0641)	Prec 99.219% (99.219%)
Epoch: [198][100/391]	Time 0.043 (0.048)	Data 0.002 (0.004)	Loss 0.0728 (0.0730)	Prec 98.438% (97.471%)
Epoch: [198][200/391]	Time 0.038 (0.045)	Data 0.001 (0.003)	Loss 0.0786 (0.0720)	Prec 96.094% (97.547%)
Epoch: [198][300/391]	Time 0.052 (0.049)	Data 0.002 (0.002)	Loss 0.0338 (0.0717)	Prec 100.000% (97.599%)
Validation starts
Test: [0/79]	Time 0.216 (0.216)	Loss 0.2256 (0.2256)	Prec 92.969% (92.969%)
 * Prec 88.400% 
best acc: 88.590000
Epoch: [199][0/391]	Time 0.277 (0.277)	Data 0.216 (0.216)	Loss 0.0945 (0.0945)	Prec 96.094% (96.094%)
Epoch: [199]

In [5]:
# HW

#  1. Train with 4 bits for both weight and activation to achieve >90% accuracy
#  2. Find x_int and w_int for the 2nd convolution layer
#  3. Check the recovered psum has similar value to the un-quantized original psum
#     (such as example 1 in W3S2)

In [2]:
PATH = "result/resnet20_cifar_project/model_best.pth.tar"
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['state_dict'])
device = torch.device("cuda") 

model.cuda()
model.eval()

test_loss = 0
correct = 0

with torch.no_grad():
    for data, target in testloader:
        data, target = data.to(device), target.to(device) # loading to GPU
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)  
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(testloader.dataset)

print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(testloader.dataset),
        100. * correct / len(testloader.dataset)))


Test set: Accuracy: 8861/10000 (89%)



In [3]:
#send an input and grap the value by using prehook like HW3
class SaveOutput:
    def __init__(self):
        self.outputs = []
    def __call__(self, module, module_in):
        self.outputs.append(module_in)
    def clear(self):
        self.outputs = []  
        
######### Save inputs from selected layer ##########
save_output = SaveOutput()

for layer in model.modules():
    if isinstance(layer, torch.nn.Conv2d):
        print("prehooked")
        layer.register_forward_pre_hook(save_output)       ## Input for the module will be grapped       
####################################################

dataiter = iter(trainloader)
images, labels = dataiter.next()
images = images.to(device)
out = model(images)
print("1st convolution's input size:", save_output.outputs[0][0].size())
print("2nd convolution's input size:", save_output.outputs[1][0].size())
test = model
print(model)

prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
prehooked
1st convolution's input size: torch.Size([128, 3, 32, 32])
2nd convolution's input size: torch.Size([128, 8, 32, 32])
ResNet_Cifar(
  (conv1): QuantConv2d(
    3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
    (weight_quant): weight_quantize_fn()
  )
  (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): QuantConv2d(
        8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (weight_quant): weight_quantize_fn()
      )
      (conv2): QuantConv2d(
        8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (weight_quant): weight_quantize_fn()
      )
      (bn1): BatchNorm2d(8, eps=1e-05, mo

In [64]:
w_bit = 4
w_alpha = 4
weight_q = model.layer1[0].conv2.weight_q # quantized value is stored during the training
w_delta      = w_alpha/(2**(w_bit-1)-1)
w_int        = weight_q/w_delta
#print(w_int) # you should see clean integer numbers

In [94]:
x_bit = 4
x = save_output.outputs[1][0]  # input of the 2nd conv layer
x_alpha  = 4
x_delta = x_alpha/(2**x_bit-1)

act_quant_fn = act_quantization(x_bit) # define the quantization function
x_q = act_quant_fn(x, x_alpha)         # create the quantized value for x

x_int = x_q/x_delta
#print(x_int) # you should see clean integer numbers 
print(w_int.size())

torch.Size([8, 8, 3, 3])


In [66]:
conv_int = torch.nn.Conv2d(in_channels = 64, out_channels=64, kernel_size = 3, bias = False)
conv_int.weight = torch.nn.parameter.Parameter(w_int)

output_int =  conv_int(x_int)
output_recovered = output_int*w_delta*x_delta
#print(output_recovered)

In [87]:
#### input floating number / weight quantized version
print(model.layer1[0])
conv_ref = torch.nn.Conv2d(in_channels = 64, out_channels=64, kernel_size = 3, bias = False)
conv_ref.weight = model.layer1[0].conv2.weight_q

output_recovered = model.layer1[0].bn1(output_recovered)
output_recovered = model.layer1[0].relu(output_recovered)
#print(output_ref)

BasicBlock(
  (conv1): QuantConv2d(
    8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
    (weight_quant): weight_quantize_fn()
  )
  (conv2): QuantConv2d(
    8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
    (weight_quant): weight_quantize_fn()
  )
  (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
)


In [88]:
output_ref = save_output.outputs[2][0]

In [89]:
output_recovered.size()

torch.Size([128, 8, 30, 30])

In [90]:
output_ref.size()

torch.Size([128, 8, 32, 32])

In [91]:
difference = abs( output_ref - output_recovered )
print(difference.mean())  ## It should be small, e.g.,2.3 in my trainned model

RuntimeError: The size of tensor a (32) must match the size of tensor b (30) at non-singleton dimension 3