This example copies a lot from [a pytorch tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)


In [1]:
import os
import sys

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim


# To be able to import model
sys.path.append('/'+'/'.join(os.getcwd().split('/')[:-1])+'/')

from model.Conversion import Start,Stop
from model.activations import ReLU,PACT
from model.linear import Linear
from model.wrapped import FlattenM,MaxPool2d,Dropout
from model.blocks import ConvBnA,LinBnA
from model.sequential import Sequential
from model.Quantizer import LinQuantExpScale
from model.QuantizationMethods.MinMSE import MinMSE_convolution_weight_quantization, MinMSE_linear_weight_quantization

from model.DataWrapper import DataWrapper
from model.convolution.weight_quantization import LinQuantWeight
from model.Quantizer import FakeQuant, Quant
from model.logger import logger_init, logger_forward
from types import FunctionType
from typing import Tuple, Optional

from tqdm import tqdm

path = 'cifa10_vgg_custom_act_weightqaunt_rescaler'


if not os.path.exists('./runs'):
    os.mkdir('./runs')
if not os.path.exists(f'./runs/{path}'):
    os.mkdir(f'./runs/{path}')

In [2]:
batch_size = 80
epochs = 50

In [3]:
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (.3, .3, .3))])
transform_test = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (.3, .3, .3))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [4]:

# This is a fixpoint quantization
class CustomConvWeightQuant(LinQuantWeight):
    @logger_init
    def __init__(self, bits: int = 8, size: tuple = (-1,), rounding_mode: str = "round", layer_wise=False) -> None:
        super(CustomConvWeightQuant,self).__init__(bits, size, rounding_mode,layer_wise)

        # self min and max is set to -2**(bits-1) and 2**(bits-1)-1

        weight_range = 1
        nn.init.constant_(self.delta_in,2*weight_range / (2.0**self.bits - 1))
        self.delta_out = self.delta_in

    @logger_forward
    def forward(self, x: Tensor, rexp_mean: Tensor, rexp_diff: Tensor, fact_fun: FunctionType) -> Tensor:
        with torch.no_grad():
            fact = fact_fun((self.delta_out.view(1,-1,1,1) * rexp_mean).log2()).view(-1, 1, 1, 1)

            self.delta_for_quant = self.delta_in.div(rexp_diff.view(*self.rexp_view)).div_(fact)

            # clipping the weights, improves performance
            x.data.clamp_(self.delta_for_quant*(self.min-0.5),
                        self.delta_for_quant*(self.max+0.5))

        return FakeQuant(
                x=x.clone(),
                delta_in=self.delta_for_quant,
                delta_out=self.delta_for_quant,
                training=self.training,
                min_quant=self.min,
                max_quant=self.max,
                rounding_mode=self.rounding_mode,
            )


# custom back propagation for relu6
class RELU6_back_function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, val: Tensor, m: Tensor) -> Tensor:
        ctx.save_for_backward(val >= m, val >= 0)
        return val.clone()

    @staticmethod
    def backward(ctx, grad_outputs: Tensor) -> Tuple[Tensor, Tensor]:
        m_cmp, zero_cmp = ctx.saved_tensors
        val_gard = grad_outputs * torch.logical_and(zero_cmp,~m_cmp)
        return val_gard, None


#custom RELU6 activation function
class CustomActivationRelu6(Quant):
    def __init__(self, bits, size=(-1,), rounding_mode: str = "floor", use_enforced_quant_level: bool = False):
        super(CustomActivationRelu6, self).__init__(bits, size, rounding_mode, use_enforced_quant_level)
        self.bits = bits
        
        nn.init.constant_(self.delta_in,6/(2**bits - 1))
        self.delta_out = self.delta_in

        nn.init.constant_(self.min, 0)
        nn.init.constant_(self.max, 2**bits - 1)

    def forward(self, x: torch.Tensor, fake: bool = False, metadata: Optional[DataWrapper] = None,*args,**kargs):
        if self.training:
            with torch.no_grad():
                if self.use_enforced_quant_level and metadata is not None:
                    self.use_quant(metadata)
                if self.use_enforced_quant_level and metadata is None:
                    raise ValueError("Quantization function desired but metadata not passed")

            x = RELU6_back_function.apply(x,6)
        return super(CustomActivationRelu6,self).forward(x, fake)

