In [1]:
import torch
import sys

sys.path.append("..") # so we can import the layer_sim library

from layer_sim import networks
from layer_sim import datasets
from layer_sim.train import train_net, test_net
from layer_sim.pruning.IMP import imp_lrr

## Dataset and NN preparation

In [2]:
train_batch = 128
test_batch = 128
trainloader, testloader = datasets.MNIST("../data", train_batch, test_batch, num_workers=4)
net = networks.LeNet5(num_classes=10)

### TODO: insert image of LeNet5

## NN training and testing

In [3]:
lr_init = 0.1
weight_decay = 0.0001
momentum = 0.9
epochs = 15
device = "cpu"

In [4]:
lr_annealing_rate = 10
lr_annealing_schedule = [10, 12]

In [5]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=lr_init, momentum=momentum, weight_decay=weight_decay)

In [5]:
tr_loss, tr_perf = train_net(net, epochs, criterion, optimizer, trainloader, device=device, lr_annealing_factor=lr_annealing_rate, epochs_annealing=lr_annealing_schedule)

===> Epoch 1/15 ### Loss 0.3291197659830252 ### Performance 0.8938
===> Epoch 2/15 ### Loss 0.06732262536038955 ### Performance 0.9794166666666667
===> Epoch 3/15 ### Loss 0.04838982847606142 ### Performance 0.9854


KeyboardInterrupt: 

In [6]:
te_loss, te_perf = test_net(net, testloader, criterion, device=device)

===> TEST ### Loss 2.3029794692993164 ### Performance 0.0972


#### Optional: save the NN

In the next cell we save the state_dict along with a number of auxiliary data (train/test loss/performance) in a dictionary called `save_dict`. We save this dict in a `save_root` which we will use also as a base for the IMP checkpoints.

The `save_dict` mimics the structure of IMP's checkpoint (minus the pruning mask, which is absent in the case of the complete model).

In [7]:
save_dict = {
    "train_loss": tr_loss,
    "train_perf": tr_perf,
    "test_loss": te_loss,
    "test_perf": te_perf,
    "parameters": net.state_dict()
}

NameError: name 'tr_loss' is not defined

In [6]:
save_root = "../models/LeNet5"
save_name = "complete_net.pt"

In [9]:
torch.save(save_dict, os.path.join(save_root, save_name))

NameError: name 'save_dict' is not defined

In [7]:
# load the NN
net.load_state_dict(torch.load(os.path.join(save_root, save_name))["parameters"])

<All keys matched successfully>

## Store complete model's representation

In [None]:
# get dataloader for representation w/ Train set as False
reprloader, _ = datasets.MNIST("../data", 128, train=False, num_workers=4)
compl = net.forward_with_hooks() # TODO: limit datapoints

## IMP application

Tip: we can help ourselves with `net.state_dict().keys()` to enucleate the layers names which we'll be pruning

In [8]:
net.state_dict().keys()

odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.2.weight', 'classifier.2.bias', 'classifier.4.weight', 'classifier.4.bias'])

For this example, we wish to prune all of the weights & biases of the conv layers + the f-c layers minus the last one.
The selection is operated on the keys of the state_dict compared using regex:

* we can use a generic `"features"` pattern to catch all of the conv layers
* we use a more specific regex `"classifier\.[02]\."` to catch the first 2 f-c layers (but not the last one which has ID `4` in the state_dict keys)

Note: if we had BatchNorm layers in `features`, we should be more careful in indicating the layers to prune, similarly to what we did in `classifier`

In [9]:
layers_to_prune = ["features", r"classifier\.[02]\."]

In [10]:
imp_iterations = 2
save_path="../models/LeNet5/IMP" # checkpoints will be saved in this folder as IMP_checkpoint_n.pt, where `n` is the iteration number
pruning_rate = 0.5
imp_lrr(net, epochs, criterion, optimizer, trainloader, imp_iterations, device=device, testloader=testloader, save_path=save_path, layer_ids_to_prune=layers_to_prune, pruning_factor=pruning_rate)

=====> Iteration of IMP: 1/2
Proportion of parameters in mask 0.5095664700850853
===> Epoch 1/15 ### Loss 0.01041195536716841 ### Performance 0.99655
===> Epoch 2/15 ### Loss 0.012068730958690866 ### Performance 0.9958833333333333
===> Epoch 3/15 ### Loss 0.008705891965283081 ### Performance 0.9970666666666667
===> Epoch 4/15 ### Loss 0.011660329529092026 ### Performance 0.9962833333333333
===> Epoch 5/15 ### Loss 0.008792778549799308 ### Performance 0.99715
===> Epoch 6/15 ### Loss 0.00852779282204962 ### Performance 0.99725
===> Epoch 7/15 ### Loss 0.009600953032627391 ### Performance 0.9970666666666667
===> Epoch 8/15 ### Loss 0.010967730930063408 ### Performance 0.9962666666666666
===> Epoch 9/15 ### Loss 0.00826772124649336 ### Performance 0.9974666666666666
===> Epoch 10/15 ### Loss 0.006784967440033021 ### Performance 0.9976666666666667
===> Epoch 11/15 ### Loss 0.007287115667760372 ### Performance 0.9976
===> Epoch 12/15 ### Loss 0.007883973504335154 ### Performance 0.997366666