In [8]:
import os
import sys
sys.path.append('/'+'/'.join(os.getcwd().split('/')[:-1])+'/')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim


from model.Conversion import Start,Stop
from model.activations import ReLU,PACT
from model.linear import Linear
from model.wrapped import FlattenM,MaxPool2d,Dropout
from model.blocks import ConvBnA,LinBnA
from model.sequential import Sequential
from model.Quantizer import LinQuantExpScale
from model.QuantizationMethods.MinMSE import MinMSE_convolution_weight_quantization, MinMSE_linear_weight_quantization

from tqdm import tqdm

path = 'cifa10_vgg'


if not os.path.exists('./runs'):
    os.mkdir('./runs')
if not os.path.exists(f'./runs/{path}'):
    os.mkdir(f'./runs/{path}')

In [9]:
batch_size = 80
epochs = 50

In [10]:
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (.3, .3, .3))])
transform_test = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (.3, .3, .3))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [11]:
weight_quant = MinMSE_convolution_weight_quantization

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.start = Start(bits=8,size=(1,3,1,1),mode="auto",auto_runs=2)
        self.stop = Stop(size=(1,10))
        
        self.seq = Sequential(
            ConvBnA(  3, 64,3,1,1,activation=PACT,weight_quant=weight_quant),
            ConvBnA( 64, 64,3,1,1,activation=PACT,weight_quant=weight_quant),
            MaxPool2d(2,2),
            Dropout(0.1),
            ConvBnA( 64,128,3,1,1,activation=PACT,weight_quant=weight_quant),
            ConvBnA(128,128,3,1,1,activation=PACT,weight_quant=weight_quant),
            MaxPool2d(2,2),
            Dropout(0.1),
            ConvBnA(128,256,3,1,1,activation=PACT,weight_quant=weight_quant),
            ConvBnA(256,256,3,1,1,activation=PACT,weight_quant=weight_quant),
            Dropout(0.1),
            MaxPool2d(2,2),
            FlattenM(1),
            LinBnA(256*4*4,10,weight_quant=MinMSE_linear_weight_quantization,weight_quant_channel_wise=True,activation=LinQuantExpScale,affine=False),
        )

    def forward(self, x):
        x = self.start(x)
        x = self.seq(x)
        x = self.stop(x)
        return x

In [12]:
def eval(net, pr=True):
    global testloader,device
    global best
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    net.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(device,non_blocking=True)
            labels = labels.to(device,non_blocking=True)
            # calculate outputs by running images through the network
            outputs = net(images)
            # loss = criterion(outputs, target)

            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    net.train()
    if pr:
        print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:4.1f} %')
    if best < 100 * correct / total:
        best = 100 * correct / total

    return 100 * correct / total, best


In [13]:
net = Net()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = net.to(device)


criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9,weight_decay=5e-4)
sched= optim.lr_scheduler.CosineAnnealingLR(optimizer,epochs,1e-5,verbose=False)

best = 0

