# Model Sparsification Tutorial
This tutorial provides a walk through with the following key steps:
1. prepare and load data
2. prepare and load model
3. train model (may be skipped if pre-trained)
4. prunning model

advanced-user walkthroughs:

5. assesing FLOPS per layer
6. targeting sparsity - sensitivity analysis
7. assesing learining rate sensitivity

## load Data
this example uses imagenette/imagewoof datasets from fast.ai provided under the [Apache License 2.0](https://github.com/fastai/imagenette/blob/master/LICENSE)

In [1]:
from neuralmagicML.datasets import *
from torch.utils.data import Dataset, DataLoader

dataset_type = 'imagenette'
dataset_path = '../data/imagenette-160/'

device = 'cpu'
 #device to run on: 'cpu' / 'cuda:0'
    
train_batch_size = 128
test_batch_size = 256
dataset_early_stop = -1

train_dataset = ImagenetteDataset(dataset_path, train=True, rand_trans=True)
train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=8)

val_dataset = ImagenetteDataset(dataset_path, train=False, rand_trans=False)
val_data_loader = DataLoader(val_dataset, batch_size=test_batch_size, shuffle=True, num_workers=8)

train_test_dataset = ImagenetteDataset(dataset_path, train=True, rand_trans=False)
train_test_data_loader = DataLoader(train_test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=8)




already downloaded imagenette of size ImagenetteSize.s160
already downloaded imagenette of size ImagenetteSize.s160
already downloaded imagenette of size ImagenetteSize.s160


