In [1]:
import torch
import sys

sys.path.append("..") # so we can import the layer_sim library

from layer_sim import networks
from layer_sim import datasets
from layer_sim.train import train_net, test_net
from layer_sim.pruning.IMP import imp_lrr

## Dataset and NN preparation

In [8]:
train_batch = 128
test_batch = 128
trainloader, testloader = datasets.MNIST("../data", train_batch, test_batch, num_workers=4)
net = networks.LeNet5(num_classes=10)

### TODO: insert image of LeNet5

## NN training and testing

In [6]:
lr_init = 0.1
weight_decay = 0.0001
momentum = 0.9
epochs = 15
device = "cpu"

In [7]:
lr_annealing_rate = 10
lr_annealing_schedule = [10, 12]

In [10]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=lr_init, momentum=momentum, weight_decay=weight_decay)
train_net(net, epochs, criterion, optimizer, trainloader, device=device, lr_annealing_factor=lr_annealing_rate, epochs_annealing=lr_annealing_schedule)

===> Epoch 1/15 ### Loss 0.32814517916639646 ### Performance 0.8958666666666667
===> Epoch 2/15 ### Loss 0.07053277128438155 ### Performance 0.9790166666666666
===> Epoch 3/15 ### Loss 0.048630800982316334 ### Performance 0.9855166666666667
===> Epoch 4/15 ### Loss 0.04039812043358882 ### Performance 0.98765
===> Epoch 5/15 ### Loss 0.03454992857140799 ### Performance 0.9894166666666667
===> Epoch 6/15 ### Loss 0.03000268583421906 ### Performance 0.9911
===> Epoch 7/15 ### Loss 0.02701919445786625 ### Performance 0.9917166666666667
===> Epoch 8/15 ### Loss 0.02189559004077067 ### Performance 0.99315
===> Epoch 9/15 ### Loss 0.022885738459710654 ### Performance 0.9929833333333333
===> Epoch 10/15 ### Loss 0.01977641623740395 ### Performance 0.9937833333333334
===> Epoch 11/15 ### Loss 0.02014222490272174 ### Performance 0.9936833333333334
LR annealed: previous 0.1, current 0.01
===> Epoch 12/15 ### Loss 0.007557284027454443 ### Performance 0.99775
===> Epoch 13/15 ### Loss 0.00476630337

(0.0036606240635427335, 0.9991166666666667)

In [11]:
test_net(net, testloader, criterion, device=device)

===> TEST ### Loss 0.027738256379961967 ### Performance 0.9924


(tensor(0.0277), 0.9924)

## Store complete model's representation

In [None]:
# get dataloader for representation w/ Train set as False
reprloader, _ = datasets.MNIST("../data", 128, train=False, num_workers=4)
compl = net.forward_with_hooks() # TODO: limit datapoints

## IMP application

Tip: we can help ourselves with `net.state_dict().keys()` to enucleate the layers names which we'll be pruning

In [13]:
net.state_dict().keys()

odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.2.weight', 'classifier.2.bias', 'classifier.4.weight', 'classifier.4.bias'])

For this example, we wish to prune all of the weights & biases of the conv layers + the f-c layers minus the last one.
The selection is operated on the keys of the state_dict compared using regex:

* we can use a generic `"features"` pattern to catch all of the conv layers
* we use a more specific regex `"classifier\.[02]\."` to catch the first 2 f-c layers (but not the last one which has ID `4` in the state_dict keys)

Note: if we had BatchNorm layers, we should be more careful in indicating the layers to prune since, using `torch.nn.Sequential`s inside the NN, the name of the BN parameters is usually associated to the closest conv layer

In [15]:
layers_to_prune = ["features", r"classifier\.[02]\."]

In [16]:
imp_iterations = 2
save_path="../models/LeNet5/IMP" # checkpoints will be saved in this folder as IMP_checkpoint_n.pt, where `n` is the iteration number
imp_lrr(net, epochs, criterion, optimizer, trainloader, imp_iterations, device=device, testloader=testloader, save_path=save_path, layer_ids_to_prune=layers_to_prune)

=====> Iteration of IMP: 1/2


SyntaxError: invalid syntax (<string>, line 1)