In [1]:
import onnx
import torch
import torchvision
import torchvision.transforms as transforms
from torch.nn import Module
from QuantLenetV2 import *
from train_utils import *

In [2]:
DATASET_ROOT = "/workspace/finn/src/data/fashion"
BATCH_SIZE = 50

# percentage of training data
VAL_RATIO = 0.1

transform = transforms.Compose(
    [transforms.ToTensor()])

In [3]:
train_data = torchvision.datasets.FashionMNIST(DATASET_ROOT, download=True, train=True, transform=transform)
val_data, train_data = torch.utils.data.random_split(train_data, [50000, 10000])
test_data = torchvision.datasets.FashionMNIST(DATASET_ROOT, train=False, transform=transform)

trainloader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valloader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

classes = ('t-shirt/top', 'trouser', 'pullover', 'dress', \
           'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot')

# ONLY MODIFY CELL BELOW

In [4]:
INPUT_WIDTH = 8
WEIGHT_WIDTH = 1
ACT_WIDTH = 2

MAX_EPOCHS = 100
VERBOSE = True

# ONLY MODIFY CELL ABOVE

In [5]:
qnet = cnv(in_bit_width=INPUT_WIDTH, weight_bit_width=WEIGHT_WIDTH, act_bit_width=ACT_WIDTH, num_classes=10, in_channels=1)
path = f"./models/model_i{INPUT_WIDTH}_w{WEIGHT_WIDTH}_a{ACT_WIDTH}.pth"

In [6]:
best_model_weights = trainModel(qnet, MAX_EPOCHS, trainloader, valloader, path, verbose=VERBOSE)

[1,    50] loss: 2.77179463
[1,   100] loss: 2.67483214
[1,   150] loss: 2.60613252
[1,   200] loss: 2.49534774
[1,   250] loss: 2.44751607
[1,   300] loss: 2.37845330
[1,   350] loss: 2.30682276
[1,   400] loss: 2.27257607
[1,   450] loss: 2.18655979
[1,   500] loss: 2.11501820
[1,   550] loss: 2.06243105
[1,   600] loss: 1.99585778
[1,   650] loss: 1.97857795
[1,   700] loss: 1.94461713
[1,   750] loss: 1.91006961
[1,   800] loss: 1.88302904
[1,   850] loss: 1.88010846
[1,   900] loss: 1.87246749
[1,   950] loss: 1.80417371
[1,  1000] loss: 1.79565271
New Best Validation:[epoch #1] loss: 1.79563627
[2,    50] loss: 1.77117007
[2,   100] loss: 1.71927936
[2,   150] loss: 1.72136850
[2,   200] loss: 1.73461779
[2,   250] loss: 1.74659110
[2,   300] loss: 1.71657789
[2,   350] loss: 1.66049383
[2,   400] loss: 1.66081641
[2,   450] loss: 1.66563689
[2,   500] loss: 1.63538714
[2,   550] loss: 1.62003965
[2,   600] loss: 1.59578630
[2,   650] loss: 1.60579812
[2,   700] loss: 1.59054325


[14,   450] loss: 0.71134249
[14,   500] loss: 0.73087215
[14,   550] loss: 0.71138289
[14,   600] loss: 0.72299585
[14,   650] loss: 0.73350317
[14,   700] loss: 0.76401745
[14,   750] loss: 0.72389117
[14,   800] loss: 0.73455109
[14,   850] loss: 0.68994253
[14,   900] loss: 0.72835654
[14,   950] loss: 0.70390841
[14,  1000] loss: 0.68196176
New Best Validation:[epoch #14] loss: 0.71876691
[15,    50] loss: 0.70254345
[15,   100] loss: 0.74258262
[15,   150] loss: 0.68518803
[15,   200] loss: 0.70692953
[15,   250] loss: 0.71159285
[15,   300] loss: 0.72081983
[15,   350] loss: 0.66884381
[15,   400] loss: 0.70665709
[15,   450] loss: 0.70215374
[15,   500] loss: 0.68914142
[15,   550] loss: 0.68628213
[15,   600] loss: 0.66289281
[15,   650] loss: 0.70967675
[15,   700] loss: 0.72798713
[15,   750] loss: 0.68696715
[15,   800] loss: 0.72806739
[15,   850] loss: 0.72698184
[15,   900] loss: 0.70577168
[15,   950] loss: 0.69706685
[15,  1000] loss: 0.71509467
New Best Validation:[ep

[27,   500] loss: 0.62709532
[27,   550] loss: 0.58445076
[27,   600] loss: 0.55168829
[27,   650] loss: 0.59857000
[27,   700] loss: 0.55416981
[27,   750] loss: 0.58974083
[27,   800] loss: 0.59467333
[27,   850] loss: 0.58510928
[27,   900] loss: 0.57361726
[27,   950] loss: 0.58978474
[27,  1000] loss: 0.60115234
Converged after 27 epochs


In [7]:
test(path, qnet, testloader, BATCH_SIZE, classes)

Accuracy of the network on the 10000 test images: 76 %
--------------------------------------------------------------------------------
Accuracy of t-shirt/top : 78 %
Accuracy of trouser : 93 %
Accuracy of pullover : 59 %
Accuracy of dress : 79 %
Accuracy of  coat : 68 %
Accuracy of sandal : 89 %
Accuracy of shirt : 26 %
Accuracy of sneaker : 84 %
Accuracy of   bag : 92 %
Accuracy of ankle boot : 93 %
