In [1]:
import torch
import numpy as np
from copy import deepcopy
import torchvision.datasets as datasets
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.prune import l1_unstructured
from torch.nn.utils.prune import remove

use_cuda = True
use_cuda = use_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [2]:
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        
        #self.conv1 = nn.Conv2d(1, 64, 3)   #CIFAR10 --> (3,64,3)
        #self.conv2 = nn.Conv2d(64, 128, 3)
        self.fc1 = nn.Linear(784, 128)   #MNIST_batches --> 3200       ; #CIFAR10 --> (4608, 128)
        self.fc2 = nn.Linear(128, 10)
        
        self.importances = [(k, torch.zeros_like(p).to(p.device)) for k, p in self.named_parameters()]
        
    def forward(self, x):
        #x =  F.max_pool2d(F.relu(self.conv1(x)), (2,2)) 
        #x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = x.view(-1,10)
        x = x.to(device)
        return x
    
    def reset(self):
        # reinitialize network
        for layer in self.children():
            layer.reset_parameters()

        # clear importance
        for _,imp in self.importances:
            imp.fill_(0.)

#MODEL
model = SimpleModel().to(device)

modello = deepcopy(model)
pre_w = deepcopy(model.state_dict())


#HYPERPARAMETERS
lr = 0.01
epochs = 3
batch_size = 100
criterion = nn.CrossEntropyLoss()
optim1 = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
optim2 = torch.optim.Adam(model.parameters(), lr=0.0001)

In [3]:
# Loading Dataset
name = 'MNIST'
m = getattr(datasets, name)

