In [None]:
import torch
from torch.autograd.function import InplaceFunction, Function
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
import math
import numpy as np
import torchvision.transforms as transforms
import torch.nn.init as init
import torchvision
from tqdm import tqdm
import torch.optim as optim

In [None]:
def quantize_model(model, quantize = False, bits = 8, qmode = "dynamic"):
    if quantize:
        print("Quantize mode on")
        for layer in model.modules():
            try:
                mode = layer.mode()
                if mode == False:
                    layer.change_mod(True, bits, qmode)
            except:
                continue
    else:
        print("Quantize mode off")
        for layer in model.modules():
            try:
                mode = layer.mode()
                if mode == True:
                    layer.change_mod(False, 0)
            except:
                continue
    return model

In [None]:
def qsin_activation_mode(model):
    print("QSIN activation mode on")
    for layer in model.modules():
        try:
            mode = layer.mode()
            layer.qsinmode()
        except:
            continue
    return model

In [None]:
class MyRound(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, input):
        return torch.round(input)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

In [None]:
def Quantize_tensor(input_tensor, max_abs_val = None, num_bits = 8):
    my_round = MyRound.apply
    qmin = -1.0 * (2**num_bits) / 2
    qmax = -qmin - 1
    scale = max_abs_val / ((qmax - qmin) / 2)
    input_tensor = torch.div(input_tensor, scale)
    input_tensor = my_round((input_tensor))
    input_tensor = torch.clamp(input_tensor, qmin, qmax)
    return torch.mul(input_tensor, scale)

In [None]:
class Quant(nn.Module):
    def __init__(self, num_bits=8, mode = "dynamic", static_count = 30):
        super(Quant, self).__init__()
        self.num_bits = num_bits
        self.mode = mode
        self.static_count = static_count
        self.static_cur = 0
        self.stat_values = []
        self.max_abs = 0 
        self.max_abs_tr = nn.Parameter(torch.zeros(0), requires_grad=True) # IMPORTANT
        
    def forward(self, input):
        if self.mode == "dynamic":
            self.max_abs = torch.max(torch.abs(input))
            return Quantize_tensor(input, self.max_abs, self.num_bits)
        
        elif self.mode == "static":
            if self.static_cur > self.static_count:
                return Quantize_tensor(input, self.max_abs_tr, self.num_bits)
            elif self.static_cur == self.static_count:
                self.max_abs = np.mean(self.stat_values)
                self.max_abs_tr.data = torch.tensor(self.max_abs, dtype=torch.float).to(self.max_abs_tr.device)
                self.static_cur += 1
                return Quantize_tensor(input, self.max_abs_tr, self.num_bits)
            else:
                self.static_cur += 1
                self.stat_values.append(np.max(np.absolute(input.cpu().detach().numpy())))
                return input

In [None]:
def QSin(x, num_bits = 8): 
    pi = torch.tensor(np.pi)
    qmin = -1.0 * (2**num_bits) / 2
    qmax = -qmin - 1
    result = torch.sum(torch.square(torch.sin(torch.mul(pi, x[torch.logical_and(x >= qmin, x <= qmax)]))))
    result = result + torch.sum(torch.mul(torch.square(pi), torch.square((x[x < qmin] - qmin))))
    result = result + torch.sum(torch.mul(torch.square(pi), torch.square((x[x > qmax] - qmax))))
    return result

