# Experiment Tutorial

Check we're using the correct python

In [None]:
import sys
sys.executable

In [None]:
from shrinkbench.experiment import PruningExperiment

The `DATAPATH` and `WEIGHTSPATH` environment variables are used to tell the framework where to look for datasets and pretrained weights respectively.

MNIST is not in [hub](https://pytorch.org/docs/stable/hub.html). You can get pretrained models from other sources ([e.g.](https://github.com/csinva/gan-vae-pretrained-pytorch/blob/master/mnist_classifier/weights/lenet_epoch%3D12_test_acc%3D0.991.pth)), but that one doesn't work, as it's a different model.

Pretrained weights might work with a different model that _is_ in hub, or you can just say `pretrained = False` when going through the strategies below.

In [None]:
import os
pwd = os.getcwd() 
data_path = pwd+'/../../data'
# weights_path =  '' # for pretrained models

os.environ['DATAPATH'] = data_path
#os.environ['WEIGHTSPATH'] = weights_path

In [None]:
import torchvision.datasets as datasets
datasets.MNIST(data_path, train=True, download=True)

In [None]:
from IPython.display import clear_output
clear_output()

In [None]:
#get weird errors ifd something fails, this is to clean up if need be.
import shutil
shutil.rmtree('results')

We run experiments for our MNIST network for logarithmically spaced compression ratios

In [None]:
for strategy in ['RandomPruning', 'GlobalMagWeight', 'LayerMagWeight']:
    for  c in [1,2,4,8,16,32,64]:   
        exp = PruningExperiment(dataset='MNIST', 
                                model='MnistNet',
                                pretrained=False,
                                strategy=strategy,
                                compression=c,
                                train_kwargs={'epochs':10})
        exp.run()
        clear_output()

We then collect output from experiment folders and plot the different metrics easily

In [None]:
from shrinkbench.plot import df_from_results, plot_df

In [None]:
df = df_from_results('results')

With the provided functions, it is easy to generate plots 

In [None]:
plot_df(df, 'compression', 'pre_acc5', markers='strategy', line='--', colors='strategy', suffix=' - pre')
plot_df(df, 'compression', 'post_acc5', markers='strategy', fig=False, colors='strategy')

We can also check the theoretical speedup and see that layerwise provides larger FLOPS speedups because of the even pruning of the conv layers

In [None]:
import numpy as np
import matplotlib.pyplot as plt
plot_df(df, 'speedup', 'post_acc5', colors='strategy', markers='strategy')
plt.ylim(0.5,0.999)
plt.xticks(2**np.arange(7))
plt.gca().set_xticklabels(map(str, 2**np.arange(7)))

We can easily check if the compression is matching our expectation by looking at the relative error. As expected, random pruning does worst.

In [None]:
df['compression_err'] = (df['real_compression'] - df['compression'])/df['compression']

In [None]:
plot_df(df, 'compression', 'compression_err', colors='strategy', markers='strategy')