In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import copy

In [2]:
class QuantizedLinearLayer(nn.Module):
    """Quantized version of nn.Linear"""

    def __init__(
        self,
        input_dim,
        output_dim,
        weight,
        weight_scale,
        weight_zero_point,
        bias,
        bias_scale,
        bias_zero_point,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.weight = nn.parameter.Buffer(weight)
        self.bias = nn.parameter.Buffer(bias)
        self.weight_scale = nn.parameter.Buffer(weight_scale)
        self.weight_zero_point = nn.parameter.Buffer(weight_zero_point)
        self.bias_scale = nn.parameter.Buffer(bias_scale)
        self.bias_zero_point = nn.parameter.Buffer(bias_zero_point)

    def forward(self, x):
        # x.shape = (batch_size, input_features)

        # dequantize params
        weight = (self.weight.float() - self.weight_zero_point) / self.weight_scale
        bias = (self.bias.float() - self.bias_zero_point) / self.bias_scale

        # compute
        return x @ weight.T + bias

In [3]:
def quantize_linear(linear_layer):
    """Quantizes a linear layer and returns the quantized weights and biases together
    with the scale and zero point"""
    # quantize linear layer to unsigned 8-bit integers
    weight = linear_layer.weight

    # compute min and max
    min_val = weight.min()
    max_val = weight.max()

    # extend interval to include zero
    if min_val > 0:
        min_val = 0

    if max_val < 0:
        max_val = 0

    # compute scale
    weight_scale = 255 / (max_val - min_val)
    # compute zero point
    weight_zero_point = (-min_val * weight_scale).round().clamp(0, 255).to(torch.uint8)

    # quantize weight
    weight_quantized = (
        (weight * weight_scale + weight_zero_point)
        .round()
        .clamp(0, 255)
        .to(torch.uint8)
    )

    # same for bias
    bias = linear_layer.bias

    min_val = bias.min()
    max_val = bias.max()

    if min_val > 0:
        min_val = 0

    if max_val < 0:
        max_val = 0

    bias_scale = 255 / (max_val - min_val)
    bias_zero_point = (-min_val * bias_scale).round().clamp(0, 255).to(torch.uint8)

    bias_quantized = (
        (bias * bias_scale + bias_zero_point).round().clamp(0, 255).to(torch.uint8)
    )

    return (
        weight_quantized,
        weight_scale,
        weight_zero_point,
        bias_quantized,
        bias_scale,
        bias_zero_point,
    )

In [4]:
def quantize_model(model, exclude_layers, copy_model=True):
    """It quantizes the model by quantizing all the linear layers in the model.
    Args:
        model: the model to quantize
        exclude_layers: list of layers to exclude from quantization
        copy_model: if True, the model is copied before quantization. 
            If False, the model is quantized in place

    Returns:
        the quantized model"""

    if copy_model:
        model = copy.deepcopy(model)

    # quantize model
    for name, layer in model.named_children():
        if name in exclude_layers:
            continue
        if isinstance(layer, nn.Linear):
            # quantize layer
            (
                weight,
                weight_scale,
                weight_zero_point,
                bias,
                bias_scale,
                bias_zero_point,
            ) = quantize_linear(layer)
            # replace layer with quantized version
            setattr(
                model,
                name,
                QuantizedLinearLayer(
                    layer.in_features,
                    layer.out_features,
                    weight,
                    weight_scale,
                    weight_zero_point,
                    bias,
                    bias_scale,
                    bias_zero_point,
                ),
            )
        else:
            # recursively quantize children
            quantize_model(layer, exclude_layers)

    return model

# Example

In [5]:
linear_layer = nn.Linear(10, 20)

x = torch.randn(5, 10)

(
    weight_quantized,
    weight_scale,
    weight_zero_point,
    bias_quantized,
    bias_scale,
    bias_zero_point,
) = quantize_linear(linear_layer)

quantized_linear_layer = QuantizedLinearLayer(
    10,
    20,
    weight_quantized,
    weight_scale,
    weight_zero_point,
    bias_quantized,
    bias_scale,
    bias_zero_point,
)

In [None]:
linear_layer.bias

In [None]:
(bias_quantized - bias_zero_point.float()) / bias_scale

In [None]:
weight_quantized, weight_scale, weight_zero_point, bias_quantized, bias_scale, bias_zero_point

In [None]:
linear_layer(x)

In [None]:
torch.tensor(1.0) - torch.tensor(255, dtype=torch.uint8)

# Test on real Network

In [11]:
from torchvision.datasets import MNIST

In [12]:
# load dataset
mnist_train = MNIST(root="data", download=True, train=True)
mnist_test = MNIST(root="data", download=True, train=False)

In [13]:
x_train = mnist_train.data / 255.0
y_train = mnist_train.targets

x_test = mnist_test.data / 255.0
y_test = mnist_test.targets

In [14]:
# create a MLP with dropout
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 128),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(64, 10),
)

In [15]:
train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
test_dataset = torch.utils.data.TensorDataset(x_test, y_test)

In [16]:
# Hyperparameters
BATCH_SIZE = 64
EPOCHS = 10
LEARNING_RATE = 0.001
HIDDEN_SIZE = 128
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Inizializzazione modello, loss e optimizer
model.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        if batch_idx % 100 == 0:
            print(
                f"Batch: {batch_idx}/{len(train_loader)}, "
                f"Loss: {loss.item():.4f}, "
                f"Accuracy: {100.*correct/total:.2f}%"
            )

    return running_loss / len(train_loader), 100.0 * correct / total


def evaluate(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

    test_loss /= len(test_loader)
    accuracy = 100.0 * correct / total

    print(f"\nTest set: Average loss: {test_loss:.4f}, " f"Accuracy: {accuracy:.2f}%\n")

    return test_loss, accuracy

In [None]:
# Training
best_accuracy = 0.0
for epoch in range(EPOCHS):
    print(f"\nEpoch: {epoch+1}/{EPOCHS}")
    train_loss, train_acc = train_epoch(
        model, train_loader, criterion, optimizer, DEVICE
    )
    test_loss, test_acc = evaluate(model, test_loader, criterion, DEVICE)

    if test_acc > best_accuracy:
        best_accuracy = test_acc
        # Salvataggio del miglior modello
        torch.save(model.state_dict(), "mlp_mnist_best.pth")

print(f"Training completato! Miglior accuratezza: {best_accuracy:.2f}%")

In [None]:
evaluate(model, test_loader, criterion, DEVICE)

In [None]:
# quantize model
quantized_model = quantize_model(model, [])

evaluate(quantized_model, test_loader, criterion, DEVICE)