In [None]:
class Conv2d(nn.Conv2d):
    def __init__(self, inCh: int, outCh: int, kernel_size: int = 4, stride: int = 1, padding: int = 0,
                 bias: bool = None, quantization: bool = False, q_bits: int = 8, qsin_activation = False):
        super(Conv2d,self).__init__(inCh, outCh, kernel_size, stride=stride, padding = padding, bias = bias)

        self.quantize = True if quantization else False
        self.QsinA = True if qsin_activation else False

        if self.quantize:
            self.bits = q_bits
            self.Quantize_weights = Quant(self.bits)
            self.Quantize_input = Quant(self.bits)
        else:
            self.bits = 'FP'
            
        if self.QsinA:
            self.qsin_loss_A = 0

    def init(self, input):
        self.inputW = input.shape

    def change_mod(self, value, bits = 8, mode = "dynamic"):
        self.quantize = value
        self.bits = bits
        self.Quantize_weights = Quant(bits,mode)
        self.Quantize_input = Quant(bits, mode)
    
    def qsinmode(self):
        self.QsinA = True
        self.qsin_loss_A = 0

    def mode(self):
        return self.quantize  

    def forward(self, input):
        if self.quantize:
            qinput = self.Quantize_input(input)
            qweight = self.Quantize_weights(self.weight)
            
            
            #count qsin loss on activation
            if self.QsinA:
                self.qsin_loss_A = 0
                qmin = -1.0 * (2**self.bits) / 2
                qmax = -qmin - 1
                scale = self.Quantize_input.max_abs_tr / ((qmax - qmin) / 2)
                sq_scale = torch.square(scale)
                self.qsin_loss_A = torch.mul(sq_scale, QSin(torch.div(input, scale), self.bits))
                
                
            return nn.functional.conv2d(qinput, qweight, self.bias, self.stride, self.padding, self.dilation, self.groups)

        else:
            return nn.functional.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

In [None]:
class Linear(nn.Linear):
    def __init__(self, inFeatures: int, outFeatures: int, bias: bool = True, quantization: bool = False, q_bits: int = 8, qsin_activation = False):
        super(Linear, self).__init__(inFeatures, outFeatures, bias)

        self.quantize = True if quantization else False
        self.QsinA = True if qsin_activation else False

        if self.quantize:
            self.bits = q_bits
            self.Quantize_weights = Quant(self.bits)
            self.Quantize_input = Quant(self.bits)
        else:
            self.bits = 'FP'
            
        if self.QsinA:
            self.qsin_loss_A = 0

    def init(self, input):
        self.inputW = input.shape
        
    def change_mod(self, value, bits = 8, mode = "dynamic"):
        self.quantize = value
        self.bits = bits
        self.Quantize_weights = Quant(bits, mode)
        self.Quantize_input = Quant(bits, mode)
        
    def qsinmode(self):
        self.QsinA = True
        self.qsin_loss_A = 0
        
    def mode(self):
        return self.quantize  

    def forward(self, input):
            
        if self.quantize:
            qinput = self.Quantize_input(input)
            qweight = self.Quantize_weights(self.weight)
            
            #count qsin loss on activation
            if self.QsinA:
                self.qsin_loss_A = 0
                qmin = -1.0 * (2**self.bits) / 2
                qmax = -qmin - 1
                scale = self.Quantize_input.max_abs_tr / ((qmax - qmin) / 2)
                sq_scale = torch.square(scale)
                self.qsin_loss_A = torch.mul(sq_scale, QSin(torch.div(input, scale), self.bits))
            
            return nn.functional.linear(qinput, qweight, self.bias)
        else:
            return nn.functional.linear(input, self.weight, self.bias)

In [None]:
def Qsin_W(model, bits = 8):
    qmin = -1.0 * (2**bits) / 2
    qmax = -qmin - 1
    loss = 0
    for layer in model.modules():
        try:
            scale = layer.Quantize_weights.max_abs_tr / ((qmax - qmin) / 2)
            sq_scale = torch.square(scale)
            QSin_w = QSin(torch.div(layer.weight, scale), bits)
            loss = loss + torch.mul(sq_scale, QSin_w)
        except:
            continue
    return loss

In [None]:
def Qsin_A(model):
    loss = 0
    for layer in model.modules():
        try:
            loss = loss + layer.qsin_loss_A
        except:
            continue
    return loss

# Model and train

In [None]:
def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, Linear) or isinstance(m, Conv2d):
        init.kaiming_normal(m.weight)


class LambdaLayer(nn.Module):
    def __init__(self, lambd, planes):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
        self.planes = planes

    def forward(self, x):
        return self.lambd(x, self.planes)


