This example copies a lot from [a pytorch tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)


In [1]:
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision
import torch.nn.functional as F
import torch.nn as nn
import torch
import os
import sys

# To be able to import model
sys.path.append('/'+'/'.join(os.getcwd().split('/')[:-1])+'/')

from model.wrapped import FlattenM, MaxPool2d, Dropout, AdaptiveAvgPool2d
from model.blocks import ConvBnA, LinBnA, BasicBlock
from model.activations import ReLU, PACT
from model.Conversion import Start, Stop
from tqdm import tqdm
from model.QuantizationMethods.MinMSE import MinMSE_convolution_weight_quantization, MinMSE_linear_weight_quantization
from model.Quantizer import LinQuantExpScale
from model.sequential import Sequential
from model.linear import Linear


path = 'cifa10_resnet'


if not os.path.exists('./runs'):
    os.mkdir('./runs')
if not os.path.exists(f'./runs/{path}'):
    os.mkdir(f'./runs/{path}')

In [2]:
batch_size = 80
epochs = 50

In [3]:
# Getting the dataset

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (.3, .3, .3))])
transform_test = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (.3, .3, .3))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [4]:
weight_quant = MinMSE_convolution_weight_quantization

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # Start and Stop modules convert the Float values to the fake quantized domain
        # and during inference to the integer domain expressed by a float (necessary due to NVIDIA)
        self.start = Start(bits=8,size=(1,3,1,1),mode="auto",auto_runs=2)
        # mode="auto" simply measures min and max of the input and quantized to them in a symmetric manner 
        self.stop = Stop(size=(1,10))
        
        self.seq = Sequential(
            ConvBnA(  3, 128,3,1,1,activation=PACT,weight_quant=weight_quant),
            BasicBlock(128,128,1),
            Dropout(0.1),
            
            BasicBlock(128,256,2),
            Dropout(0.1),

            BasicBlock(256,512,2),
            Dropout(0.1),

            AdaptiveAvgPool2d((1,1)),
            FlattenM(1),
            LinBnA(512,10,weight_quant=MinMSE_linear_weight_quantization,weight_quant_channel_wise=True,activation=LinQuantExpScale,affine=False),
        )

    def forward(self, x):
        x = self.start(x)
        x = self.seq(x)
        x = self.stop(x)
        return x

In [5]:
def eval(net, pr=True):
    global testloader,device
    global best
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    net.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(device,non_blocking=True)
            labels = labels.to(device,non_blocking=True)
            # calculate outputs by running images through the network
            outputs = net(images)
            # loss = criterion(outputs, target)

            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    net.train()
    if pr:
        print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:4.1f} %')
    if best < 100 * correct / total:
        best = 100 * correct / total

    return 100 * correct / total, best


In [6]:
net = Net()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = net.to(device)


criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9,weight_decay=5e-4)
sched= optim.lr_scheduler.CosineAnnealingLR(optimizer,epochs,1e-5,verbose=False)

best = 0

