In [1]:
from resnet20 import ResNetCIFAR
from lenet import LeNet5
from alexnet import AlexNet
from train_util import train, test, train_gsm_unstructured, train_gsm_structured
from summary import summary
import torch
import numpy as np
from final_pruning import final_unstruct_pruning, final_struct_pruning
import torch.nn as nn
import matplotlib.pyplot as plt

from torchprofile import profile_macs
from evaluate_util import compute_conv_flops

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
EPOCHS = 100

### Based model ResNet-56 training with SGD

In [3]:
net = ResNetCIFAR(num_layers=56)
net = net.to(device)

# Comment if you have pretrained weights
# train(net, epochs=EPOCHS, batch_size=128, lr=0.1, reg=1e-4, net_name = 'resnet_56_base.pt')

In [4]:
net.load_state_dict(torch.load("saved_models/resnet_56_base.pt"))
test(net)
summary(net)
compute_conv_flops(net, cuda=True, prune=True)

Files already downloaded and verified
Test Loss=0.3370, Test accuracy=0.9312
Layer id	Type		Parameter	Non-zero parameter	Sparsity(\%)
1		Convolutional_Param	864		864			0.000000
1		Convolutional_Filter	32		32			0.000000
2		BatchNorm	N/A		N/A			N/A
3		ReLU		N/A		N/A			N/A
4		Convolutional_Param	4608		4608			0.000000
4		Convolutional_Filter	16		16			0.000000
5		BatchNorm	N/A		N/A			N/A
6		ReLU		N/A		N/A			N/A
7		Convolutional_Param	2304		2304			0.000000
7		Convolutional_Filter	16		16			0.000000
8		BatchNorm	N/A		N/A			N/A
9		Convolutional_Param	512		512			0.000000
9		Convolutional_Filter	16		16			0.000000
10		BatchNorm	N/A		N/A			N/A
11		ReLU		N/A		N/A			N/A
12		Convolutional_Param	2304		2304			0.000000
12		Convolutional_Filter	16		16			0.000000
13		BatchNorm	N/A		N/A			N/A
14		ReLU		N/A		N/A			N/A
15		Convolutional_Param	2304		2304			0.000000
15		Convolutional_Filter	16		16			0.000000
16		BatchNorm	N/A		N/A			N/A
17		ReLU		N/A		N/A			N/A
18		Convolutional_Param	2304		2304			0.000000
18		

129073792.0

### Model Trained with Structured GSM SGD

In [5]:
NON_ZERO_RATIO = 0.5

In [6]:
net = ResNetCIFAR(num_layers=56)
net = net.to(device)
net.load_state_dict(torch.load("saved_models/resnet_56_base.pt"))

# Comment if you have loaded pretrained weights
train_gsm_structured(net, epochs=EPOCHS, batch_size=64, lr=0.005, nonzero_ratio = NON_ZERO_RATIO, 
                     reg=1e-4, 
                     net_name = 'resnet_56_struct_gsm_before_pruning.pt')

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified

Epoch: 0
[Step=50]	Loss=2.6674	acc=0.1220	1099.6 examples/second
[Step=100]	Loss=2.5845	acc=0.1275	1292.6 examples/second
[Step=150]	Loss=2.4998	acc=0.1421	1297.3 examples/second
[Step=200]	Loss=2.4191	acc=0.1558	1271.1 examples/second
[Step=250]	Loss=2.3580	acc=0.1661	1227.7 examples/second
[Step=300]	Loss=2.3108	acc=0.1737	1237.3 examples/second
[Step=350]	Loss=2.2739	acc=0.1793	1299.0 examples/second
Test Loss=2.1173, Test acc=0.2056
Saving...

Epoch: 1
[Step=400]	Loss=1.9425	acc=0.2431	821.3 examples/second
[Step=450]	Loss=1.9113	acc=0.2663	1273.4 examples/second
[Step=500]	Loss=1.8956	acc=0.2724	1425.9 examples/second
[Step=550]	Loss=1.8833	acc=0.2787	1321.1 examples/second
[Step=600]	Loss=1.8655	acc=0.2842	1260.8 examples/second
[Step=650]	Loss=1.8499	acc=0.2909	1532.1 examples/second
[Step=700]	Loss=1.8340	acc=0.2968	1327.3 examples/second
[Step=750]	Loss=1.8163	acc=0.3035	1335.5 ex

[Step=6400]	Loss=0.5344	acc=0.8182	1402.4 examples/second
[Step=6450]	Loss=0.5420	acc=0.8158	1344.0 examples/second
[Step=6500]	Loss=0.5408	acc=0.8151	1274.5 examples/second
[Step=6550]	Loss=0.5430	acc=0.8151	1370.9 examples/second
[Step=6600]	Loss=0.5438	acc=0.8154	1375.1 examples/second
Test Loss=0.8348, Test acc=0.7350

