In [None]:
import torch
import qtorch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import time

### Hyperparameters

In [None]:
batch_size = 100
hidden_size = 128
lr = 0.01 
epochs = 50
momentum = 0.9
rounding = "stochastic"

### Common Variable

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
file_name = "./test_binarizzazione_grafici/"

### Load Dataset

In [None]:
training_data = datasets.MNIST("./data", train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST("./data", train=False, transform=transforms.ToTensor(), download=True)

training_loader = torch.utils.data.DataLoader(dataset=training_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False)

# Images have same size
input_size = np.prod(training_data[0][0].shape[1:])
output_size = 10

# images, labels = training_loader.next()

# plt.imshow(images[1][0], cmap="gray")
# plt.show()


### Neural Network

In [None]:
class LinearModule(nn.Module):
    
    def __init__(self, input_size, output_size, hidden_size):
        super(LinearModule, self).__init__()

        self.l1 = nn.Linear(input_size, hidden_size, bias=False, device=device)
        self.act1 = nn.ReLU()

        self.l2 = nn.Linear(hidden_size, hidden_size, bias=False, device=device)
        self.act2 = nn.ReLU()

        self.output_layer = nn.Linear(hidden_size, output_size, bias=False, device=device)

    def forward(self, x):
        out = self.l1(x)
        out = self.act1(out)
        out = self.l2(out)
        out = self.act2(out)
        out = self.output_layer(out)
        return out

model = LinearModule(input_size, output_size, hidden_size)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), momentum=momentum, lr=lr)

torch.save(model.state_dict(), "./models/mnist_float32_model")

### Model Test

In [None]:
def getAccuracy(model):

    with torch.inference_mode():

        acc = 0 

        for i, (images, labels) in enumerate(test_loader):
            images = images.to(device)
            labels = labels.to(device)
            output = model(images.reshape(-1, input_size))
            _, predictions = torch.max(output, 1)
            acc += (predictions == labels).sum().item()
            
        acc = 100 * acc / len(test_data)            
        return acc

In [None]:
getAccuracy(model)

## Load model

In [None]:
model = LinearModule(input_size, output_size, hidden_size)
model.load_state_dict(torch.load("./models/mnist_float32_model"))
model.eval()

___
## Binarization

In [None]:
class Binarize(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x):
        return torch.where(torch.ge(x, 0.), 1., 0.)

    @staticmethod
    def backward(ctx, grad_output):
        return torch.nn.functional.hardtanh(grad_output)
        
binarize = Binarize.apply

# Funzione per approssimare alla potenza di 2 più vicina
def AP2(x):
    return torch.sign(x) * torch.pow(2, torch.round(torch.log2(torch.abs(x))))


def batchNorm(x, gamma, beta, inference = False, E = None, Var = None, lastLevel = False):
    eps = 1e-5
    b_num = E * gamma
    b_den = torch.sqrt(Var + eps).pow_(-1)
    b = beta - b_num * b_den
    w = gamma * torch.sqrt(Var + eps).pow_(-1) 
    y = AP2(w) * x + b              # Moltiplicazione -> shift register (mul per potenza 2)
        
    return y
    
def lossFunction(a, a_star):
    # MSE
    y = a - a_star
    y2 = torch.pow(y, 2)
    return torch.mean(y2)

### Binarized model

