<a href="https://colab.research.google.com/github/iny045/Quantization-Aware-Training/blob/main/MNIST_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import math


import torch
from torch.autograd import Function

def compute_ternary_params_groupwise(w: torch.Tensor,
                                     group_size: int = 128,
                                     m: float = 0.75):
    assert w.shape[1] % group_size == 0, \
        "in_features must be divisible by group_size"

    w_abs = w.abs().reshape(-1, group_size)          # (groups, group_size)

    # threshold Δ per group
    delta = m * w_abs.mean(dim=1)       # (groups,)

    # mask of weights that survive the threshold
    keep   = w_abs > delta.unsqueeze(1)              # broadcast

    # scale α  = mean(|w| over kept  fallback to Δ if all-zero
    keep_cnt = keep.sum(dim=1)                       # (groups,)
    keep_sum = (w_abs * keep).sum(dim=1)
    scale = torch.where(keep_cnt > 0,
                        keep_sum / keep_cnt,
                        delta.detach())              # avoid divide-by-0

    return delta, scale


def quantize_ternary_groupwise(w: torch.Tensor,
                               delta: torch.Tensor,
                               group_size: int = 128):
    w_r    = w.reshape(-1, group_size)               # (groups, group_size)
    delta  = delta.unsqueeze(1)                      # (groups, 1)
    qw_r   = torch.where(
        w_r >  delta,  1.0,
        torch.where(w_r < -delta, -1.0, 0.0)
    )
    return qw_r.reshape_as(w)

def dequantize_ternary_groupwise(qw: torch.Tensor,
                                 scale: torch.Tensor,
                                 group_size: int = 128):
    qw_r  = qw.reshape(-1, group_size)
    scale = scale.unsqueeze(1)                       # (groups, 1)
    return (qw_r * scale).reshape_as(qw)

class GroupTernaryFakeQuant(Function):
    @staticmethod
    def forward(ctx, w: torch.Tensor,
                group_size: int = 128,
                m: float = 0.75):
        delta, scale = compute_ternary_params_groupwise(
            w, group_size, m=m
        )
        qw  = quantize_ternary_groupwise(w,   delta,  group_size)
        dqw = dequantize_ternary_groupwise(qw, scale, group_size)

        ctx.save_for_backward(scale)   # only scale needed for STE
        return dqw

    @staticmethod
    def backward(ctx, grad_output):
        #ste
        return grad_output, None, None

def compute_int_scale(x: torch.Tensor, bitwidth: int):
    qmax = 2**bitwidth - 1
    scale = x.abs().max() / qmax
    return scale

def quantize_int(x: torch.Tensor, scale: torch.Tensor, bitwidth: int):
    qmax = 2**bitwidth - 1
    qx = torch.round(x / scale)
    qx = torch.clamp(qx, -qmax, qmax)
    return qx

def dequantize_int(qx: torch.Tensor, scale: torch.Tensor):
    dqx = qx * scale
    return dqx


class PerTensorInt8QFakeQuant(Function):
    @staticmethod
    def forward(ctx, x):
        scale = compute_int_scale(x, bitwidth=8)
        qx = quantize_int(x, scale, bitwidth=8)
        dqx = dequantize_int(qx, scale)
        ctx.save_for_backward(scale)
        return dqx

    @staticmethod
    def backward(ctx, grad_output):
        scale = ctx.saved_tensors
        grad_input = grad_output.clone()
        return grad_input


def fakequantize_weight_ternary(w,
                                group_size: int = 128,
                                m: float = 0.75):
    return GroupTernaryFakeQuant.apply(w, group_size, m)

def fakequantize_activation_int8(w):
    return PerTensorInt8QFakeQuant.apply(w)


class TernaryConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                 dilation=1, groups=1, bias=True, group_size=128, m=0.75):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.group_size = group_size
        self.m = m

        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size,
            stride=stride, padding=padding, bias=bias
        )

    def forward(self, x):

        quant_weight = fakequantize_weight_ternary(
            self.conv.weight,
            group_size=self.group_size,
            m=self.m
        )

        return F.conv2d(
            x, quant_weight, self.conv.bias,
            self.stride, self.padding,
            self.dilation, self.groups
        )

class TernaryLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True, group_size=128, m=0.75):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.group_size = group_size
        self.m = m

        self.linear = nn.Linear(in_features, out_features, bias = bias)

    def forward(self, x):

        quant_weight = fakequantize_weight_ternary(
            self.linear.weight,
            group_size=self.group_size,
            m=self.m
        )

        return F.linear(x, quant_weight, self.linear.bias)

class SimpleConvNet(nn.Module):
    def __init__(self, num_classes=10, group_size=128, m=0.75):
        super().__init__()
        self.conv1 = TernaryConv2d(1, 32, (3, 3), 1, group_size=group_size, m=m)
        self.conv2 = TernaryConv2d(32, 64, (3, 3), 1, group_size=group_size, m=m)

        self.fc1 = TernaryLinear(64*5*5, 128, group_size=group_size, m=m)
        self.fc2 = TernaryLinear(128, num_classes, group_size=group_size, m=m)

        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):

        x = fakequantize_activation_int8(x)

        x = F.relu(self.conv1(x)) # this calls the forward method of class TernaryConv2d
        x = self.pool(x)
        x = fakequantize_activation_int8(x)

        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = fakequantize_activation_int8(x)

        x = torch.flatten(x, 1)
        x = self.dropout(x)

        x = F.relu(self.fc1(x))
        x = fakequantize_activation_int8(x)
        x = self.dropout(x)

        x = self.fc2(x)
        return x


def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data) #forward function of classSimpleCOnvnet is called
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)

    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.0f}%)\n')
    return accuracy


def main():

    batch_size = 64
    epochs = 10
    lr = 0.01
    momentum = 0.9
    group_size = 1 # Changed group_size to 1
    m = 0.75


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)# loads the data in batches for memory efficiency
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


    model = SimpleConvNet(group_size=group_size, m=m).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)


    best_accuracy = 0
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch)
        accuracy = test(model, device, test_loader)

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), "ternary_convnet_best.pth")

    print(f"Best accuracy: {best_accuracy:.2f}%")

if __name__ == '__main__':
    main()


Test set: Average loss: 0.0547, Accuracy: 9829/10000 (98%)


Test set: Average loss: 0.0393, Accuracy: 9866/10000 (99%)


Test set: Average loss: 0.0313, Accuracy: 9898/10000 (99%)


Test set: Average loss: 0.0267, Accuracy: 9918/10000 (99%)


Test set: Average loss: 0.0255, Accuracy: 9916/10000 (99%)


Test set: Average loss: 0.0241, Accuracy: 9928/10000 (99%)


Test set: Average loss: 0.0258, Accuracy: 9923/10000 (99%)


Test set: Average loss: 0.0245, Accuracy: 9916/10000 (99%)


Test set: Average loss: 0.0253, Accuracy: 9916/10000 (99%)


Test set: Average loss: 0.0232, Accuracy: 9926/10000 (99%)

Best accuracy: 99.28%
