In [1]:
import sys
sys.path.append("./utils/")
import torch
from torchvision import datasets
from torchvision import transforms
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import torch.nn.utils.prune as prune
import utils.util as util

### Hyperparameters

In [2]:
batch_size = 100
hidden_size = 256
hidden_level = 4
lr = 0.001 
epochs = 50
momentum = 0.9

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load Dataset

In [3]:
training_data = datasets.MNIST("./data", train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST("./data", train=False, transform=transforms.ToTensor(), download=True)

training_loader = torch.utils.data.DataLoader(dataset=training_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False)

# Images have same size
input_size = np.prod(training_data[0][0].shape[1:])
output_size = 10


### Network

In [3]:
class MnistNetwork(torch.nn.Module):

    def __init__(self, input_size, output_size, hidden_size, device="cpu"):

        self.output_size = output_size
        self.hidden_size = hidden_size
        self.device = device

        super(MnistNetwork, self).__init__()

        self.l1 = torch.nn.Linear(input_size, hidden_size, bias=False, device=device)
        self.a1 = torch.nn.ReLU()

        self.l2 = torch.nn.Linear(hidden_size, hidden_size, bias=False, device=device)
        self.a2 = torch.nn.ReLU()

        self.l3 = torch.nn.Linear(hidden_size, hidden_size, bias=False, device=device)
        self.a3 = torch.nn.ReLU()

        self.l4 = torch.nn.Linear(hidden_size, hidden_size, bias=False, device=device)
        self.a4 = torch.nn.ReLU()

        self.l5 = torch.nn.Linear(hidden_size, output_size, bias=False, device=device)
        self.a5 = torch.nn.ReLU()
    
    def forward(self, x):

        out = self.l1(x)
        out = self.a1(out)

        out = self.l2(out)
        out = self.a2(out)

        out = self.l3(out)
        out = self.a3(out)

        out = self.l4(out)
        out = self.a4(out)

        out = self.l5(out)
        out = self.a5(out)

        return out

    def trainModel(self, training_loader, test_loader, epochs=100, lr=0.01, writer = None, PATH = None):

        loss_function = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(), momentum=momentum, lr=lr)

        util.train(model=self, 
            training_loader=training_loader, 
            test_loader=test_loader, 
            loss_function=loss_function, 
            optimizer=optimizer, 
            epochs=epochs, 
            lr=lr, 
            dim=1, 
            device=self.device, 
            writer=writer, 
            PATH=PATH)

model = MnistNetwork(input_size, output_size, hidden_size, device)

NameError: name 'input_size' is not defined

### Training modello

In [None]:
writer = SummaryWriter("./test_grafici/MnistNetwork")
PATH = "./models/MnistNetwork_model"
model.trainModel(training_loader, test_loader, epochs, 0.1, writer, PATH)

### Load Model

In [None]:
util.load_model(model, "./models/MnistNetwork_model")
util.getAccuracy(model, test_loader, device)

### Binarizzazione modello

In [None]:
import utils.binarizedModel as bModel
bmodel = bModel.MNISTBinarizedModel(model, device)

#### Training modello binarizzato

In [None]:
writer = SummaryWriter("./test_grafici/MnistNetworkBinarized")
PATH = "./models/MnistNetwork_binarized_model" 
bmodel.trainModel(training_loader, test_loader, epochs, 0.01, writer=writer, PATH=PATH)

#### Load modello binarizzato

In [None]:
util.load_model(bmodel, "./models/MnistNetwork_binarized_model")
bmodel.getAccuracy(test_loader)

### Pruning del modello binarizzato

Preparazione dei dati

In [None]:
import utils.binarizedModel as bModel
model = MnistNetwork(input_size, output_size, hidden_size, device)
util.load_model(model, "./models/MnistNetworkPruned_model")
bmodel = bModel.MNISTBinarizedModel(model, device)
util.load_model(bmodel, "./models/MnistNetwork_binarized_model")
util.save_model(bmodel, "./models/MnistNetworkPruned_binarized_model_wp")

Training del modello binarizzato. Il primo parametro è il threshold. Si utilizza prune poiché viene usata la binarizzazione -1 0 1.

In [None]:
bmodel.prune(5, training_loader, test_loader, epochs, 0.01, None, "./models/MnistNetworkPruned_binarized_model_wp")

Pruning. Settare la lettura del modello _wp per effettuare il puning.

In [None]:
util.load_model(bmodel, "./models/MnistNetworkPruned_binarized_model_wp")
parameters_to_prune = ((bmodel.fcLayers[0], "weight"), (bmodel.fcLayers[1], "weight"), (bmodel.fcLayers[2], "weight"), (bmodel.fcLayers[3], "weight"), (bmodel.fcLayers[4], "weight"))
prune.global_unstructured(
    parameters_to_prune, pruning_method=util.ThresholdPruning, threshold=5
)

for child in bmodel.fcLayers:
    prune.remove(child, "weight")

tot = 0
size = 0
for g in bmodel.parameters():
    if(len(g.shape) > 1):
        print(g.shape)
        size += g.shape[0] * g.shape[1]
        r = torch.where(torch.eq(g, 0), 1., 0.)
        tot += torch.sum(r).item()
print(f"tot / size = {tot / size}")

print(bmodel.getAccuracy(test_loader, True, 5))

Salvataggio del modello binarizzato pruned

In [None]:
#util.save_model(bmodel, "./models/MnistNetworkPruned_binarized_model")

Caricamento del modello binarizzato pruned

In [6]:
import sys
sys.path.append("./utils/")
import utils.binarizedModel as binModel
import util
bmodel = binModel.MNISTBinarizedModel()
#util.load_model(bmodel, "./demo/models/mnist_pruned_84443_model")
util.load_model(bmodel, "./demo/models/mnist_pruned_96665_model")
output = bmodel.prediction(test_data[0][0].reshape(1, -1), pruned=True)

In [5]:
import sys
sys.path.append("./utils/")
import binarizedModel as binModel
import util
import mnistModel
model = mnistModel.MnistNetwork()
util.load_model(model, "./demo/models/mnist_model", device=device)
bmodel = binModel.MNISTBinarizedModel(device=device)
bmodel.trainModel(training_loader, test_loader, PATH="./demo/models/mnist_binarized_model")

epoch = 1/50, step = 100/600, loss = 0.5178994536399841
epoch = 1/50, step = 200/600, loss = 0.24029308557510376
epoch = 1/50, step = 300/600, loss = 0.27109551429748535
epoch = 1/50, step = 400/600, loss = 0.27022162079811096
epoch = 1/50, step = 500/600, loss = 0.30329039692878723
epoch = 1/50, step = 600/600, loss = 0.3294365704059601
Validation Accuracy = 80.66%
epoch = 2/50, step = 100/600, loss = 0.2317800372838974
epoch = 2/50, step = 200/600, loss = 0.1751854121685028
epoch = 2/50, step = 300/600, loss = 0.24878264963626862
epoch = 2/50, step = 400/600, loss = 0.1203133761882782
epoch = 2/50, step = 500/600, loss = 0.09030026197433472
epoch = 2/50, step = 600/600, loss = 0.180506631731987
Validation Accuracy = 91.18%
epoch = 3/50, step = 100/600, loss = 0.2352297157049179
epoch = 3/50, step = 200/600, loss = 0.2625238001346588
epoch = 3/50, step = 300/600, loss = 0.13193799555301666
epoch = 3/50, step = 400/600, loss = 0.24102115631103516
epoch = 3/50, step = 500/600, loss = 0.