This example copies a lot from [a pytorch tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)


In [1]:
import os
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim


# To be able to import model
sys.path.append('/'+'/'.join(os.getcwd().split('/')[:-1])+'/')

from model.Conversion import Start,Stop
from model.activations import ReLU,PACT
from model.linear import Linear
from model.wrapped import FlattenM,MaxPool2d,Dropout
from model.blocks import ConvBnA,LinBnA
from model.sequential import Sequential
from model.Quantizer import LinQuantExpScale
from model.QuantizationMethods.MinMSE import MinMSE_convolution_weight_quantization, MinMSE_linear_weight_quantization

from tqdm import tqdm

path = 'cifa10_vgg'


if not os.path.exists('./runs'):
    os.mkdir('./runs')
if not os.path.exists(f'./runs/{path}'):
    os.mkdir(f'./runs/{path}')

In [2]:
batch_size = 80
epochs = 50

In [3]:
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (.3, .3, .3))])
transform_test = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (.3, .3, .3))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [4]:
weight_quant = MinMSE_convolution_weight_quantization

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # Start and Stop modules convert the Float values to the fake quantized domain
        # and during inference to the integer domain expressed by a float (necessary due to NVIDIA)
        self.start = Start(bits=8,size=(1,3,1,1),mode="auto",auto_runs=2)
        # mode="auto" simply measures min and max of the input and quantized to them in a symmetric manner 
        self.stop = Stop(size=(1,10))
        
        self.seq = Sequential(
            ConvBnA(  3, 64,3,1,1,activation=PACT,weight_quant=weight_quant),
            ConvBnA( 64, 64,3,1,1,activation=PACT,weight_quant=weight_quant),
            MaxPool2d(2,2),
            Dropout(0.1),
            ConvBnA( 64,128,3,1,1,activation=PACT,weight_quant=weight_quant),
            ConvBnA(128,128,3,1,1,activation=PACT,weight_quant=weight_quant),
            MaxPool2d(2,2),
            Dropout(0.1),
            ConvBnA(128,256,3,1,1,activation=PACT,weight_quant=weight_quant),
            ConvBnA(256,256,3,1,1,activation=PACT,weight_quant=weight_quant),
            Dropout(0.1),
            MaxPool2d(2,2),
            FlattenM(1),
            LinBnA(256*4*4,10,weight_quant=MinMSE_linear_weight_quantization,weight_quant_channel_wise=True,activation=LinQuantExpScale,affine=False),
        )

    def forward(self, x):
        x = self.start(x)
        x = self.seq(x)
        x = self.stop(x)
        return x

This example copies a lot from [a pytorch tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)


In [5]:
def eval(net, pr=True):
    global testloader,device
    global best
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    net.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(device,non_blocking=True)
            labels = labels.to(device,non_blocking=True)
            # calculate outputs by running images through the network
            outputs = net(images)
            # loss = criterion(outputs, target)

            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    net.train()
    if pr:
        print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:4.1f} %')
    if best < 100 * correct / total:
        best = 100 * correct / total

    return 100 * correct / total, best


In [6]:
net = Net()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = net.to(device)


criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9,weight_decay=5e-4)
sched= optim.lr_scheduler.CosineAnnealingLR(optimizer,epochs,1e-5,verbose=False)

best = 0

