In [1]:
from resnet20 import ResNetCIFAR
from lenet import LeNet5
from alexnet import AlexNet
from train_util import train, test, train_gsm_unstructured, train_gsm_structured
from summary import summary
import torch
import numpy as np
from final_pruning import final_unstruct_pruning, final_struct_pruning
import torch.nn as nn
import matplotlib.pyplot as plt

from torchprofile import profile_macs
from evaluate_util import compute_conv_flops

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
EPOCHS = 50

### Based model ResNet-56 training with SGD

In [3]:
net = ResNetCIFAR(num_layers=56)
net = net.to(device)

# Comment if you have pretrained weights
# train(net, epochs=EPOCHS, batch_size=128, lr=0.1, reg=1e-4, net_name = 'resnet_56_base.pt')

In [4]:
net.load_state_dict(torch.load("saved_models/resnet_56_base.pt"))
test(net)
summary(net)
compute_conv_flops(net, cuda=True, prune=True)

Files already downloaded and verified
Test Loss=0.3370, Test accuracy=0.9312
Layer id	Type		Parameter	Non-zero parameter	Sparsity(\%)
1		Convolutional_Param	864		864			0.000000
1		Convolutional_Filter	32		32			0.000000
2		BatchNorm	N/A		N/A			N/A
3		ReLU		N/A		N/A			N/A
4		Convolutional_Param	4608		4608			0.000000
4		Convolutional_Filter	16		16			0.000000
5		BatchNorm	N/A		N/A			N/A
6		ReLU		N/A		N/A			N/A
7		Convolutional_Param	2304		2304			0.000000
7		Convolutional_Filter	16		16			0.000000
8		BatchNorm	N/A		N/A			N/A
9		Convolutional_Param	512		512			0.000000
9		Convolutional_Filter	16		16			0.000000
10		BatchNorm	N/A		N/A			N/A
11		ReLU		N/A		N/A			N/A
12		Convolutional_Param	2304		2304			0.000000
12		Convolutional_Filter	16		16			0.000000
13		BatchNorm	N/A		N/A			N/A
14		ReLU		N/A		N/A			N/A
15		Convolutional_Param	2304		2304			0.000000
15		Convolutional_Filter	16		16			0.000000
16		BatchNorm	N/A		N/A			N/A
17		ReLU		N/A		N/A			N/A
18		Convolutional_Param	2304		2304			0.000000
18		

129073792.0

### Model Trained with Structured GSM SGD

In [5]:
NON_ZERO_RATIO = 0.5

In [6]:
net = ResNetCIFAR(num_layers=56)
net = net.to(device)
net.load_state_dict(torch.load("saved_models/resnet_56_base.pt"))

# Comment if you have loaded pretrained weights
# train_gsm_structured(net, epochs=EPOCHS, batch_size=64, lr=0.005, nonzero_ratio = NON_ZERO_RATIO, 
#                      reg=1e-4, 
#                      net_name = 'resnet_56_struct_gsm_before_pruning.pt')

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified

Epoch: 0
[Step=50]	Loss=1.2306	acc=0.6134	522.7 examples/second
[Step=100]	Loss=0.9848	acc=0.6830	558.2 examples/second
[Step=150]	Loss=0.8708	acc=0.7165	598.3 examples/second
[Step=200]	Loss=0.8006	acc=0.7384	651.3 examples/second
[Step=250]	Loss=0.7468	acc=0.7551	577.1 examples/second
[Step=300]	Loss=0.7101	acc=0.7667	666.8 examples/second
[Step=350]	Loss=0.6815	acc=0.7753	649.2 examples/second
[Step=400]	Loss=0.6615	acc=0.7810	650.1 examples/second
[Step=450]	Loss=0.6414	acc=0.7876	620.6 examples/second
[Step=500]	Loss=0.6214	acc=0.7937	686.6 examples/second
[Step=550]	Loss=0.6102	acc=0.7965	693.7 examples/second
[Step=600]	Loss=0.5973	acc=0.8008	698.0 examples/second
[Step=650]	Loss=0.5870	acc=0.8038	638.5 examples/second
[Step=700]	Loss=0.5784	acc=0.8069	717.2 examples/second
[Step=750]	Loss=0.5683	acc=0.8099	678.9 examples/second
Test Loss=0.5096, Test acc=0.8331
Saving...

Epoch: 1


[Step=6850]	Loss=0.2772	acc=0.9053	579.0 examples/second
[Step=6900]	Loss=0.2775	acc=0.9051	586.3 examples/second
[Step=6950]	Loss=0.2779	acc=0.9051	628.9 examples/second
[Step=7000]	Loss=0.2772	acc=0.9054	594.1 examples/second
Test Loss=0.3804, Test acc=0.8776
Saving...