In [5]:
# cross layer rescaler definition 
count = 0
ns = []
def fun():
    points = 1
    qa = torch.tensor([q/(points+1) for q in list(range(1,points+1))],device='cuda')
    print(qa)
    global count 
    global ns
    def calculate_n_a_fixed(weight,mean,var,out_quant,rexp,):
        with torch.no_grad():
            n = torch.log2(weight.abs() / (out_quant * torch.sqrt(var + 1e-5)))
            n = torch.nan_to_num(n,nan=0,posinf=0,neginf=-32).add(rexp.view(-1)).clip(min=-32,max=0)
            if count+1>len(ns):
                ns.append(n.detach().clone().view(-1))
                nr = n.median() * torch.ones_like(n)
                nr = torch.ceil(nr)
            else:
                ns[count] = n.detach().clone().view(-1)
                data_points = torch.concat(ns).quantile(qa)
                nr = data_points[0] * torch.ones_like(n)
                dist = torch.abs(n-data_points[0])
                for i in range(1,points):
                    nr = torch.where(torch.abs(n-data_points[i]) > dist,nr,data_points[i])
                    dist = torch.where(torch.abs(n-data_points[i]) > dist,dist,torch.abs(n-data_points[i]))
                nr = torch.ceil(nr)

            alpha = (torch.sign(weight)+1e-5).sign() * torch.exp2(n - nr)
            return nr, alpha
    count += 1
    return calculate_n_a_fixed

In [6]:



weight_quant = CustomConvWeightQuant

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # Start and Stop modules convert the Float values to the fake quantized domain
        # and during inference to the integer domain expressed by a float (necessary due to NVIDIA)
        self.start = Start(bits=8,size=(1,3,1,1),mode="auto",auto_runs=2)
        # mode="auto" simply measures min and max of the input and quantized to them in a symmetric manner 
        self.stop = Stop(size=(1,10))
        
        self.seq = Sequential(
            ConvBnA(  3, 64,3,1,1,activation=CustomActivationRelu6,weight_quant=weight_quant),
            ConvBnA( 64, 64,3,1,1,activation=CustomActivationRelu6,weight_quant=weight_quant,BN_shift_alpha_function=fun()),
            MaxPool2d(2,2),
            Dropout(0.1),
            ConvBnA( 64,128,3,1,1,activation=CustomActivationRelu6,weight_quant=weight_quant,BN_shift_alpha_function=fun()),
            ConvBnA(128,128,3,1,1,activation=CustomActivationRelu6,weight_quant=weight_quant,BN_shift_alpha_function=fun()),
            MaxPool2d(2,2),
            Dropout(0.1),
            ConvBnA(128,256,3,1,1,activation=CustomActivationRelu6,weight_quant=weight_quant,BN_shift_alpha_function=fun()),
            ConvBnA(256,256,3,1,1,activation=CustomActivationRelu6,weight_quant=weight_quant,BN_shift_alpha_function=fun()),
            Dropout(0.1),
            MaxPool2d(2,2),
            FlattenM(1),
            LinBnA(256*4*4,10,weight_quant=MinMSE_linear_weight_quantization,weight_quant_channel_wise=True,activation=LinQuantExpScale,affine=False,BN_shift_alpha_function=fun()),
        )

    def forward(self, x):
        x = self.start(x)
        x = self.seq(x)
        x = self.stop(x)
        return x

This example copies a lot from [a pytorch tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)


In [7]:
def eval(net, pr=True):
    global testloader,device
    global best
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    net.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(device,non_blocking=True)
            labels = labels.to(device,non_blocking=True)
            # calculate outputs by running images through the network
            outputs = net(images)
            # loss = criterion(outputs, target)

            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    net.train()
    if pr:
        print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:4.1f} %')
    if best < 100 * correct / total:
        best = 100 * correct / total

    return 100 * correct / total, best


In [8]:
net = Net()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = net.to(device)


criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9,weight_decay=5e-4)
sched= optim.lr_scheduler.CosineAnnealingLR(optimizer,epochs,1e-5,verbose=False)

best = 0

tensor([0.5000], device='cuda:0')
tensor([0.5000], device='cuda:0')
tensor([0.5000], device='cuda:0')
tensor([0.5000], device='cuda:0')
tensor([0.5000], device='cuda:0')
tensor([0.5000], device='cuda:0')
tensor([0.5000], device='cuda:0')