Epoch: 17
[Step=6650]	Loss=0.5269	acc=0.8203	876.8 examples/second
[Step=6700]	Loss=0.5248	acc=0.8218	1355.2 examples/second
[Step=6750]	Loss=0.5244	acc=0.8211	1324.0 examples/second
[Step=6800]	Loss=0.5317	acc=0.8186	1311.3 examples/second
[Step=6850]	Loss=0.5306	acc=0.8180	1269.3 examples/second
[Step=6900]	Loss=0.5317	acc=0.8161	1400.9 examples/second
[Step=6950]	Loss=0.5280	acc=0.8175	1377.2 examples/second
[Step=7000]	Loss=0.5304	acc=0.8170	1356.6 examples/second
Test Loss=0.6036, Test acc=0.7968
Saving...

Epoch: 18
[Step=7050]	Loss=0.4927	acc=0.8288	801.7 examples/second
[Step=7100]	Loss=0.4981	acc=0.8300	1259.7 examples/second
[Step=7150]	Loss=0.5085	acc=0.82

[Step=12800]	Loss=0.3936	acc=0.8639	1279.0 examples/second
[Step=12850]	Loss=0.3973	acc=0.8629	1422.5 examples/second
[Step=12900]	Loss=0.3957	acc=0.8632	1424.2 examples/second
Test Loss=0.4487, Test acc=0.8521
Saving...

Epoch: 33
[Step=12950]	Loss=0.3788	acc=0.8675	906.2 examples/second
[Step=13000]	Loss=0.3701	acc=0.8716	1357.2 examples/second
[Step=13050]	Loss=0.3751	acc=0.8710	1230.3 examples/second
[Step=13100]	Loss=0.3791	acc=0.8691	1224.4 examples/second
[Step=13150]	Loss=0.3808	acc=0.8691	1206.6 examples/second
[Step=13200]	Loss=0.3823	acc=0.8683	1489.0 examples/second
[Step=13250]	Loss=0.3831	acc=0.8686	1394.3 examples/second
Test Loss=0.5781, Test acc=0.8201

Epoch: 34
[Step=13300]	Loss=0.3584	acc=0.8776	864.5 examples/second
[Step=13350]	Loss=0.3654	acc=0.8707	1315.5 examples/second
[Step=13400]	Loss=0.3680	acc=0.8717	1294.2 examples/second
[Step=13450]	Loss=0.3648	acc=0.8734	1387.0 examples/second
[Step=13500]	Loss=0.3709	acc=0.8715	1406.0 examples/second
[Step=13550]	Loss

[Step=19150]	Loss=0.3299	acc=0.8863	1380.2 examples/second
Test Loss=0.5620, Test acc=0.8319

Epoch: 49
[Step=19200]	Loss=0.3103	acc=0.8923	863.0 examples/second
[Step=19250]	Loss=0.3220	acc=0.8915	1329.6 examples/second
[Step=19300]	Loss=0.3237	acc=0.8907	1294.9 examples/second
[Step=19350]	Loss=0.3219	acc=0.8905	1300.1 examples/second
[Step=19400]	Loss=0.3271	acc=0.8890	1238.4 examples/second
[Step=19450]	Loss=0.3327	acc=0.8866	1238.1 examples/second
[Step=19500]	Loss=0.3332	acc=0.8861	1218.5 examples/second
[Step=19550]	Loss=0.3333	acc=0.8854	1257.6 examples/second
Test Loss=0.4812, Test acc=0.8501

Epoch: 50
[Step=19600]	Loss=0.2642	acc=0.9103	891.7 examples/second
[Step=19650]	Loss=0.2543	acc=0.9144	1311.3 examples/second
[Step=19700]	Loss=0.2481	acc=0.9161	1276.8 examples/second
[Step=19750]	Loss=0.2402	acc=0.9191	1422.8 examples/second
[Step=19800]	Loss=0.2362	acc=0.9203	1407.2 examples/second
[Step=19850]	Loss=0.2325	acc=0.9212	1326.3 examples/second
[Step=19900]	Loss=0.2314	ac

Test Loss=0.3538, Test acc=0.8944

Epoch: 65
[Step=25450]	Loss=0.1412	acc=0.9491	847.4 examples/second
[Step=25500]	Loss=0.1346	acc=0.9515	1254.0 examples/second
[Step=25550]	Loss=0.1382	acc=0.9508	1313.6 examples/second
[Step=25600]	Loss=0.1379	acc=0.9511	1357.8 examples/second
[Step=25650]	Loss=0.1390	acc=0.9510	1297.5 examples/second
[Step=25700]	Loss=0.1384	acc=0.9513	1330.6 examples/second
[Step=25750]	Loss=0.1377	acc=0.9511	1281.0 examples/second
[Step=25800]	Loss=0.1380	acc=0.9510	1261.4 examples/second
Test Loss=0.3558, Test acc=0.8989

Epoch: 66
[Step=25850]	Loss=0.1266	acc=0.9506	951.9 examples/second
[Step=25900]	Loss=0.1310	acc=0.9529	1294.8 examples/second
[Step=25950]	Loss=0.1333	acc=0.9530	1230.7 examples/second
[Step=26000]	Loss=0.1344	acc=0.9524	1241.9 examples/second
[Step=26050]	Loss=0.1383	acc=0.9512	1254.5 examples/second
[Step=26100]	Loss=0.1365	acc=0.9517	1243.9 examples/second
[Step=26150]	Loss=0.1359	acc=0.9522	1213.9 examples/second
Test Loss=0.3706, Test acc=