ds_train = m("./data", train=True, download=True, transform=transforms.ToTensor())
ds_test = m("./data", train=False, download=True, transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(ds_test,batch_size=batch_size, shuffle=False)

num_iters = int(np.ceil(train_loader.dataset.data.shape[0] * epochs / batch_size))

### Normal Train & evaluation

In [4]:
idx = 1
running_loss = 0.0
for epoch in range(epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device); labels = labels.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optim2.zero_grad()
        loss.backward()
        optim2.step()
        
        running_loss += loss.item()
        if idx % 180 == 0:
                print(f'\rTraining {idx}/{num_iters} iterations done.     Loss: {running_loss /180:.3f}')#, end='')
                running_loss = 0.0
        idx += 1

Training 180/1800 iterations done.     Loss: 1.775
Training 360/1800 iterations done.     Loss: 0.942
Training 540/1800 iterations done.     Loss: 0.632
Training 720/1800 iterations done.     Loss: 0.506
Training 900/1800 iterations done.     Loss: 0.432
Training 1080/1800 iterations done.     Loss: 0.400
Training 1260/1800 iterations done.     Loss: 0.376
Training 1440/1800 iterations done.     Loss: 0.347
Training 1620/1800 iterations done.     Loss: 0.331
Training 1800/1800 iterations done.     Loss: 0.320


In [5]:
accuracy = []
with torch.no_grad():
    for j, (inputs,labels) in enumerate(test_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        out = model(inputs)
        lost = F.cross_entropy(out, labels) 

        _, preds = torch.max(out, dim=1)
        acc = torch.tensor(torch.sum(preds == labels).item() / len(preds))
        accuracy.append(acc)

        if j % 20 == 0:
            print('test_loss : ', lost, ' --- test_acc: ', acc)

test_loss :  tensor(0.2565)  --- test_acc:  tensor(0.9300)
test_loss :  tensor(0.4951)  --- test_acc:  tensor(0.8800)
test_loss :  tensor(0.3804)  --- test_acc:  tensor(0.8900)
test_loss :  tensor(0.4049)  --- test_acc:  tensor(0.9000)
test_loss :  tensor(0.2202)  --- test_acc:  tensor(0.9300)


### Calculate IMPORTANCES of weights (Fisher Information)

In [12]:
inputs = train_loader.dataset.data

accumulators = [torch.zeros_like(p, device=device) for p in model.parameters()]

for i in range(len(inputs)):
    for accum, v in zip(accumulators, model.parameters()):
        accum.add_(v.grad.square())

for accum, (_,imp) in zip(accumulators, model.importances):
    imp.add_(accum / len(inputs))



new_model = deepcopy(model)

### Masking: threshold method

In [7]:
thresh = 0.000005
new_imp = model.importances
masks = {}
perc = []
print('Threshold', thresh)
for k in range(len(new_imp)):
    bool_tensor = new_imp[k][1] > thresh
    mask = torch.where(bool_tensor == True, 1., 0.)
    masks[tuple(new_imp[k][0].split(".")) ] = mask
    print(new_imp[k][0])
    true_values = bool_tensor.masked_select(bool_tensor == True)
    false_values = bool_tensor.masked_select(bool_tensor == False)
    true_num = len(true_values)
    false_num = len(false_values)
    print('Number of weights that changes much: ',true_num)
    print('Number of weights that changes less: ',false_num)
    perc.append(false_num/(false_num+true_num))
    print('We can prune the', round(false_num/(false_num+true_num)*100, 4), '% of the weights')

Threshold 5e-06
fc1.weight
Number of weights that changes much:  3157
Number of weights that changes less:  97195
We can prune the 96.8541 % of the weights
fc1.bias
Number of weights that changes much:  51
Number of weights that changes less:  77
We can prune the 60.1562 % of the weights
fc2.weight
Number of weights that changes much:  1062
Number of weights that changes less:  218
We can prune the 17.0312 % of the weights
fc2.bias
Number of weights that changes much:  10
Number of weights that changes less:  0
We can prune the 0.0 % of the weights


In [8]:
for pr, (lay,imp) in zip(perc,new_model.importances):
        name =  tuple(lay.split("."))
        module = getattr(new_model, name[0])
        l1_unstructured(module, name[1],  amount = pr, importance_scores=imp)

### Masking: percentile method

In [13]:
for lay,imp in new_model.importances:
        name =  tuple(lay.split("."))
        module = getattr(new_model, name[0])
        l1_unstructured(module, name[1],  amount = 0.3, importance_scores=imp)

### Pruning
E' meglio un pruning locale o globale?

In [14]:
#To make pruning permanent this function reassigns the parameter weight to the model parameters, in its pruned version.
def remove_params(model):
    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            try:
                remove(module, "weight")
            except:
                pass
            try:
                remove(module, "bias")
            except:
                pass
        elif isinstance(module, torch.nn.Linear):
            try:
                remove(module, "weight")
            except:
                pass
            try:
                remove(module, "bias")
            except:
                pass
    return model

pruned_model = remove_params(new_model)

### Re-training the pruned_model

Ho bisogno della penalty da sommare alla loss?

Ho bisogno di lambda per pesare l'importanza dell'EWC?

Utilizzo lo stesso train_dataset o faccio lo shuffle?

In [15]:
idx = 1
running_loss = 0.0
#lmbda = 1000
#penalty = {}
#star_params = [p.clone().detach() for p in pruned_model.parameters()]

#train_loader = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True)

for epoch in range(epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device); labels = labels.to(device)
        
        outputs = pruned_model(inputs)
        loss = criterion(outputs, labels)
        
        #for (name,p), (_,imp), p_star in zip(pruned_model.named_parameters(), pruned_model.importances, star_params):
        #        penalty[name] = lmbda * torch.sum(imp * torch.square(p - p_star))
        #        loss += lmbda * torch.sum(imp * torch.square(p - p_star))  
     
        optim2.zero_grad()
        loss.backward()
        optim2.step()
        
        running_loss += loss.item()
        if idx % 180 == 0:
                print(f'\rTraining {idx}/{num_iters} iterations done.     Loss: {running_loss /180:.3f}')#, end='')
                running_loss = 0.0
        idx += 1

Training 180/1800 iterations done.     Loss: 1.053
Training 360/1800 iterations done.     Loss: 1.042
Training 540/1800 iterations done.     Loss: 1.031
Training 720/1800 iterations done.     Loss: 1.040
Training 900/1800 iterations done.     Loss: 1.042
Training 1080/1800 iterations done.     Loss: 1.043
Training 1260/1800 iterations done.     Loss: 1.041
Training 1440/1800 iterations done.     Loss: 1.032
Training 1620/1800 iterations done.     Loss: 1.059
Training 1800/1800 iterations done.     Loss: 1.026


perché la loss non scende?

In [16]:
accuracy = []
with torch.no_grad():
    for j, (inputs,labels) in enumerate(test_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        out = pruned_model(inputs)
        lost = F.cross_entropy(out, labels) 

        _, preds = torch.max(out, dim=1)
        acc = torch.tensor(torch.sum(preds == labels).item() / len(preds))
        accuracy.append(acc)

        if j % 20 == 0:
            print('test_loss : ', lost, ' --- test_acc: ', acc)

test_loss :  tensor(0.7444)  --- test_acc:  tensor(0.7500)
test_loss :  tensor(1.0868)  --- test_acc:  tensor(0.6700)
test_loss :  tensor(1.0407)  --- test_acc:  tensor(0.6800)
test_loss :  tensor(1.3259)  --- test_acc:  tensor(0.6500)
test_loss :  tensor(1.1741)  --- test_acc:  tensor(0.6700)


In [19]:
print(list(pruned_model.parameters()))

[Parameter containing:
tensor([[0., -0., -0.,  ..., 0., -0., -0.],
        [0., -0., -0.,  ..., -0., -0., 0.],
        [0., -0., -0.,  ..., 0., 0., 0.],
        ...,
        [0., -0., 0.,  ..., 0., -0., -0.],
        [0., -0., 0.,  ..., 0., 0., -0.],
        [0., -0., 0.,  ..., 0., -0., 0.]], requires_grad=True), Parameter containing:
tensor([ 0.0707,  0.0000,  0.0000,  0.0411,  0.0696,  0.0571,  0.0293,  0.0331,
         0.0517,  0.0725,  0.0089,  0.0000,  0.0000, -0.0397,  0.0432,  0.0503,
         0.0472,  0.0250, -0.0000, -0.0040,  0.0204,  0.0783, -0.0042,  0.0058,
         0.0466,  0.0477,  0.0119, -0.0000, -0.0184,  0.0309,  0.0233,  0.0468,
        -0.0263,  0.0500, -0.0036,  0.0421,  0.0000,  0.0000, -0.0051,  0.0465,
        -0.0426, -0.0000,  0.0465,  0.0431,  0.0355,  0.0126, -0.0077, -0.0423,
         0.0000, -0.0105,  0.0403,  0.0000,  0.0631,  0.0000, -0.0016,  0.0321,
         0.0747,  0.0263, -0.0043, -0.0000,  0.0000, -0.0336,  0.0000, -0.0000,
         0.0463,  0.042