In [9]:
for epoch in range(epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    training_running_correct = 0
    with tqdm(enumerate(trainloader, 0),total=len(trainloader),disable=False) as t:
        for i, (data,target) in t:
            # get the inputs; data is a list of [inputs, labels]
            target = target.to(device,non_blocking=True)
            data = data.to(device,non_blocking=True)

            # zero the parameter gradients
            optimizer.zero_grad(True)

            # forward + backward + optimize
            outputs = net(data)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()


            # print statistics
            running_loss += loss.item()

            _, preds = torch.sort(outputs.data.clone().detach(), 1,descending=True)
            training_running_correct += (preds[:,0].view(-1,1) == target.view(-1,1)).any(dim=1).sum().detach().item()
            t.set_postfix({'loss': running_loss/(i+1),'acc':training_running_correct/((i+1)*batch_size)})
    ev = eval(net,pr=False)
    print(f'[{epoch + 1:3d}, {i + 1:5d}] loss: {running_loss*batch_size/len(trainloader):6.3f},Train Acc: {training_running_correct/len(trainloader):3.1f}, Test Acc:{ev[0]:3.1f}%, Best test Acc:{ev[1]:3.1f}%')
    running_loss = 0.0
    torch.save(net.state_dict(),f"./runs/{path}/ckp.pt")
    sched.step()
            
print('Finished Training')


100%|██████████| 625/625 [00:17<00:00, 35.57it/s, loss=1.65, acc=0.487]


reduce autorun by 1: 1 min/max -1.6666666269302368 / 1.6666666269302368
[  1,   625] loss: 131.604,Train Acc: 38.9, Test Acc:50.9%, Best test Acc:50.9%


100%|██████████| 625/625 [00:16<00:00, 37.35it/s, loss=1.4, acc=0.629] 


reduce autorun by 1: 0 min/max -1.6666666269302368 / 1.6666666269302368
[  2,   625] loss: 111.977,Train Acc: 50.3, Test Acc:60.5%, Best test Acc:60.5%


100%|██████████| 625/625 [00:16<00:00, 38.92it/s, loss=1.33, acc=0.671]


[  3,   625] loss: 106.010,Train Acc: 53.7, Test Acc:46.0%, Best test Acc:60.5%


100%|██████████| 625/625 [00:16<00:00, 37.29it/s, loss=1.3, acc=0.683] 


[  4,   625] loss: 104.373,Train Acc: 54.7, Test Acc:46.2%, Best test Acc:60.5%


100%|██████████| 625/625 [00:16<00:00, 37.36it/s, loss=1.4, acc=0.631] 


[  5,   625] loss: 111.888,Train Acc: 50.5, Test Acc:47.0%, Best test Acc:60.5%


100%|██████████| 625/625 [00:16<00:00, 37.58it/s, loss=1.41, acc=0.623]


[  6,   625] loss: 112.948,Train Acc: 49.8, Test Acc:34.7%, Best test Acc:60.5%


100%|██████████| 625/625 [00:16<00:00, 37.07it/s, loss=1.41, acc=0.627]


[  7,   625] loss: 112.422,Train Acc: 50.1, Test Acc:22.7%, Best test Acc:60.5%


100%|██████████| 625/625 [00:16<00:00, 37.75it/s, loss=1.41, acc=0.625]


[  8,   625] loss: 112.571,Train Acc: 50.0, Test Acc:35.9%, Best test Acc:60.5%


100%|██████████| 625/625 [00:16<00:00, 37.18it/s, loss=1.41, acc=0.623]


[  9,   625] loss: 112.459,Train Acc: 49.9, Test Acc:29.9%, Best test Acc:60.5%


100%|██████████| 625/625 [00:16<00:00, 37.12it/s, loss=1.42, acc=0.614]


[ 10,   625] loss: 113.956,Train Acc: 49.2, Test Acc:32.7%, Best test Acc:60.5%


100%|██████████| 625/625 [00:16<00:00, 37.87it/s, loss=1.43, acc=0.612]


[ 11,   625] loss: 114.548,Train Acc: 49.0, Test Acc:24.1%, Best test Acc:60.5%


100%|██████████| 625/625 [00:16<00:00, 37.68it/s, loss=1.45, acc=0.605]


[ 12,   625] loss: 115.791,Train Acc: 48.4, Test Acc:23.5%, Best test Acc:60.5%


 81%|████████▏ | 509/625 [00:13<00:03, 38.34it/s, loss=1.44, acc=0.605]