def pad_func(x, planes):
    return F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = Conv2d(
            in_planes, planes, kernel_size=3, 
            stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = Conv2d(
            planes, planes, kernel_size=3, 
            stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = LambdaLayer(pad_func, planes)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.conv1 = Conv2d(
            3, 16, kernel_size=3, 
            stride=1, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = Linear(64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):

        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = torch.flatten(out, 1)
        out = self.linear(out)

        return out

In [None]:
model = ResNet(BasicBlock, [3, 3, 3], 10)

In [None]:
checkpoint_path = '../input/resnet20/resnet_20_cifar10_91.73.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state'])
model.to(device)
print('loaded')

In [None]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    normalize,
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    normalize,
])

In [None]:
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=100, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
def train(net, epoch = 0):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0.0
    correct_train = 0
    total_train = 0
    for data, target in trainloader:
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = net(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * data.size(0)
        _, predicted = torch.max(output.data, 1)
        total_train += target.size(0)
        correct_train += (predicted == target).float().sum()
        accuracy_train = correct_train / total_train
        train_loss = train_loss/len(trainloader.sampler)
    print('Epoch: {} \tTraining Loss: {:.6f} \tTraining Accuracy: {:.6f}'.format(epoch, train_loss, accuracy_train))


def test(net, epoch = 0):
    valid_loss = 0.0
    correct_valid = 0
    total_valid = 0
    net.eval()
    for data, target in testloader:
        data = data.to(device)
        target = target.to(device)
        output = net(data)
        loss = criterion(output, target)
        
        valid_loss += loss.item() * data.size(0)
        
        _, predicted = torch.max(output.data, 1)
        total_valid += target.size(0)
        correct_valid += (predicted == target).float().sum()
        valid_loss = valid_loss/len(testloader.sampler)
        accuracy_valid = correct_valid / total_valid
    print('Epoch: {} \tTest Loss: {:.6f} \tTest Accuracy: {:.6f}'.format(epoch, valid_loss, accuracy_valid))
    return accuracy_valid

In [None]:
criterion = nn.CrossEntropyLoss()
#model = quantize_model(model, quantize=False, bits = 4)
test(model)

In [None]:
model = quantize_model(model, quantize=False, bits = 4)
model = quantize_model(model, quantize=True, bits = 4)
model.to(device)
test(model)

In [None]:
from tqdm import tqdm
model = quantize_model(model, quantize=False, bits = 4)
model = quantize_model(model, quantize=True, bits = 4, qmode = "static")
model.to(device)
i = 0
model.eval()
for data, target in trainloader:
    i+=1
    if i == 35:
        break
    data = data.to(device)
    target = target.to(device)
    output = model(data)

test(model)

# QSIN

In [None]:
qsin_activation_mode(model)
print()

In [None]:
Qsin_W(model, 4)

In [None]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=5e-4)

loss_w = []
loss_a = []
accuracy_all = []

j = 0
for i in tqdm(range(70)):
    model.train()
    for data, target in trainloader:
        
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = model(data)
        
        if j % 50 == 0:
            lambda_w = 1 / 1000
            lambda_a = 1 / 10000000
            loss_w.append(Qsin_W(model, 4).cpu().detach())
            loss_a.append(Qsin_A(model).cpu().detach())
            accuracy_valid = test(model)
            accuracy_all.append(accuracy_valid)
            
        j+=1
        
        lambda_w = 1 / 1000
        lambda_a = 1 / 10000000
        loss = criterion(output, target) + lambda_w * Qsin_W(model, 8) + lambda_a * Qsin_A(model)
        loss.backward()
        optimizer.step()

# QAT

In [None]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=5e-4)

accuracy_all = []

j = 0
for i in tqdm(range(70)):
    model.train()
    for data, target in trainloader:
        
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = model(data)
        
        if j % 50 == 0:
            accuracy_valid = test(model)
            accuracy_all.append(accuracy_valid)
            
        j+=1
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()