In [7]:
for epoch in range(epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    training_running_correct = 0
    with tqdm(enumerate(trainloader, 0),total=len(trainloader),disable=False) as t:
        for i, (data,target) in t:
            # get the inputs; data is a list of [inputs, labels]
            target = target.to(device,non_blocking=True)
            data = data.to(device,non_blocking=True)

            # zero the parameter gradients
            optimizer.zero_grad(True)

            # forward + backward + optimize
            outputs = net(data)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()


            # print statistics
            running_loss += loss.item()

            _, preds = torch.sort(outputs.data.clone().detach(), 1,descending=True)
            training_running_correct += (preds[:,0].view(-1,1) == target.view(-1,1)).any(dim=1).sum().detach().item()
            t.set_postfix({'loss': running_loss/(i+1),'acc':training_running_correct/((i+1)*batch_size)})
    ev = eval(net,pr=False)
    print(f'[{epoch + 1:3d}, {i + 1:5d}] loss: {running_loss*batch_size/len(trainloader):6.3f},Train Acc: {training_running_correct/len(trainloader):3.1f}, Test Acc:{ev[0]:3.1f}%, Best test Acc:{ev[1]:3.1f}%')
    running_loss = 0.0
    torch.save(net.state_dict(),f"./runs/{path}/ckp.pt")
    sched.step()
            
print('Finished Training')


100%|██████████| 625/625 [00:19<00:00, 32.50it/s, loss=1.53, acc=0.55] 


reduce autorun by 1: 1 min/max -1.6666666269302368 / 1.6666666269302368
[  1,   625] loss: 122.563,Train Acc: 44.0, Test Acc:64.4%, Best test Acc:64.4%


100%|██████████| 625/625 [00:17<00:00, 36.26it/s, loss=1.23, acc=0.722]


reduce autorun by 1: 0 min/max -1.6666666269302368 / 1.6666666269302368
[  2,   625] loss: 98.073,Train Acc: 57.8, Test Acc:72.1%, Best test Acc:72.1%


100%|██████████| 625/625 [00:17<00:00, 36.42it/s, loss=1.12, acc=0.777]


[  3,   625] loss: 89.292,Train Acc: 62.1, Test Acc:77.7%, Best test Acc:77.7%


100%|██████████| 625/625 [00:16<00:00, 36.78it/s, loss=1.06, acc=0.806]


[  4,   625] loss: 84.867,Train Acc: 64.5, Test Acc:74.8%, Best test Acc:77.7%


100%|██████████| 625/625 [00:16<00:00, 37.14it/s, loss=1.03, acc=0.819]


[  5,   625] loss: 82.503,Train Acc: 65.5, Test Acc:78.5%, Best test Acc:78.5%


100%|██████████| 625/625 [00:17<00:00, 36.60it/s, loss=1.01, acc=0.825] 


[  6,   625] loss: 81.133,Train Acc: 66.0, Test Acc:78.3%, Best test Acc:78.5%


100%|██████████| 625/625 [00:16<00:00, 37.59it/s, loss=1, acc=0.832]    


[  7,   625] loss: 80.032,Train Acc: 66.6, Test Acc:78.0%, Best test Acc:78.5%


100%|██████████| 625/625 [00:16<00:00, 36.86it/s, loss=1, acc=0.83]     


[  8,   625] loss: 79.968,Train Acc: 66.4, Test Acc:78.3%, Best test Acc:78.5%


100%|██████████| 625/625 [00:16<00:00, 37.23it/s, loss=1, acc=0.827]    


[  9,   625] loss: 80.236,Train Acc: 66.2, Test Acc:73.4%, Best test Acc:78.5%


100%|██████████| 625/625 [00:16<00:00, 37.52it/s, loss=1.02, acc=0.819] 


[ 10,   625] loss: 81.474,Train Acc: 65.5, Test Acc:78.5%, Best test Acc:78.5%


100%|██████████| 625/625 [00:16<00:00, 37.24it/s, loss=1.04, acc=0.809]


[ 11,   625] loss: 83.134,Train Acc: 64.8, Test Acc:71.8%, Best test Acc:78.5%


100%|██████████| 625/625 [00:16<00:00, 37.42it/s, loss=1.05, acc=0.804]


[ 12,   625] loss: 83.918,Train Acc: 64.3, Test Acc:75.3%, Best test Acc:78.5%


100%|██████████| 625/625 [00:16<00:00, 37.06it/s, loss=1.05, acc=0.8]  


[ 13,   625] loss: 84.247,Train Acc: 64.0, Test Acc:76.3%, Best test Acc:78.5%


100%|██████████| 625/625 [00:16<00:00, 36.98it/s, loss=1.05, acc=0.802]


[ 14,   625] loss: 83.750,Train Acc: 64.1, Test Acc:73.7%, Best test Acc:78.5%


100%|██████████| 625/625 [00:16<00:00, 37.33it/s, loss=1.05, acc=0.803]


[ 15,   625] loss: 83.637,Train Acc: 64.2, Test Acc:76.8%, Best test Acc:78.5%


100%|██████████| 625/625 [00:16<00:00, 37.06it/s, loss=1.04, acc=0.807]


[ 16,   625] loss: 83.261,Train Acc: 64.5, Test Acc:76.8%, Best test Acc:78.5%


100%|██████████| 625/625 [00:16<00:00, 37.34it/s, loss=1.03, acc=0.808]


[ 17,   625] loss: 82.679,Train Acc: 64.6, Test Acc:79.2%, Best test Acc:79.2%


100%|██████████| 625/625 [00:16<00:00, 38.19it/s, loss=1.03, acc=0.812]


[ 18,   625] loss: 82.236,Train Acc: 65.0, Test Acc:76.4%, Best test Acc:79.2%


100%|██████████| 625/625 [00:14<00:00, 42.09it/s, loss=1.02, acc=0.813]


[ 19,   625] loss: 81.836,Train Acc: 65.1, Test Acc:78.0%, Best test Acc:79.2%


100%|██████████| 625/625 [00:14<00:00, 42.21it/s, loss=1.02, acc=0.817]


[ 20,   625] loss: 81.225,Train Acc: 65.4, Test Acc:78.9%, Best test Acc:79.2%


100%|██████████| 625/625 [00:14<00:00, 42.25it/s, loss=1.01, acc=0.822] 


[ 21,   625] loss: 80.581,Train Acc: 65.7, Test Acc:81.5%, Best test Acc:81.5%


100%|██████████| 625/625 [00:14<00:00, 42.03it/s, loss=0.995, acc=0.827]


[ 22,   625] loss: 79.638,Train Acc: 66.1, Test Acc:78.4%, Best test Acc:81.5%


100%|██████████| 625/625 [00:14<00:00, 42.23it/s, loss=0.987, acc=0.83] 


[ 23,   625] loss: 78.973,Train Acc: 66.4, Test Acc:77.8%, Best test Acc:81.5%


100%|██████████| 625/625 [00:14<00:00, 42.17it/s, loss=0.984, acc=0.831]


[ 24,   625] loss: 78.759,Train Acc: 66.5, Test Acc:78.8%, Best test Acc:81.5%


100%|██████████| 625/625 [00:14<00:00, 42.16it/s, loss=0.976, acc=0.835]


[ 25,   625] loss: 78.044,Train Acc: 66.8, Test Acc:78.2%, Best test Acc:81.5%


100%|██████████| 625/625 [00:14<00:00, 42.08it/s, loss=0.965, acc=0.84] 


[ 26,   625] loss: 77.219,Train Acc: 67.2, Test Acc:80.6%, Best test Acc:81.5%


100%|██████████| 625/625 [00:14<00:00, 41.92it/s, loss=0.954, acc=0.847]


[ 27,   625] loss: 76.293,Train Acc: 67.8, Test Acc:82.1%, Best test Acc:82.1%


100%|██████████| 625/625 [00:14<00:00, 42.10it/s, loss=0.946, acc=0.848]


[ 28,   625] loss: 75.691,Train Acc: 67.9, Test Acc:82.1%, Best test Acc:82.1%


100%|██████████| 625/625 [00:14<00:00, 42.10it/s, loss=0.937, acc=0.854]


[ 29,   625] loss: 74.989,Train Acc: 68.3, Test Acc:82.7%, Best test Acc:82.7%


100%|██████████| 625/625 [00:14<00:00, 41.86it/s, loss=0.926, acc=0.86] 


[ 30,   625] loss: 74.094,Train Acc: 68.8, Test Acc:82.1%, Best test Acc:82.7%


100%|██████████| 625/625 [00:14<00:00, 41.93it/s, loss=0.916, acc=0.864]


[ 31,   625] loss: 73.319,Train Acc: 69.1, Test Acc:83.5%, Best test Acc:83.5%


100%|██████████| 625/625 [00:14<00:00, 42.03it/s, loss=0.905, acc=0.869]


[ 32,   625] loss: 72.429,Train Acc: 69.5, Test Acc:82.5%, Best test Acc:83.5%


100%|██████████| 625/625 [00:14<00:00, 42.09it/s, loss=0.892, acc=0.875]


[ 33,   625] loss: 71.341,Train Acc: 70.0, Test Acc:82.7%, Best test Acc:83.5%


100%|██████████| 625/625 [00:14<00:00, 42.08it/s, loss=0.883, acc=0.881]


[ 34,   625] loss: 70.635,Train Acc: 70.5, Test Acc:86.1%, Best test Acc:86.1%


100%|██████████| 625/625 [00:14<00:00, 41.96it/s, loss=0.869, acc=0.888]


[ 35,   625] loss: 69.513,Train Acc: 71.0, Test Acc:82.8%, Best test Acc:86.1%


100%|██████████| 625/625 [00:14<00:00, 41.97it/s, loss=0.854, acc=0.892]


[ 36,   625] loss: 68.345,Train Acc: 71.4, Test Acc:86.2%, Best test Acc:86.2%


100%|██████████| 625/625 [00:14<00:00, 42.15it/s, loss=0.84, acc=0.902] 


[ 37,   625] loss: 67.169,Train Acc: 72.2, Test Acc:87.1%, Best test Acc:87.1%


100%|██████████| 625/625 [00:14<00:00, 42.06it/s, loss=0.826, acc=0.909]


[ 38,   625] loss: 66.070,Train Acc: 72.7, Test Acc:87.9%, Best test Acc:87.9%


100%|██████████| 625/625 [00:14<00:00, 42.05it/s, loss=0.814, acc=0.914]


[ 39,   625] loss: 65.138,Train Acc: 73.2, Test Acc:88.2%, Best test Acc:88.2%


100%|██████████| 625/625 [00:14<00:00, 42.05it/s, loss=0.798, acc=0.923]


[ 40,   625] loss: 63.847,Train Acc: 73.8, Test Acc:86.8%, Best test Acc:88.2%


100%|██████████| 625/625 [00:14<00:00, 42.11it/s, loss=0.785, acc=0.929]


[ 41,   625] loss: 62.809,Train Acc: 74.4, Test Acc:88.6%, Best test Acc:88.6%


100%|██████████| 625/625 [00:14<00:00, 42.08it/s, loss=0.772, acc=0.939]


[ 42,   625] loss: 61.733,Train Acc: 75.1, Test Acc:89.1%, Best test Acc:89.1%


100%|██████████| 625/625 [00:14<00:00, 42.22it/s, loss=0.758, acc=0.945]


[ 43,   625] loss: 60.622,Train Acc: 75.6, Test Acc:89.7%, Best test Acc:89.7%


100%|██████████| 625/625 [00:14<00:00, 42.02it/s, loss=0.744, acc=0.951]


[ 44,   625] loss: 59.522,Train Acc: 76.1, Test Acc:89.7%, Best test Acc:89.7%


100%|██████████| 625/625 [00:14<00:00, 42.07it/s, loss=0.734, acc=0.957]


[ 45,   625] loss: 58.698,Train Acc: 76.6, Test Acc:90.2%, Best test Acc:90.2%


100%|██████████| 625/625 [00:14<00:00, 42.06it/s, loss=0.727, acc=0.961]


[ 46,   625] loss: 58.156,Train Acc: 76.9, Test Acc:90.4%, Best test Acc:90.4%


100%|██████████| 625/625 [00:14<00:00, 41.96it/s, loss=0.72, acc=0.965] 


[ 47,   625] loss: 57.619,Train Acc: 77.2, Test Acc:90.4%, Best test Acc:90.4%


100%|██████████| 625/625 [00:14<00:00, 41.74it/s, loss=0.715, acc=0.967]


[ 48,   625] loss: 57.201,Train Acc: 77.4, Test Acc:90.6%, Best test Acc:90.6%


100%|██████████| 625/625 [00:14<00:00, 42.03it/s, loss=0.71, acc=0.969] 


[ 49,   625] loss: 56.834,Train Acc: 77.5, Test Acc:90.7%, Best test Acc:90.7%


100%|██████████| 625/625 [00:14<00:00, 41.97it/s, loss=0.71, acc=0.971] 


[ 50,   625] loss: 56.784,Train Acc: 77.6, Test Acc:90.7%, Best test Acc:90.7%
Finished Training
