In [None]:
import time
import copy
import sys
import os
from collections import OrderedDict

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


from models import ConvNet
from utils import parse_args, make_criterion


# training function
def train(net, trainloader, criterion, optimizer, epoch, lmbda, train_loss_tracker, train_acc_tracker):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        inputs.requires_grad=True
        optimizer.zero_grad()
        outputs = net(inputs)
#         if lbda == 0:
#             loss = criterion(outputs, targets)     #nn.CrossEntropyLoss(outputs, targets) # Add regularization term (define new class)
#         elif lbda > 0:
#             loss = criterion(outputs, targets, net)
#         else:
#             NotImplementedError 
        loss = criterion(outputs, targets, net)
        loss.backward()
        # update optimizer state
        optimizer.step()
        # compute average loss
        train_loss += loss.item()
        train_loss_tracker.append(loss.item())
        loss = train_loss / (batch_idx + 1)
        # compute accuracy
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        acc = 100. * correct / total
        # Print status
        sys.stdout.write(f'\rEpoch {epoch}: Train Loss: {loss:.3f}' +  
                         f'| Train Acc: {acc:.3f}')
        sys.stdout.flush()
    train_acc_tracker.append(acc)
    sys.stdout.flush()

# testing function 
def test(net, testloader, criterion, epoch, lmbda, test_loss_tracker, test_acc_tracker):
    global best_acc
    best_acc = 0 
    net.eval()
    test_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(testloader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = net(inputs)
#             if lbda == 0:
#                 loss = criterion(outputs, targets)     #nn.CrossEntropyLoss(outputs, targets) # Add regularization term (define new class)
#             elif lbda > 0:
#                 loss = criterion(outputs, targets, net)
#             else:
#                 NotImplementedError 
        loss = criterion(outputs, targets, net)
        test_loss += loss.item()
        test_loss_tracker.append(loss.item())
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        loss = test_loss / (batch_idx + 1)
        acc = 100.* correct / total
    sys.stdout.write(f' | Test Loss: {loss:.3f} | Test Acc: {acc:.3f}\n')
    sys.stdout.flush()
    
    # Save checkpoint.
    acc = 100.*correct/total
    test_acc_tracker.append(acc)
    if acc > best_acc:
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc  


def quantizer(input, nbit):
    '''
    input: full precision tensor in the range [0, 1]
    return: quantized tensor
    '''
    output = input * (2**nbit -1)
    output = torch.round(output)

    return output/(2**nbit -1)


def dorefa_g(w, nbit, adaptive_scale=None):
    '''
    w: a floating-point weight tensor to quantize
    nbit: the number of bits in the quantized representation
    adaptive_scale: the maximum scale value. if None, it is set to be the
                    absolute maximum value in w.
    '''
    if adaptive_scale is None:
        adaptive_scale = torch.max(torch.abs(w))

    # Part 3.2: Implement based on stochastic quantization function above
    # basically, quantize and dequantize (with added noise in the middle)
    noise_tensor = (torch.rand(w.shape, device=w.device) - 0.5) / (2**nbit - 1)
    intermediate = quantizer(noise_tensor + 0.5 + w / (2*adaptive_scale), nbit)
    w_q = 2 * adaptive_scale * (intermediate - 0.5)

    # remove placeholder "return w, adaptive_scale" line below 
    # after you implement
    return w_q, adaptive_scale


def quantize_model(model, nbit):
    '''
    Used to quantize the ConvNet model
    '''
    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            m.weight.data, m.adaptive_scale = dorefa_g(m.weight, nbit)
            if m.bias is not None:
                m.bias.data,_ = dorefa_g(m.bias, nbit, m.adaptive_scale)


# Load training data
transform_train = transforms.Compose([                                   
    transforms.RandomCrop(32, padding=4),                                       
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

# Load testing data
transform_test = transforms.Compose([                                           
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)


# python script
if __name__ == "__main__":
    args = parse_args()
    
    # seed
    SEED = args.seed
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    
    # unpack args
    device = args.device
    epochs = args.n_epochs
    #regulquant_on = args.regulquant
    lmbda = args.lmbda
    
    net = ConvNet()
    net = net.to(device)
    
    lr = args.lr
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = None
    criterion = make_criterion(args)
    
    
    train_loss_tracker, train_acc_tracker = [], []
    test_loss_tracker, test_acc_tracker = [], []
    
    start_time = time.time()
    for epoch in range(0, epochs):
        epoch_start_time = time.time()
        
        # call train function
        train(net, trainloader, criterion, optimizer, epoch, lmbda, train_loss_tracker, train_acc_tracker)
        
        # call test function
        test(net, testloader, criterion, epoch, lmbda, test_loss_tracker, test_acc_tracker)
        
        # scheduler step
        if scheduler:
            scheduler.step()
        
        epoch_end_time = time.time()
        epoch_total_time = epoch_end_time - epoch_start_time        
        print(f"Epoch runtime: {epoch_total_time}")

    total_time = time.time() - start_time
    print('Total runtime: {} seconds'.format(total_time))  

In [None]:
n_bits = 8
quantize_model(net, n_bits)
test(net, testloader, criterion, epoch, lmbda, test_loss_tracker, test_acc_tracker)