In [7]:
for epoch in range(epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    training_running_correct = 0
    with tqdm(enumerate(trainloader, 0),total=len(trainloader),disable=False) as t:
        for i, (data,target) in t:
            # get the inputs; data is a list of [inputs, labels]
            target = target.to(device,non_blocking=True)
            data = data.to(device,non_blocking=True)

            # zero the parameter gradients
            optimizer.zero_grad(True)

            # forward + backward + optimize
            outputs = net(data)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()


            # print statistics
            running_loss += loss.item()

            _, preds = torch.sort(outputs.data.clone().detach(), 1,descending=True)
            training_running_correct += (preds[:,0].view(-1,1) == target.view(-1,1)).any(dim=1).sum().detach().item()
            t.set_postfix({'loss': running_loss/(i+1),'acc':training_running_correct/((i+1)*batch_size)})
    ev = eval(net,pr=False)
    print(f'[{epoch + 1:3d}, {i + 1:5d}] loss: {running_loss*batch_size/len(trainloader):6.3f},Train Acc: {training_running_correct/len(trainloader):3.1f}, Test Acc:{ev[0]:3.1f}%, Best test Acc:{ev[1]:3.1f}%')
    running_loss = 0.0
    torch.save(net.state_dict(),f"./runs/{path}/ckp.pt")
    sched.step()
            
print('Finished Training')


100%|██████████| 625/625 [00:52<00:00, 11.80it/s, loss=1.62, acc=0.499]


reduce autorun by 1: 1 min/max -1.6666666269302368 / 1.6666666269302368
[  1,   625] loss: 129.693,Train Acc: 40.0, Test Acc:57.2%, Best test Acc:57.2%


100%|██████████| 625/625 [00:51<00:00, 12.06it/s, loss=1.34, acc=0.662]


reduce autorun by 1: 0 min/max -1.6666666269302368 / 1.6666666269302368
[  2,   625] loss: 107.023,Train Acc: 52.9, Test Acc:63.3%, Best test Acc:63.3%


100%|██████████| 625/625 [00:51<00:00, 12.16it/s, loss=1.21, acc=0.733]


[  3,   625] loss: 96.648,Train Acc: 58.7, Test Acc:69.6%, Best test Acc:69.6%


100%|██████████| 625/625 [00:53<00:00, 11.65it/s, loss=1.12, acc=0.779]


[  4,   625] loss: 89.714,Train Acc: 62.3, Test Acc:73.0%, Best test Acc:73.0%


100%|██████████| 625/625 [00:53<00:00, 11.61it/s, loss=1.07, acc=0.807]


[  5,   625] loss: 85.303,Train Acc: 64.6, Test Acc:77.1%, Best test Acc:77.1%


100%|██████████| 625/625 [00:53<00:00, 11.59it/s, loss=1.04, acc=0.822]


[  6,   625] loss: 83.047,Train Acc: 65.8, Test Acc:80.1%, Best test Acc:80.1%


100%|██████████| 625/625 [00:54<00:00, 11.48it/s, loss=1.01, acc=0.836] 


[  7,   625] loss: 80.690,Train Acc: 66.8, Test Acc:80.8%, Best test Acc:80.8%


100%|██████████| 625/625 [00:56<00:00, 11.11it/s, loss=0.991, acc=0.847]


[  8,   625] loss: 79.300,Train Acc: 67.8, Test Acc:79.1%, Best test Acc:80.8%


100%|██████████| 625/625 [00:56<00:00, 11.10it/s, loss=0.98, acc=0.853] 


[  9,   625] loss: 78.361,Train Acc: 68.2, Test Acc:79.9%, Best test Acc:80.8%


100%|██████████| 625/625 [00:53<00:00, 11.59it/s, loss=0.972, acc=0.857]


[ 10,   625] loss: 77.765,Train Acc: 68.6, Test Acc:82.1%, Best test Acc:82.1%


100%|██████████| 625/625 [00:55<00:00, 11.20it/s, loss=0.973, acc=0.857]


[ 11,   625] loss: 77.874,Train Acc: 68.6, Test Acc:78.7%, Best test Acc:82.1%


100%|██████████| 625/625 [00:59<00:00, 10.46it/s, loss=0.971, acc=0.856]


[ 12,   625] loss: 77.710,Train Acc: 68.5, Test Acc:76.9%, Best test Acc:82.1%


100%|██████████| 625/625 [00:54<00:00, 11.50it/s, loss=0.972, acc=0.858]


[ 13,   625] loss: 77.770,Train Acc: 68.7, Test Acc:76.8%, Best test Acc:82.1%


100%|██████████| 625/625 [00:55<00:00, 11.26it/s, loss=0.967, acc=0.86] 


[ 14,   625] loss: 77.347,Train Acc: 68.8, Test Acc:77.9%, Best test Acc:82.1%


100%|██████████| 625/625 [00:55<00:00, 11.18it/s, loss=0.967, acc=0.861]


[ 15,   625] loss: 77.383,Train Acc: 68.9, Test Acc:84.2%, Best test Acc:84.2%


100%|██████████| 625/625 [00:52<00:00, 11.81it/s, loss=0.96, acc=0.865] 


[ 16,   625] loss: 76.804,Train Acc: 69.2, Test Acc:81.1%, Best test Acc:84.2%


100%|██████████| 625/625 [00:46<00:00, 13.42it/s, loss=0.953, acc=0.868]


[ 17,   625] loss: 76.226,Train Acc: 69.4, Test Acc:81.0%, Best test Acc:84.2%


100%|██████████| 625/625 [00:46<00:00, 13.53it/s, loss=0.943, acc=0.872]


[ 18,   625] loss: 75.446,Train Acc: 69.8, Test Acc:74.4%, Best test Acc:84.2%


100%|██████████| 625/625 [00:45<00:00, 13.80it/s, loss=0.937, acc=0.875]


[ 19,   625] loss: 74.983,Train Acc: 70.0, Test Acc:82.8%, Best test Acc:84.2%


100%|██████████| 625/625 [00:44<00:00, 13.97it/s, loss=0.929, acc=0.881]


[ 20,   625] loss: 74.333,Train Acc: 70.5, Test Acc:83.3%, Best test Acc:84.2%


100%|██████████| 625/625 [00:44<00:00, 13.93it/s, loss=0.919, acc=0.885]


[ 21,   625] loss: 73.528,Train Acc: 70.8, Test Acc:85.0%, Best test Acc:85.0%


100%|██████████| 625/625 [00:44<00:00, 13.97it/s, loss=0.91, acc=0.891] 


[ 22,   625] loss: 72.829,Train Acc: 71.2, Test Acc:82.9%, Best test Acc:85.0%


100%|██████████| 625/625 [00:44<00:00, 14.00it/s, loss=0.901, acc=0.896]


[ 23,   625] loss: 72.082,Train Acc: 71.7, Test Acc:84.9%, Best test Acc:85.0%


100%|██████████| 625/625 [00:44<00:00, 13.98it/s, loss=0.892, acc=0.902]


[ 24,   625] loss: 71.358,Train Acc: 72.2, Test Acc:86.5%, Best test Acc:86.5%


100%|██████████| 625/625 [00:44<00:00, 13.98it/s, loss=0.882, acc=0.906]


[ 25,   625] loss: 70.555,Train Acc: 72.5, Test Acc:86.0%, Best test Acc:86.5%


100%|██████████| 625/625 [00:44<00:00, 13.98it/s, loss=0.874, acc=0.91] 


[ 26,   625] loss: 69.918,Train Acc: 72.8, Test Acc:87.8%, Best test Acc:87.8%


100%|██████████| 625/625 [00:44<00:00, 13.98it/s, loss=0.862, acc=0.916]


[ 27,   625] loss: 68.990,Train Acc: 73.3, Test Acc:85.1%, Best test Acc:87.8%


100%|██████████| 625/625 [00:46<00:00, 13.48it/s, loss=0.851, acc=0.922]


[ 28,   625] loss: 68.062,Train Acc: 73.8, Test Acc:87.3%, Best test Acc:87.8%


100%|██████████| 625/625 [00:44<00:00, 13.94it/s, loss=0.837, acc=0.93] 


[ 29,   625] loss: 66.994,Train Acc: 74.4, Test Acc:87.5%, Best test Acc:87.8%


100%|██████████| 625/625 [00:49<00:00, 12.61it/s, loss=0.827, acc=0.934]


[ 30,   625] loss: 66.148,Train Acc: 74.7, Test Acc:87.8%, Best test Acc:87.8%


100%|██████████| 625/625 [00:51<00:00, 12.11it/s, loss=0.813, acc=0.942]


[ 31,   625] loss: 65.052,Train Acc: 75.4, Test Acc:88.7%, Best test Acc:88.7%


100%|██████████| 625/625 [00:51<00:00, 12.23it/s, loss=0.802, acc=0.948]


[ 32,   625] loss: 64.167,Train Acc: 75.8, Test Acc:86.7%, Best test Acc:88.7%


100%|██████████| 625/625 [00:48<00:00, 12.81it/s, loss=0.786, acc=0.956]


[ 33,   625] loss: 62.909,Train Acc: 76.4, Test Acc:88.3%, Best test Acc:88.7%


100%|██████████| 625/625 [00:50<00:00, 12.29it/s, loss=0.772, acc=0.963]


[ 34,   625] loss: 61.752,Train Acc: 77.0, Test Acc:89.0%, Best test Acc:89.0%


100%|██████████| 625/625 [00:51<00:00, 12.05it/s, loss=0.758, acc=0.969]


[ 35,   625] loss: 60.605,Train Acc: 77.6, Test Acc:89.6%, Best test Acc:89.6%


100%|██████████| 625/625 [00:52<00:00, 11.94it/s, loss=0.738, acc=0.979]


[ 36,   625] loss: 59.004,Train Acc: 78.3, Test Acc:90.6%, Best test Acc:90.6%


100%|██████████| 625/625 [00:51<00:00, 12.03it/s, loss=0.725, acc=0.983]


[ 37,   625] loss: 58.029,Train Acc: 78.6, Test Acc:90.6%, Best test Acc:90.6%


100%|██████████| 625/625 [00:51<00:00, 12.05it/s, loss=0.712, acc=0.987]


[ 38,   625] loss: 56.950,Train Acc: 79.0, Test Acc:91.2%, Best test Acc:91.2%


100%|██████████| 625/625 [00:52<00:00, 11.94it/s, loss=0.698, acc=0.992]


[ 39,   625] loss: 55.852,Train Acc: 79.4, Test Acc:91.8%, Best test Acc:91.8%


100%|██████████| 625/625 [00:53<00:00, 11.61it/s, loss=0.687, acc=0.995]


[ 40,   625] loss: 54.941,Train Acc: 79.6, Test Acc:91.7%, Best test Acc:91.8%


100%|██████████| 625/625 [00:54<00:00, 11.50it/s, loss=0.677, acc=0.996]


[ 41,   625] loss: 54.198,Train Acc: 79.7, Test Acc:92.2%, Best test Acc:92.2%


100%|██████████| 625/625 [00:54<00:00, 11.48it/s, loss=0.668, acc=0.998]


[ 42,   625] loss: 53.411,Train Acc: 79.8, Test Acc:92.4%, Best test Acc:92.4%


100%|██████████| 625/625 [00:55<00:00, 11.29it/s, loss=0.664, acc=0.999]


[ 43,   625] loss: 53.140,Train Acc: 79.9, Test Acc:92.2%, Best test Acc:92.4%


100%|██████████| 625/625 [00:56<00:00, 11.01it/s, loss=0.658, acc=0.999]


[ 44,   625] loss: 52.662,Train Acc: 79.9, Test Acc:92.6%, Best test Acc:92.6%


100%|██████████| 625/625 [00:55<00:00, 11.28it/s, loss=0.655, acc=0.999]


[ 45,   625] loss: 52.406,Train Acc: 79.9, Test Acc:92.5%, Best test Acc:92.6%


100%|██████████| 625/625 [00:54<00:00, 11.48it/s, loss=0.651, acc=1]   


[ 46,   625] loss: 52.044,Train Acc: 80.0, Test Acc:92.5%, Best test Acc:92.6%


100%|██████████| 625/625 [00:53<00:00, 11.64it/s, loss=0.65, acc=1] 


[ 47,   625] loss: 51.968,Train Acc: 80.0, Test Acc:92.6%, Best test Acc:92.6%


100%|██████████| 625/625 [00:55<00:00, 11.27it/s, loss=0.648, acc=1]    


[ 48,   625] loss: 51.872,Train Acc: 80.0, Test Acc:92.8%, Best test Acc:92.8%


100%|██████████| 625/625 [00:54<00:00, 11.37it/s, loss=0.647, acc=1]    


[ 49,   625] loss: 51.778,Train Acc: 80.0, Test Acc:92.8%, Best test Acc:92.8%


100%|██████████| 625/625 [00:53<00:00, 11.71it/s, loss=0.647, acc=1]


[ 50,   625] loss: 51.784,Train Acc: 80.0, Test Acc:92.7%, Best test Acc:92.8%
Finished Training