In [14]:
for epoch in range(epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    training_running_correct = 0
    with tqdm(enumerate(trainloader, 0),total=len(trainloader),disable=False) as t:
        for i, (data,target) in t:
            # get the inputs; data is a list of [inputs, labels]
            target = target.to(device,non_blocking=True)
            data = data.to(device,non_blocking=True)

            # zero the parameter gradients
            optimizer.zero_grad(True)

            # forward + backward + optimize
            outputs = net(data)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()


            # print statistics
            running_loss += loss.item()
            # if i == 0:    # print every 2000 mini-batches
            _, preds = torch.sort(outputs.data.clone().detach(), 1,descending=True)
            training_running_correct += (preds[:,0].view(-1,1) == target.view(-1,1)).any(dim=1).sum().detach().item()
            t.set_postfix({'loss': running_loss/(i+1),'acc':training_running_correct/((i+1)*batch_size)})
    ev = eval(net,pr=False)
    print(f'[{epoch + 1:3d}, {i + 1:5d}] loss: {running_loss*batch_size/len(trainloader):6.3f},Train Acc: {training_running_correct/len(trainloader):3.1f}, Test Acc:{ev[0]:3.1f}%, Best test Acc:{ev[1]:3.1f}%')
    running_loss = 0.0
    torch.save(net.state_dict(),f"./runs/{path}/ckp.pt")
    sched.step()
            
print('Finished Training')


100%|██████████| 625/625 [00:17<00:00, 35.57it/s, loss=1.42, acc=0.615]


reduce autorun by 1: 1 min/max -1.6666666269302368 / 1.6666666269302368
[  1,   625] loss: 113.609,Train Acc: 49.2, Test Acc:71.7%, Best test Acc:71.7%


100%|██████████| 625/625 [00:19<00:00, 32.80it/s, loss=1.17, acc=0.752]


reduce autorun by 1: 0 min/max -1.6666666269302368 / 1.6666666269302368
[  2,   625] loss: 93.857,Train Acc: 60.2, Test Acc:73.9%, Best test Acc:73.9%


100%|██████████| 625/625 [00:18<00:00, 34.28it/s, loss=1.1, acc=0.794] 


[  3,   625] loss: 87.647,Train Acc: 63.5, Test Acc:79.5%, Best test Acc:79.5%


100%|██████████| 625/625 [00:18<00:00, 33.02it/s, loss=1.04, acc=0.822]


[  4,   625] loss: 83.510,Train Acc: 65.8, Test Acc:81.5%, Best test Acc:81.5%


100%|██████████| 625/625 [00:17<00:00, 35.29it/s, loss=1.01, acc=0.842]


[  5,   625] loss: 80.530,Train Acc: 67.4, Test Acc:83.5%, Best test Acc:83.5%


100%|██████████| 625/625 [00:17<00:00, 34.95it/s, loss=0.977, acc=0.859]


[  6,   625] loss: 78.180,Train Acc: 68.8, Test Acc:84.3%, Best test Acc:84.3%


100%|██████████| 625/625 [00:18<00:00, 34.32it/s, loss=0.955, acc=0.869]


[  7,   625] loss: 76.439,Train Acc: 69.5, Test Acc:84.8%, Best test Acc:84.8%


100%|██████████| 625/625 [00:18<00:00, 34.69it/s, loss=0.933, acc=0.883]


[  8,   625] loss: 74.641,Train Acc: 70.6, Test Acc:84.8%, Best test Acc:84.8%


100%|██████████| 625/625 [00:17<00:00, 34.84it/s, loss=0.916, acc=0.89] 


[  9,   625] loss: 73.241,Train Acc: 71.2, Test Acc:86.0%, Best test Acc:86.0%


100%|██████████| 625/625 [00:17<00:00, 35.38it/s, loss=0.897, acc=0.901]


[ 10,   625] loss: 71.756,Train Acc: 72.1, Test Acc:87.1%, Best test Acc:87.1%


100%|██████████| 625/625 [00:17<00:00, 34.99it/s, loss=0.884, acc=0.908]


[ 11,   625] loss: 70.710,Train Acc: 72.7, Test Acc:87.1%, Best test Acc:87.1%


100%|██████████| 625/625 [00:17<00:00, 36.15it/s, loss=0.871, acc=0.915]


[ 12,   625] loss: 69.648,Train Acc: 73.2, Test Acc:87.0%, Best test Acc:87.1%


100%|██████████| 625/625 [00:16<00:00, 37.08it/s, loss=0.857, acc=0.92] 


[ 13,   625] loss: 68.546,Train Acc: 73.6, Test Acc:87.6%, Best test Acc:87.6%


100%|██████████| 625/625 [00:17<00:00, 35.98it/s, loss=0.846, acc=0.928]


[ 14,   625] loss: 67.646,Train Acc: 74.2, Test Acc:85.8%, Best test Acc:87.6%


100%|██████████| 625/625 [00:17<00:00, 36.64it/s, loss=0.834, acc=0.933]


[ 15,   625] loss: 66.741,Train Acc: 74.7, Test Acc:87.6%, Best test Acc:87.6%


100%|██████████| 625/625 [00:17<00:00, 36.67it/s, loss=0.825, acc=0.939]


[ 16,   625] loss: 65.994,Train Acc: 75.1, Test Acc:88.2%, Best test Acc:88.2%


100%|██████████| 625/625 [00:18<00:00, 34.48it/s, loss=0.813, acc=0.944]


[ 17,   625] loss: 65.069,Train Acc: 75.6, Test Acc:88.3%, Best test Acc:88.3%


100%|██████████| 625/625 [00:18<00:00, 34.22it/s, loss=0.805, acc=0.948]


[ 18,   625] loss: 64.427,Train Acc: 75.9, Test Acc:88.6%, Best test Acc:88.6%


100%|██████████| 625/625 [00:18<00:00, 33.78it/s, loss=0.794, acc=0.954]


[ 19,   625] loss: 63.528,Train Acc: 76.3, Test Acc:89.2%, Best test Acc:89.2%


100%|██████████| 625/625 [00:18<00:00, 34.61it/s, loss=0.786, acc=0.959]


[ 20,   625] loss: 62.851,Train Acc: 76.7, Test Acc:89.1%, Best test Acc:89.2%


100%|██████████| 625/625 [00:18<00:00, 33.66it/s, loss=0.78, acc=0.963] 


[ 21,   625] loss: 62.396,Train Acc: 77.0, Test Acc:88.7%, Best test Acc:89.2%


100%|██████████| 625/625 [00:18<00:00, 34.27it/s, loss=0.77, acc=0.966] 


[ 22,   625] loss: 61.601,Train Acc: 77.3, Test Acc:89.3%, Best test Acc:89.3%


100%|██████████| 625/625 [00:19<00:00, 32.25it/s, loss=0.763, acc=0.971]


[ 23,   625] loss: 61.023,Train Acc: 77.6, Test Acc:89.3%, Best test Acc:89.3%


100%|██████████| 625/625 [00:19<00:00, 32.05it/s, loss=0.757, acc=0.972]


[ 24,   625] loss: 60.584,Train Acc: 77.7, Test Acc:89.8%, Best test Acc:89.8%


100%|██████████| 625/625 [00:19<00:00, 32.16it/s, loss=0.749, acc=0.976]


[ 25,   625] loss: 59.921,Train Acc: 78.1, Test Acc:89.5%, Best test Acc:89.8%


100%|██████████| 625/625 [00:20<00:00, 31.25it/s, loss=0.743, acc=0.979]


[ 26,   625] loss: 59.413,Train Acc: 78.3, Test Acc:89.7%, Best test Acc:89.8%


100%|██████████| 625/625 [00:18<00:00, 33.57it/s, loss=0.737, acc=0.981]


[ 27,   625] loss: 58.955,Train Acc: 78.5, Test Acc:90.1%, Best test Acc:90.1%


100%|██████████| 625/625 [00:18<00:00, 33.49it/s, loss=0.732, acc=0.983]


[ 28,   625] loss: 58.554,Train Acc: 78.6, Test Acc:89.5%, Best test Acc:90.1%


100%|██████████| 625/625 [00:18<00:00, 33.71it/s, loss=0.727, acc=0.986]


[ 29,   625] loss: 58.126,Train Acc: 78.9, Test Acc:90.2%, Best test Acc:90.2%


100%|██████████| 625/625 [00:18<00:00, 33.96it/s, loss=0.719, acc=0.988]


[ 30,   625] loss: 57.503,Train Acc: 79.0, Test Acc:89.8%, Best test Acc:90.2%


100%|██████████| 625/625 [00:18<00:00, 33.07it/s, loss=0.715, acc=0.989]


[ 31,   625] loss: 57.231,Train Acc: 79.1, Test Acc:90.4%, Best test Acc:90.4%


100%|██████████| 625/625 [00:17<00:00, 35.10it/s, loss=0.71, acc=0.991] 


[ 32,   625] loss: 56.808,Train Acc: 79.3, Test Acc:90.3%, Best test Acc:90.4%


100%|██████████| 625/625 [00:17<00:00, 34.81it/s, loss=0.705, acc=0.992]


[ 33,   625] loss: 56.424,Train Acc: 79.4, Test Acc:90.2%, Best test Acc:90.4%


100%|██████████| 625/625 [00:17<00:00, 35.03it/s, loss=0.704, acc=0.992]


[ 34,   625] loss: 56.306,Train Acc: 79.4, Test Acc:90.7%, Best test Acc:90.7%


100%|██████████| 625/625 [00:17<00:00, 35.43it/s, loss=0.699, acc=0.994]


[ 35,   625] loss: 55.944,Train Acc: 79.5, Test Acc:90.6%, Best test Acc:90.7%


100%|██████████| 625/625 [00:17<00:00, 35.63it/s, loss=0.696, acc=0.994]


[ 36,   625] loss: 55.703,Train Acc: 79.5, Test Acc:90.7%, Best test Acc:90.7%


100%|██████████| 625/625 [00:17<00:00, 34.81it/s, loss=0.693, acc=0.995]


[ 37,   625] loss: 55.476,Train Acc: 79.6, Test Acc:90.4%, Best test Acc:90.7%


100%|██████████| 625/625 [00:18<00:00, 34.40it/s, loss=0.69, acc=0.995] 


[ 38,   625] loss: 55.229,Train Acc: 79.6, Test Acc:90.5%, Best test Acc:90.7%


100%|██████████| 625/625 [00:18<00:00, 33.61it/s, loss=0.688, acc=0.996]


[ 39,   625] loss: 55.046,Train Acc: 79.7, Test Acc:90.8%, Best test Acc:90.8%


100%|██████████| 625/625 [00:18<00:00, 34.16it/s, loss=0.685, acc=0.996]


[ 40,   625] loss: 54.794,Train Acc: 79.7, Test Acc:90.8%, Best test Acc:90.8%


100%|██████████| 625/625 [00:18<00:00, 34.35it/s, loss=0.684, acc=0.997]


[ 41,   625] loss: 54.751,Train Acc: 79.8, Test Acc:90.6%, Best test Acc:90.8%


100%|██████████| 625/625 [00:18<00:00, 34.66it/s, loss=0.683, acc=0.997]


[ 42,   625] loss: 54.608,Train Acc: 79.8, Test Acc:90.8%, Best test Acc:90.8%


100%|██████████| 625/625 [00:18<00:00, 34.12it/s, loss=0.68, acc=0.997] 


[ 43,   625] loss: 54.425,Train Acc: 79.8, Test Acc:90.8%, Best test Acc:90.8%


100%|██████████| 625/625 [00:18<00:00, 34.50it/s, loss=0.679, acc=0.998]


[ 44,   625] loss: 54.324,Train Acc: 79.8, Test Acc:90.7%, Best test Acc:90.8%


100%|██████████| 625/625 [00:17<00:00, 35.60it/s, loss=0.678, acc=0.998]


[ 45,   625] loss: 54.252,Train Acc: 79.8, Test Acc:90.7%, Best test Acc:90.8%


100%|██████████| 625/625 [00:17<00:00, 34.81it/s, loss=0.678, acc=0.998]


[ 46,   625] loss: 54.220,Train Acc: 79.8, Test Acc:90.7%, Best test Acc:90.8%


100%|██████████| 625/625 [00:18<00:00, 34.72it/s, loss=0.677, acc=0.998]


[ 47,   625] loss: 54.197,Train Acc: 79.8, Test Acc:90.8%, Best test Acc:90.8%


100%|██████████| 625/625 [00:17<00:00, 34.91it/s, loss=0.677, acc=0.998]


[ 48,   625] loss: 54.137,Train Acc: 79.8, Test Acc:90.8%, Best test Acc:90.8%


100%|██████████| 625/625 [00:18<00:00, 34.21it/s, loss=0.677, acc=0.997]


[ 49,   625] loss: 54.187,Train Acc: 79.8, Test Acc:90.9%, Best test Acc:90.9%


100%|██████████| 625/625 [00:17<00:00, 35.48it/s, loss=0.677, acc=0.998]


[ 50,   625] loss: 54.126,Train Acc: 79.8, Test Acc:90.8%, Best test Acc:90.9%
Finished Training