[Step=31750]	Loss=0.0941	acc=0.9683	1293.8 examples/second
[Step=31800]	Loss=0.0927	acc=0.9687	1284.0 examples/second
[Step=31850]	Loss=0.0946	acc=0.9677	1281.7 examples/second
[Step=31900]	Loss=0.0963	acc=0.9671	1307.0 examples/second
[Step=31950]	Loss=0.0993	acc=0.9660	1239.8 examples/second
[Step=32000]	Loss=0.0996	acc=0.9661	1224.7 examples/second
[Step=32050]	Loss=0.0984	acc=0.9666	1257.8 examples/second
Test Loss=0.3681, Test acc=0.8993

Epoch: 82
[Step=32100]	Loss=0.1018	acc=0.9653	875.6 examples/second
[Step=32150]	Loss=0.1006	acc=0.9664	1286.7 examples/second
[Step=32200]	Loss=0.1016	acc=0.9653	1340.5 examples/second
[Step=32250]	Loss=0.0985	acc=0.9663	1258.1 examples/second
[Step=32300]	Loss=0.0999	acc=0.9657	1268.2 examples/second
[Step=32350]	Loss=0.0991	acc=0.9664	1287.8 examples/second
[Step=32400]	Loss=0.0991	acc=0.9660	1166.7 examples/second
[Step=32450]	Loss=0.0983	acc=0.9663	1244.5 examples/second
Test Loss=0.3682, Test acc=0.8999
Saving...

Epoch: 83
[Step=32500]	Los

[Step=38100]	Loss=0.0871	acc=0.9699	1245.5 examples/second
[Step=38150]	Loss=0.0879	acc=0.9694	1395.1 examples/second
[Step=38200]	Loss=0.0873	acc=0.9702	1274.8 examples/second
[Step=38250]	Loss=0.0886	acc=0.9693	1312.8 examples/second
[Step=38300]	Loss=0.0909	acc=0.9683	1288.3 examples/second
Test Loss=0.3789, Test acc=0.8981

Epoch: 98
[Step=38350]	Loss=0.0926	acc=0.9700	877.1 examples/second
[Step=38400]	Loss=0.0925	acc=0.9684	1419.5 examples/second
[Step=38450]	Loss=0.0888	acc=0.9702	1337.4 examples/second
[Step=38500]	Loss=0.0892	acc=0.9697	1244.5 examples/second
[Step=38550]	Loss=0.0902	acc=0.9698	1206.5 examples/second
[Step=38600]	Loss=0.0883	acc=0.9707	1247.6 examples/second
[Step=38650]	Loss=0.0893	acc=0.9706	1198.8 examples/second
[Step=38700]	Loss=0.0898	acc=0.9700	1208.6 examples/second
Test Loss=0.3800, Test acc=0.8980

Epoch: 99
[Step=38750]	Loss=0.0872	acc=0.9718	809.4 examples/second
[Step=38800]	Loss=0.0867	acc=0.9719	1261.2 examples/second
[Step=38850]	Loss=0.0898	ac

In [7]:
net.load_state_dict(torch.load("saved_models/resnet_56_struct_gsm_before_pruning.pt"))
final_struct_pruning(net, nonzero_ratio = NON_ZERO_RATIO, 
                     net_name = "resnet_56_struct_gsm_after_pruning.pt")

In [8]:
net.load_state_dict(torch.load("saved_models/resnet_56_struct_gsm_after_pruning.pt"))
test(net)
summary(net)
compute_conv_flops(net, cuda=True, prune=True)

Files already downloaded and verified
Test Loss=0.3678, Test accuracy=0.8996
Layer id	Type		Parameter	Non-zero parameter	Sparsity(\%)
1		Convolutional_Param	864		783			0.093750
1		Convolutional_Filter	32		29			0.093750
2		BatchNorm	N/A		N/A			N/A
3		ReLU		N/A		N/A			N/A
4		Convolutional_Param	4608		3744			0.187500
4		Convolutional_Filter	16		13			0.187500
5		BatchNorm	N/A		N/A			N/A
6		ReLU		N/A		N/A			N/A
7		Convolutional_Param	2304		2016			0.125000
7		Convolutional_Filter	16		14			0.125000
8		BatchNorm	N/A		N/A			N/A
9		Convolutional_Param	512		480			0.062500
9		Convolutional_Filter	16		15			0.062500
10		BatchNorm	N/A		N/A			N/A
11		ReLU		N/A		N/A			N/A
12		Convolutional_Param	2304		2160			0.062500
12		Convolutional_Filter	16		15			0.062500
13		BatchNorm	N/A		N/A			N/A
14		ReLU		N/A		N/A			N/A
15		Convolutional_Param	2304		2007			0.128906
15		Convolutional_Filter	16		14			0.125000
16		BatchNorm	N/A		N/A			N/A
17		ReLU		N/A		N/A			N/A
18		Convolutional_Param	2304		2160			0.062500
18		

88292762.0