In [1]:
import Ipynb_importer
from a_basic_quant import *
from b_model import *
from c_train_and_test import *

importing Jupyter notebook from a_basic_quant.ipynb
importing Jupyter notebook from b_model.ipynb
importing Jupyter notebook from c_train_and_test.ipynb


In [2]:
def full_inference(model, test_loader):
    correct = 0
    for idx, (datas, targets) in enumerate(test_loader, 1):
        output = model(datas)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(targets.view_as(pred)).sum().item()
    print('\nTest set: Full Model Accuracy: {:.0f}%\n'.format(100. * correct / len(test_loader.dataset)))


In [3]:
def direct_quantize(model, test_loader):
    for idx ,(datas, targets) in enumerate(test_loader,1):
        output = model.quantize_forward(datas)
        if idx % 500 == 0:
            break
    model.freeze()
    print('direct quantization finish')

In [4]:
def quantize_inference(model, test_loader):
    correct = 0
    for i, (datas, targets) in enumerate(test_loader, 1):
        output = model.quantize_inference(datas)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(targets.view_as(pred)).sum().item()
    print('\nTest set: Quant Model Accuracy: {:.0f}%\n'.format(100. * correct / len(test_loader.dataset)))


In [8]:
batch_size = 64
test_batch_size = 64
using_bn = False

train_loader, test_loader = dataset_loader(batch_size, test_batch_size)

if using_bn:
    model = NetBN()
    model.load_state_dict(torch.load('ckpt/mnist_cnnbn.pt'))
else:
    model = Net()
    model.load_state_dict(torch.load('ckpt/mnist_cnn.pt'))

model.eval()
full_inference(model, test_loader)

num_bits = 8
print('Quantization bit: %d' % num_bits)
model.quantize_init(num_bits=num_bits)

model.eval()
direct_quantize(model, train_loader)
quantize_inference(model, test_loader)


Test set: Full Model Accuracy: 98%

Quantization bit: 8
direct quantization finish

Test set: Quant Model Accuracy: 98%



In [5]:
def main():
    batch_size = 64
    test_batch_size = 64
    using_bn = False
    
    train_loader, test_loader = dataset_loader(batch_size, test_batch_size)
    
    if using_bn:
        model = NetBN()
        model.load_state_dict(torch.load('ckpt/mnist_cnnbn.pt'))
    else:
        model = Net()
        model.load_state_dict(torch.load('ckpt/mnist_cnn.pt'))
    
    model.eval()
    full_inference(model, test_loader)

    num_bits = 8
    print('Quantization bit: %d' % num_bits)
    model.quantize_init(num_bits=num_bits)
    
    model.eval()
    direct_quantize(model, train_loader)
    quantize_inference(model, test_loader)
    
#     from torchsummary import summary 
#     summary(model.to('cuda'), (1,28,28))

In [6]:
if __name__ == "__main__":
    main()


Test set: Full Model Accuracy: 98%

Quantization bit: 8
direct quantization finish

Test set: Quant Model Accuracy: 98%