## Let's get familiar with this data-set a bit:
The authors, much like ourselves, were interested in a dataset that has similar properties to the 'imagenet' dataset but would allow rapid iterations. For more details about the degree of 'similarity' see [fastai/imagenette](https://github.com/fastai/imagenette).

It includes 10 classes out of the imagenete 1000:
tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute.

This is **not** a replacement for the actual dataset at hand (e.g. imagenet) - we here use it for demonstration purposes only owing to it's much accelerated training time and ease of testing for (almost) identical models. switching to imagenet is as simple as changing the dataset argument (and corresponding model)

With that in mind let us take a closer look at the actual data (as feeded to the model):

In [2]:
import matplotlib.pyplot as plt
import torch

def renormalize(tens: torch.Tensor) -> torch.Tensor :
    max = tens.max()
    min = tens.min()
    return (tens - min)/(max - min)

data_loader = train_data_loader

labels = {}
samples = {}
samples_collected = {}
samples_to_collect = 3
for batch, (*x_feature, y_lab) in enumerate(data_loader):
    y_lab = y_lab.to(device)
    x_feature = tuple([dat.to(device) for dat in x_feature])
    batch_size = y_lab.shape[0]
    for image_idx in range(batch_size):
        label = y_lab[image_idx].item()
        if label not in [ll for ll in labels.keys()]:
            labels[label] = 1
            samples[label] = [x_feature[0][image_idx]]
            samples_collected[label] = 1
        else:
            labels[label] += 1
            if samples_collected[label] < samples_to_collect:
                samples[label].append(x_feature[0][image_idx])
                samples_collected[label] += 1
#                 print(x_feature[0][image_idx])
#                 assert(False)

n_labels = len(labels.keys())
plt.figure(figsize=(12, 30))
tot_images = n_labels*samples_to_collect

print('Total number of calsess: {}'.format(n_labels))
print('number of images per class: ')
for label, num in labels.items():
    print('label: {}, number of images: {}'.format(label, num))

print('sample images: row = class label, column = sample')
for idx, label in enumerate(labels.keys()):
    for sample in range(len(samples[label])):
        plt.subplot(n_labels,samples_to_collect,idx*samples_to_collect + sample + 1)
        plt.imshow(renormalize(samples[label][sample]).permute(1,2,0))



Total number of calsess: 10
number of images per class: 
label: 3, number of images: 1194
label: 8, number of images: 1300
label: 9, number of images: 1300
label: 4, number of images: 1300
label: 7, number of images: 1300
label: 6, number of images: 1300
label: 0, number of images: 1300
label: 5, number of images: 1300
label: 2, number of images: 1300
label: 1, number of images: 1300
sample images: row = class label, column = sample


## Define model
Let us initialize the standard ResNet-50 Model, with the modification of changing the classification (FC) layer such that it accomodates 10 classes (instead of 1000)

In [3]:
from neuralmagicML.models import *
model = resnet50(num_classes = 10, pretrained = False)


## Define loss and optimizer

In [4]:
from torch import optim
from neuralmagicML.utils import CrossEntropyLossCalc, TopKAccuracy

#### # optimizer definitions:
learning_rate = 0.1
lr_decay = 0.2 # divide lr by 5 every 'lr_decay_rate' epochs
lr_decay_rate = 30
momentum = 0.9
weight_decay = 1e-4

print('Creating optimizer with initial lr: {}, momentum: {}, weight decay: {}'
      .format(learning_rate, momentum, weight_decay))

optimizer = optim.SGD(
    model.parameters(), learning_rate, momentum=momentum, weight_decay=weight_decay, nesterov=True)

loss_extras = {
    'top1acc': TopKAccuracy(1),
    'top5acc': TopKAccuracy(5)
}
loss_calc = CrossEntropyLossCalc(loss_extras)
print('Created loss calc {} with extras {}'.format(loss_calc, ', '.join(loss_extras.keys())))


def _adjust_learning_rate(optimizer: optim.SGD, epoch: int, init_learning_rate: float,
                          lr_decay: float, lr_decay_rate: int):
    lr = init_learning_rate * (lr_decay ** ((epoch + 1) // lr_decay_rate))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr



Creating optimizer with initial lr: 0.1, momentum: 0.9, weight decay: 0.0001
Created loss calc <neuralmagicML.utils.loss_calc.CrossEntropyLossCalc object at 0x11f2a7150> with extras top1acc, top5acc


## Logging
beyond the usual basic screen-printouts let's use tensorboard's nice logging capabilities.
We'll primarily use track scalars such as the loss and accuracy throughout training in this example.

In [5]:
from tensorboardX import SummaryWriter

## Train

In [6]:
from torch.utils.data import DataLoader

from typing import Tuple, Dict
from torch import Tensor
import torch

def _test_datasets(model, train_data_loader: DataLoader, val_data_loader: DataLoader,
                   writer: SummaryWriter, epoch: int) -> Tuple[Dict[str, float], Dict[str, float]]:
    print('Running test for validation dataset for epoch {}'.format(epoch))
    val = test_epoch(model, val_data_loader, loss_calc, device, epoch)
    print('Completed test for validation dataset for epoch {}'.format(epoch))

    print('Running test for train dataset for epoch {}'.format(epoch))
    train = test_epoch(model, train_data_loader, loss_calc, device, epoch)
    print('Completed test for train dataset for epoch {}'.format(epoch))

    val_losses = {}
    train_losses = {}

    for loss, _ in val.items():
        val_losses[loss] = torch.mean(torch.cat(val[loss])).item()
        val_tag = 'Test/validation/{}'.format(loss)
        writer.add_scalar(val_tag, val_losses[loss], epoch)
        train_losses[loss] = torch.mean(torch.cat(train[loss])).item()
        train_tag = 'Test/training/{}'.format(loss)
        writer.add_scalar(train_tag, train_losses[loss], epoch)
    val_loss_str = 'validation set - epoch: {} '.format(epoch)
    for loss, value in val_losses.items():
        val_loss_str += (loss + ': {0:.2f} '.format(value))
    print(val_loss_str)
    train_loss_str = 'training set - epoch: {} '.format(epoch)
    for loss, value in train_losses.items():
        train_loss_str += (loss + ': {0:.2f} '.format(value))
    print(train_loss_str)


    return val_losses, train_losses


def test_epoch(model: torch.nn.Module, data_loader: DataLoader, loss, device, epoch: int) -> Dict:
    model.eval()
    results = {}#ModuleTestResults()
    with torch.no_grad():
        for batch, (*x_feature, y_lab) in enumerate(data_loader):
            y_lab = y_lab.to(device)
            x_feature = tuple([dat.to(device) for dat in x_feature])
            batch_size = y_lab.shape[0]
            
            y_pred = model(*x_feature)

            losses = loss(x_feature, y_lab, y_pred)  # type: Dict[str, Tensor]
            for key, val in losses.items():
                if key not in results:
                    results[key] = []

                result = val.detach_().cpu()
                result = result.repeat(batch_size) #repeat tensor so that there is no dependency on batch size
                results[key].append(result)
#             results.append(losses, batch_size)
    return results

def train_epoch(model: torch.nn.Module, data_loader: DataLoader, optimizer, loss, device, data_counter: int):
    model.train()
    
    for batch, (*x_feature, y_lab) in enumerate(data_loader):
        # copy next batch to the device we are using
        y_lab = y_lab.to(device)
        x_feature = tuple([dat.to(device) for dat in x_feature])
        batch_size = y_lab.shape[0]

        # Zero the parameter gradients
        optimizer.zero_grad()

        # forward 
        y_pred = model(*x_feature)
        
        # update losses
        losses = loss(x_feature, y_lab, y_pred)  # type: Dict[str, Tensor]
        
        # backward
        losses['loss'].backward()
        
        # take SGD step
        optimizer.step(closure=None)
        
        # log loss and accuracy
        data_counter += batch_size
        for _loss, _value in losses.items():
            writer.add_scalar('Train/{}'.format(_loss), _value.item(), data_counter)




In [7]:
logs_dir = './logs'
model_dir = './saved_models'
print('Creating summary writer in {}'.format(logs_dir))
writer = SummaryWriter(logdir=logs_dir, comment='imagenette training')

print('Running baseline test...')
# _test_datasets(model, train_test_data_loader, val_data_loader, writer, epoch=-1)

print('Training model')
num_epochs = 1
data_counter = 0

for epoch in range(num_epochs):
    print('Starting epoch {}'.format(epoch))
    _adjust_learning_rate(optimizer, epoch, learning_rate, lr_decay, lr_decay_rate)
    train_epoch(model, train_data_loader, optimizer, loss_calc, device, data_counter)
    del data_loader
    val_losses, train_losses = _test_datasets(model, train_test_data, val_data, writer, epoch)

    if save_rate > 0 and epoch % save_rate == 0:
        save_path = os.path.join(model_dir, '{}-epoch={:03d}-val={:.4f}.pth'
                                 .format(model_name, epoch, val_losses['loss']))
        save_model(save_path, model, optimizer, epoch)
        print('saved model checkpoint at {}'.format(save_path))

scalars_json_path = os.path.join(logs_dir, 'all_scalars.json')
writer.export_scalars_to_json(scalars_json_path)
writer.close()

save_path = os.path.join(model_dir, '{}-trained.pth'.format(model_type))
print('Finished training, saving model to {}'.format(save_path))
save_model(save_path, model)
print('Saved model')




Creating summary writer in ./logs
Running baseline test...
Training model
Starting epoch 0


Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/Cellar/python/3.7.4/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/local/Cellar/python/3.7.4/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/Cellar/python/3.7.4/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/local/Cellar/python/3.7.4/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
  File "/usr/local/Cellar/python/3.7.4/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiproces

KeyboardInterrupt: 