In [131]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.nn.utils import prune
from collections import OrderedDict
import os
import flwr
from flwr.common import (parameters_to_ndarrays, ndarrays_to_parameters)

In [132]:
device = torch.device("cpu")


In [151]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(cifar_trainset, batch_size=10, shuffle=True)
cifar_testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(cifar_testset, batch_size=10, shuffle=False)

In [152]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [153]:
class QuantizedNet(nn.Module):
    def __init__(self) -> None:
        super(QuantizedNet, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.quant(x)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.dequant(x)
        return x

In [136]:
def set_parameters(net, parameters):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

In [137]:
def get_parameters(net):
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

In [138]:
def get_quantized_parameters(net):
    parameters = []
    
    for name,module in net.named_modules():
        if isinstance(module, torch.nn.quantized.Linear) or isinstance(module, torch.nn.quantized.Conv2d):
            parameters.append(module.weight().int_repr().cpu().numpy())
            parameters.append(module.bias().detach().cpu().numpy())
            
       
    return parameters  # Convert to NumPy array

In [154]:
def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # metrics
            epoch_loss += loss.item()
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch + 1}: loss={epoch_loss:.4f}, acc={epoch_acc:.4f}")

In [163]:
def test(net, testloader):
    """Evaluate the network on the entire test set."""
    correct, total = 0, 0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return accuracy

In [160]:
model = Net()

In [164]:
train(model, train_loader, 1)
accuracy = test(model, test_loader)

Epoch 1: loss=0.1157, acc=0.5901


In [None]:
torch.ao.quantization.quantize_dynamic(model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8)quantized_model = 
torch.save(quantized_model.state_dict(), "quantized_model.pth")
print(f"{os.path.getsize('quantized_model.pth')} bytes")

76934 bytes


In [None]:
new = Net()
torch.ao.quantization.quantize_dynamic(model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8,inplace=True)
new.load_state_dict(torch.load("quantized_model.pth"))


RuntimeError: Error(s) in loading state_dict for Net:
	Missing key(s) in state_dict: "fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias", "fc3.weight", "fc3.bias". 
	Unexpected key(s) in state_dict: "fc1.scale", "fc1.zero_point", "fc1._packed_params.dtype", "fc1._packed_params._packed_params", "fc2.scale", "fc2.zero_point", "fc2._packed_params.dtype", "fc2._packed_params._packed_params", "fc3.scale", "fc3.zero_point", "fc3._packed_params.dtype", "fc3._packed_params._packed_params". 

In [None]:
train(new, train_loader, 1)