In [None]:
class BinarizedModel(nn.Module):
    
    def __init__(self, model):
        super(BinarizedModel, self).__init__()
        self.total_weights = 0              # Numero di pesi totali nella rete
        self.hidden_layer = -1              # Numero di livelli nascosti
        self.number_perceptron_layer = []   # Numero di percettroni in ogni livello
        self.in_perceptron_layer = []       # Numero di ingressi del percettrone del livello i

        for x in model.parameters():
            self.in_perceptron_layer.append(x.shape[1])
            self.number_perceptron_layer.append(x.shape[0])
            self.total_weights += x.shape[0] * x.shape[1]
            self.hidden_layer += 1

        self.L = self.hidden_layer + 1     # numero di livelli della rete

        self.weights = []
        for x in model.parameters():
            weights = torch.empty(x.shape[1], x.shape[0], device=device)  # Numero di Pesi x Numero di Percettroni
            for i in range(x.shape[0]):
                for j in range(x.shape[1]):
                    weights[j][i] = x[i][j].item()
            self.weights.append(weights)

        # Paramatri per la backnormalization
        self.gamma = torch.empty(self.L, device=device)
        self.beta = torch.zeros(self.L, device=device)
        nn.init.uniform_(self.gamma)

        # Parametri trainable
        self.weights = nn.ParameterList(self.weights)       # Pesi della rete
        self.beta = nn.Parameter(self.beta)                 # Beta e Gamma per la backnorm
        self.gamma = nn.Parameter(self.gamma)

        # Parametri non trainable
        # Lista delle medie sui batch delle uscite dei percettroni per ogni livello
        self.EP = [torch.empty(self.number_perceptron_layer[k], device=device) for k in range(self.L)]
        self.EP = nn.ParameterList(self.EP).requires_grad_(False)
        # Lista delle varianze delle uscite dei percettroni per ogni livello
        self.VarP = [torch.empty(self.number_perceptron_layer[k], device=device) for k in range(self.L)]
        self.VarP = nn.ParameterList(self.VarP).requires_grad_(False)
                
        # Fuzione di attivazione
        self.actFun = nn.ReLU()

        self.init() 
    
    # Funzione per la inizializzazione dei parametri
    # Da invocare prima di iniziare il traning
    def init(self):

        # Variabili utili per il forward                  
        self.W = self.weights
        self.Wb = [None] * (self.L)
        self.s = [None] * (self.L)
        self.a = [None] * (self.L + 1)
        self.ab = [None] * (self.L + 1)

        # Traccia media popolazione traning
        self.EPList = [[] for i in range(self.L)]
        self.VarList = [[] for i in range(self.L)]

        # inference = false
        #   calcolo delle statistiche della popolazione
        #   da impostare prima della fase di traning
        # inference = true
        #   utilizzo delle statistiche della popolazione calcolate
        #   da impostare prima della fase di inferenza
        self.inference = False 


    def forward(self, x_vect):

        self.a[-1] = x_vect - 0.5    # Normalizzazione dei dati di input [-1; 1]
        self.ab[-1] = binarize(self.a[-1])
        ab = self.ab[-1]

        batch = len(x_vect) > 1
        
        for k in range(self.L):

            # Paper
            self.Wb[k] = binarize(self.W[k])    # Binarizzazione dei pesi
            self.s[k] = torch.matmul(ab, self.Wb[k])
            if not self.inference:
                self.a[k] = batchNorm(self.s[k], self.gamma[k], self.beta[k])
            else:
                self.a[k] = batchNorm(self.s[k], self.gamma[k], self.beta[k], True, self.EP[k], self.VarP[k])
            if k < self.L - 1:
                self.ab[k] = binarize(self.a[k])
            ab = self.ab[k]
            # ---------------

            # traccia media popolazione traning
            if not self.inference:
                self.EPList[k].append(torch.mean(self.s[k], dim=0))
                self.VarList[k].append(torch.var(self.s[k], dim=0))
            # -------------------

        return self.a[self.L - 1]


    def frozeParameter(self, batch_size):

        for k in range(self.L):
            eP = torch.stack(self.EPList[k], dim=0).to(device)
            varP = torch.stack(self.VarList[k], dim=0).to(device)
            self.EP[k] = torch.mean(eP, dim=0)
            self.VarP[k] = torch.mean(varP, dim=0).mul_(batch_size / (batch_size - 1))

        # Cencellazione dei parametri di traning utilizzati per il calcolo delle statistiche
        # della popolazione
        self.EPList = [[] for i in range(self.L)]
        self.VarList = [[] for i in range(self.L)]

bmodel = BinarizedModel(model)
loss_function = torch.nn.MSELoss()
optimizer = torch.optim.Adam(bmodel.parameters(), lr=lr)

### Load Binarized Model

In [None]:
bmodel = BinarizedModel(model)
bmodel.load_state_dict(torch.load("./models/mnist_binarized_model"))
bmodel.eval()

In [None]:
bmodel.inference = True
getAccuracy(bmodel)

### Binarized Model Embedded

In [None]:
def batchNormE(x, a, b):
    return a * x + b

def  XnorPopCount(x, w):
    
    res = torch.empty(len(x), len(w), device=device)

    for i in range(len(x)):
        xnor = torch.logical_not(torch.logical_xor(w, x[i])).int()
        xnor = 2 * xnor - 1
        res[i] = torch.sum(xnor, dim=1)
    
    return res


class BinarizedModelEmbedded():

    def __init__(self, bmodel):

        self.weights = [None] * len(bmodel.weights)

        for k in range(len(bmodel.weights)):
            self.weights[k] = binarize(bmodel.weights[k])  # [Livello] [Peso] [Percettrone]
            self.weights[k].t_()                           # [Livello] [Percettrone] [Peso]
        
        self.L = bmodel.L

        self.batch_norm_a = [bmodel.gamma[k] / torch.sqrt(bmodel.VarP[k] + 1e-5) for k in range(len(bmodel.gamma))]
        self.batch_norm_b = [bmodel.beta[k] - bmodel.gamma[k] * bmodel.EP[k] / torch.sqrt(bmodel.VarP[k] + 1e-5) for k in range(len(bmodel.gamma))]

    def forward(self, x_vect):

        self.Wb = self.weights
        self.s = [None] * (self.L)
        self.a = [None] * (self.L + 1)
        self.ab = [None] * (self.L + 1)

        self.a[-1] = x_vect - 0.5    # Normalizzazione dei dati di input [-1; 1]
        self.ab[-1] = binarize(self.a[-1])
        ab = self.ab[-1]            # Su una riga c'è l'immagine

        batch = len(x_vect) > 1
        
        for k in range(self.L):

            # Paper
            self.s[k] = XnorPopCount(ab, self.Wb[k])
            self.a[k] = batchNormE(self.s[k], self.batch_norm_a[k], self.batch_norm_b[k])
            if k < self.L - 1:
                self.ab[k] = binarize(self.a[k])
            ab = self.ab[k]

        return self.a[self.L - 1]

    def __call__(self, x):
        return self.forward(x)

bmodelE = BinarizedModelEmbedded(bmodel)

In [None]:
x = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=int)
w = torch.tensor([[1, 0, 0], [1, 0, 1]], dtype=int)
print(XnorPopCount(x, w))

In [None]:
getAccuracy(bmodelE)