Epoch: 9
[Step=7050]	Loss=0.2277	acc=0.9232	444.0 examples/second
[Step=7100]	Loss=0.2435	acc=0.9171	607.0 examples/second
[Step=7150]	Loss=0.2536	acc=0.9135	596.0 examples/second
[Step=7200]	Loss=0.2602	acc=0.9098	595.5 examples/second
[Step=7250]	Loss=0.2649	acc=0.9088	638.2 examples/second
[Step=7300]	Loss=0.2678	acc=0.9073	570.4 examples/second
[Step=7350]	Loss=0.2689	acc=0.9056	614.6 examples/second
[Step=7400]	Loss=0.2685	acc=0.9060	624.4 examples/second
[Step=7450]	Loss=0.2679	acc=0.9065	579.7 examples/second
[Step=7500]	Loss=0.2659	acc=0.9074	614.3 examples/second
[Step=7550]	Loss=0.2659	acc=0.9077	622.4 examples/second
[Step=7600]	Loss=0.2679	acc=0.9069	618.1 examples/second
[Step=7650]	Loss=0.2685	acc=0.9070

[Step=13600]	Loss=0.2304	acc=0.9212	589.7 examples/second
[Step=13650]	Loss=0.2302	acc=0.9210	594.3 examples/second
[Step=13700]	Loss=0.2278	acc=0.9215	624.6 examples/second
[Step=13750]	Loss=0.2280	acc=0.9217	634.8 examples/second
[Step=13800]	Loss=0.2289	acc=0.9214	601.4 examples/second
[Step=13850]	Loss=0.2295	acc=0.9211	584.8 examples/second
[Step=13900]	Loss=0.2313	acc=0.9206	628.4 examples/second
[Step=13950]	Loss=0.2336	acc=0.9198	570.2 examples/second
[Step=14000]	Loss=0.2326	acc=0.9201	586.7 examples/second
[Step=14050]	Loss=0.2324	acc=0.9199	581.5 examples/second
Test Loss=0.3824, Test acc=0.8824

Epoch: 18
[Step=14100]	Loss=0.2307	acc=0.9121	401.9 examples/second
[Step=14150]	Loss=0.2213	acc=0.9183	680.6 examples/second
[Step=14200]	Loss=0.2213	acc=0.9195	641.5 examples/second
[Step=14250]	Loss=0.2213	acc=0.9206	613.4 examples/second
[Step=14300]	Loss=0.2202	acc=0.9213	632.8 examples/second
[Step=14350]	Loss=0.2207	acc=0.9218	657.2 examples/second
[Step=14400]	Loss=0.2215	ac

KeyboardInterrupt: 

In [7]:
# net.load_state_dict(torch.load("saved_models/resnet_56_struct_gsm_before_pruning.pt"))
# final_struct_pruning(net, nonzero_ratio = NON_ZERO_RATIO, 
#                      net_name = "resnet_56_struct_gsm_after_pruning.pt")

In [8]:
net.load_state_dict(torch.load("saved_models/resnet_56_struct_gsm_after_pruning.pt"))
test(net)
summary(net)
compute_conv_flops(net, cuda=True, prune=True)

Files already downloaded and verified
Test Loss=0.3488, Test accuracy=0.8872
Layer id	Type		Parameter	Non-zero parameter	Sparsity(\%)
1		Convolutional_Param	864		729			0.156250
1		Convolutional_Filter	32		27			0.156250
2		BatchNorm	N/A		N/A			N/A
3		ReLU		N/A		N/A			N/A
4		Convolutional_Param	4608		3168			0.312500
4		Convolutional_Filter	16		11			0.312500
5		BatchNorm	N/A		N/A			N/A
6		ReLU		N/A		N/A			N/A
7		Convolutional_Param	2304		2304			0.000000
7		Convolutional_Filter	16		16			0.000000
8		BatchNorm	N/A		N/A			N/A
9		Convolutional_Param	512		512			0.000000
9		Convolutional_Filter	16		16			0.000000
10		BatchNorm	N/A		N/A			N/A
11		ReLU		N/A		N/A			N/A
12		Convolutional_Param	2304		2016			0.125000
12		Convolutional_Filter	16		14			0.125000
13		BatchNorm	N/A		N/A			N/A
14		ReLU		N/A		N/A			N/A
15		Convolutional_Param	2304		2160			0.062500
15		Convolutional_Filter	16		15			0.062500
16		BatchNorm	N/A		N/A			N/A
17		ReLU		N/A		N/A			N/A
18		Convolutional_Param	2304		2016			0.125000
18		

77547126.0