In [1]:
import Ipynb_importer
from a_basic_quant import *
from b_model import *
from c_train_and_test import *

importing Jupyter notebook from a_basic_quant.ipynb
importing Jupyter notebook from b_model.ipynb
importing Jupyter notebook from c_train_and_test.ipynb


In [2]:
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

In [3]:
def full_inference(model, test_loader):
    correct = 0
    torch.cuda.synchronize()
    start = time.time()
    for idx, (datas, targets) in enumerate(test_loader, 1):
        datas, targets = datas.to(device), targets.to(device)
        output = model(datas)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(targets.view_as(pred)).sum().item()
    torch.cuda.synchronize()
    end = time.time()
    print("full inference time: ",end-start)
    print('\nTest set: Full Model Accuracy: {:.0f}%\n'.format(100. * correct / len(test_loader.dataset)))


In [4]:
def direct_quantize(model, test_loader):
    for idx ,(datas, targets) in enumerate(test_loader,1):
        datas, targets = datas.to(device), targets.to(device)
        output = model.quantize_forward(datas)
        if idx % 500 == 0:
            break
    model.freeze()
    print('direct quantization finish')

In [5]:
def quantize_inference(model, test_loader):
    correct = 0
    torch.cuda.synchronize()
    start = time.time()
    for i, (datas, targets) in enumerate(test_loader, 1):
        datas, targets = datas.to(device), targets.to(device)
        output = model.quantize_inference(datas)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(targets.view_as(pred)).sum().item()
    torch.cuda.synchronize()
    end = time.time()
    print("quantize inference time: ",end-start)
    print('\nTest set: Quant Model Accuracy: {:.0f}%\n'.format(100. * correct / len(test_loader.dataset)))


In [6]:
def main():
    batch_size = 64
    test_batch_size = 64
    using_bn = True
    
    train_loader, test_loader = dataset_loader(batch_size, test_batch_size)
    
    if using_bn:
        model = NetBN().to(device)
        model.load_state_dict(torch.load('ckpt/mnist_cnnbn.pt'))
    else:
        model = Net().to(device)
        model.load_state_dict(torch.load('ckpt/mnist_cnn.pt'))
    
    model.eval()
    full_inference(model, test_loader)
    full_inference(model, test_loader)

    num_bits = 8
    print('Quantization bit: %d' % num_bits)
    model.quantize_init(num_bits=num_bits)
    
    
    model.eval()
    direct_quantize(model, train_loader)
    quantize_inference(model, test_loader)
    
#     from torchsummary import summary 
#     summary(model.to('cuda'), (1,28,28))

In [7]:
if __name__ == "__main__":
    main()

full inference time:  1.1617331504821777

Test set: Full Model Accuracy: 99%

full inference time:  1.1561315059661865

Test set: Full Model Accuracy: 99%

Quantization bit: 8
direct quantization finish
quantize inference time:  1.186352014541626

Test set: Quant Model Accuracy: 51%

