In [1]:
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms, utils
from torch.utils.data import TensorDataset, DataLoader
import torch.backends.cudnn as cudnn
import time
from pathlib import Path
import os

from art.estimators.classification import PyTorchClassifier
from art.utils import load_mnist

from quant_mnist_model import *
from QuantModules import *
from _quantUtils import *
from _utils import train, test


%matplotlib inline
%config InlineBackend.figure_format='retina'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
use_cuda = True

In [2]:
(x_train, y_train), (x_test, y_test), min_, max_ = load_mnist()

x_train = np.swapaxes(x_train, 1, 3).astype(np.float32)
x_test = np.swapaxes(x_test, 1, 3).astype(np.float32)

train_dataset = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train))
train_dataloader = DataLoader(train_dataset, batch_size=128)

test_dataset = TensorDataset(torch.Tensor(x_test), torch.Tensor(y_test))
test_dataloader = DataLoader(test_dataset, batch_size=1000)

In [8]:
model = classifier().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=.001, weight_decay=1e-4)

best_prec = 0

for e in range(5):
    print('\nEpoch: %d' % int(e+1))
    train(model, optimizer, train_dataloader)
    prec = test(model, test_dataloader)
    is_best = prec > best_prec
    best_prec = max(prec, best_prec)
    if is_best:
        torch.save(model.state_dict(), 'mnist_4bit.pth')


Epoch: 1
Test accuracy:  97.19

Epoch: 2
Test accuracy:  98.06

Epoch: 3
Test accuracy:  98.46000000000001

Epoch: 4
Test accuracy:  98.64

Epoch: 5
Test accuracy:  98.7


In [4]:
model

Classifier(
  (conv1): first_conv(1, 10, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (conv2): QuantConv2d(
    10, 20, kernel_size=(5, 5), stride=(1, 1), bias=False
    (weight_quant): weight_quantize_fn()
  )
  (conv2_drop): Dropout2d(p=0.5, inplace=False)
  (fc1): last_fc(in_features=320, out_features=512, bias=True)
  (fc2): last_fc(in_features=512, out_features=200, bias=True)
  (fc3): last_fc(in_features=200, out_features=10, bias=True)
)