In [2]:
import numpy as np
import tensorflow as tf
import copy
from kerassurgeon.operations import delete_channels
from keras.models import Sequential
from keras.layers import Dense, Conv2D, Flatten, MaxPooling2D

In [26]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

x_train = x_train.astype('float32')
x_train /= 255.
x_test = x_test.astype('float32')
x_test /= 255.

In [22]:
x_test.shape[1:4]

(28, 28, 1)

In [28]:
model = Sequential()
model.add(Conv2D(28, kernel_size=(3,3), input_shape=x_train.shape[1:4], activation = "relu", padding = "same"))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(28, kernel_size=(3,3), input_shape=x_train.shape[1:4], activation = "relu", padding = "same"))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(32, activation=tf.nn.relu))
model.add(Dense(10,activation=tf.nn.softmax))

In [31]:
model.compile(optimizer='adam', 
              loss='sparse_categorical_crossentropy', 
              metrics=['accuracy'])
model.fit(x=x_train,y=y_train, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.callbacks.History at 0xb43e68e48>

In [32]:
model.evaluate(x_test, y_test)



[0.03287297285744397, 0.9904000163078308]

In [38]:
model.count_params()

51630

In [35]:
def scoring(model, layer_index):       #it returns normalized scores of layer 'layer_index'
    numberOfFilters = model.get_layer(index = layer_index).get_weights()[0].shape[3]
#     print(numberOfFilters)
    scores = []
    for i in range (numberOfFilters):
        scores.append(np.linalg.norm( model.get_layer(index = layer_index).get_weights()[0][:,:,:,i]) + np.linalg.norm( model.get_layer(index = layer_index).get_weights()[1][i])) 
    
    maxScore = max(scores)
    return np.array(scores)/maxScore

In [34]:
def pruning(model, threshold, epochs):
    iteration = 0
    while True:
        iteration +=1
        print("iteration #" + str(iteration))
        filterLayers = []                                
        for i in range(len(model.layers)):
            if str(model.layers[i]).count("convolutional") > 0:
                filterLayers.append(i)

        scoresOfLayers = {}
        for i in filterLayers:
            scoresOfLayers["l"+str(i)] = scoring(model, i)

        flagToProceed = False
        for i in scoresOfLayers:
            if scoresOfLayers[i].shape[0] > 1:
                flagToProceed = True

        if flagToProceed == False:
            print("All layers have only 1 filter left, stopping it !!!")
            break

        minimumScoresPerLayer = []
        for i in scoresOfLayers:
            minimumScoresPerLayer.append(scoresOfLayers[i].min())
        pruningLayer = filterLayers[np.array(minimumScoresPerLayer).argmin()]
        pruningFilter = scoresOfLayers["l" + str(pruningLayer)].argmin()

        print(scoresOfLayers)
        print("====================")
        print(pruningLayer, pruningFilter)
        print("====================")

        modelDeepCopied = copy.deepcopy(model)
        model_new = delete_channels(model, model.layers[pruningLayer], [pruningFilter])

        model_new.compile(optimizer='adam', 
                      loss='sparse_categorical_crossentropy', 
                      metrics=['accuracy'])
        model_new.fit(x=x_train,y=y_train, epochs=epochs)


        print("test accuracy:")
        print(model_new.evaluate(x_test, y_test)[1])

        if model_new.evaluate(x_test, y_test)[1] > threshold:
            model = model_new
        else:
            print("accuracy dropped below threhold, stopping it !!!")
            break


In [36]:
pruning(model, 0.96, 5)

iteration #1
{'l0': array([0.606814  , 0.9229351 , 0.7232515 , 0.36631894, 0.9145028 ,
       0.5678287 , 0.30092943, 0.7762076 , 0.45462617, 0.41574916,
       0.75885916, 0.5261007 , 0.8734554 , 0.718889  , 0.7583239 ,
       0.38806838, 1.        , 0.7461518 , 0.6514504 , 0.9481061 ,
       0.88325757, 0.4384668 , 0.4436274 , 0.9304094 , 0.44880658,
       0.93886197, 0.96065843, 0.9023028 ], dtype=float32), 'l2': array([1.        , 0.7934735 , 0.94585603, 0.8329287 , 0.82306945,
       0.74706537, 0.85043347, 0.94606376, 0.8607214 , 0.8293997 ,
       0.97740555, 0.7578261 , 0.7999112 , 0.85168225, 0.9665853 ,
       0.9070671 , 0.7775275 , 0.7413592 , 0.8106757 , 0.8239844 ,
       0.83023244, 0.8801974 , 0.9994074 , 0.8733474 , 0.7896383 ,
       0.90882677, 0.8787805 , 0.78733444], dtype=float32)}
0 6
Deleting 1/28 channels from layer: conv2d_11
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.9894999861717224
iteration #2
{'l1': array([0.60135496, 0.88029784, 

Deleting 1/23 channels from layer: conv2d_11
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.9911999702453613
iteration #7
{'l1': array([0.7363617 , 0.81606406, 0.570308  , 0.7879637 , 0.7132295 ,
       0.6954135 , 0.4791025 , 0.69660455, 0.5771819 , 0.7845995 ,
       0.7594074 , 0.87967855, 1.        , 0.7877985 , 0.5780875 ,
       0.83062553, 0.78096247, 0.9021224 , 0.55311394, 0.81246924,
       0.8421531 , 0.7907958 ], dtype=float32), 'l3': array([0.98729914, 0.7735613 , 0.8650233 , 0.8832152 , 0.8430207 ,
       0.79514897, 0.83727646, 0.9316677 , 0.86425126, 0.8366242 ,
       0.9311468 , 0.82328516, 0.78212255, 0.83393615, 0.96341354,
       0.84314483, 0.8772024 , 0.80050933, 0.81662554, 0.85308385,
       0.84152335, 0.86254334, 1.        , 0.8958082 , 0.8529572 ,
       0.91238225, 0.847303  , 0.84149945], dtype=float32)}
1 6
Deleting 1/22 channels from layer: conv2d_11
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.9905999898910522
i

Deleting 1/17 channels from layer: conv2d_11
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.9909999966621399
iteration #13
{'l1': array([0.86819273, 0.79950505, 0.66314363, 0.7459049 , 0.742827  ,
       0.79112417, 0.87211585, 0.9976168 , 1.        , 0.734746  ,
       0.7255802 , 0.70646095, 0.8468073 , 0.8199129 , 0.8366181 ,
       0.72825617], dtype=float32), 'l3': array([1.        , 0.7559235 , 0.7864129 , 0.8152912 , 0.84534776,
       0.7657463 , 0.8734094 , 0.8951296 , 0.7945451 , 0.81590503,
       0.8702991 , 0.79585755, 0.78301334, 0.8064138 , 0.90034336,
       0.79302084, 0.836674  , 0.81854   , 0.8360636 , 0.83941907,
       0.7738791 , 0.78148204, 0.9760267 , 0.8272246 , 0.854914  ,
       0.86259043, 0.7625965 , 0.8245966 ], dtype=float32)}
1 2
Deleting 1/16 channels from layer: conv2d_11
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.989799976348877
iteration #14
{'l1': array([0.87852824, 0.8078149 , 0.72015476, 0.7428074 , 0.77

Deleting 1/11 channels from layer: conv2d_11
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.9901999831199646
iteration #19
{'l1': array([0.84410363, 0.77391195, 0.8651618 , 0.801484  , 1.        ,
       0.9665672 , 0.78947496, 0.802144  , 0.86259156, 0.85398227],
      dtype=float32), 'l3': array([0.96387154, 0.79109913, 0.8264254 , 0.85653436, 0.8774328 ,
       0.7612279 , 0.975937  , 0.84462726, 0.8439843 , 0.82376736,
       0.9284361 , 0.84614116, 0.75374305, 0.75247407, 1.        ,
       0.78322655, 0.8669173 , 0.86003804, 0.9196192 , 0.8306095 ,
       0.74811155, 0.6978517 , 0.9239285 , 0.8713437 , 0.79196084,
       0.8577676 , 0.7810829 , 0.76050913], dtype=float32)}
3 21
Deleting 1/28 channels from layer: conv2d_12
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.991100013256073
iteration #20
{'l1': array([0.8730273 , 0.7720053 , 0.88151884, 0.81539524, 1.        ,
       0.97351676, 0.78476024, 0.79402626, 0.8653108 , 0.842299  ],
   

Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.9872999787330627
iteration #25
{'l1': array([0.8443758 , 0.80784196, 0.84463423, 1.        , 0.95604837,
       0.7648759 , 0.8068197 ], dtype=float32), 'l3': array([0.86784434, 0.80328053, 0.9424636 , 0.8624853 , 0.9011454 ,
       0.78698397, 0.9560017 , 0.8606216 , 0.8260157 , 1.        ,
       0.94602567, 0.9270304 , 0.7538585 , 0.79489774, 0.9910279 ,
       0.9064664 , 0.80176187, 0.93440706, 0.91998434, 0.8990109 ,
       0.93885446, 0.93099976, 0.9128716 , 0.8306569 , 0.8202466 ],
      dtype=float32)}
3 12
Deleting 1/25 channels from layer: conv2d_12
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.9889000058174133
iteration #26
{'l1': array([0.795911  , 0.7783542 , 0.8053568 , 1.        , 0.9384936 ,
       0.7303897 , 0.74987084], dtype=float32), 'l3': array([0.86509514, 0.78831685, 0.9314662 , 0.8476908 , 0.8848767 ,
       0.80147225, 0.95886314, 0.85226685, 0.83205414, 0.9912721 ,
       0.96220946, 0.90487

Deleting 1/21 channels from layer: conv2d_12
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.9879000186920166
iteration #32
{'l1': array([0.8158256 , 0.85256535, 1.        , 0.9592188 , 0.7904335 ],
      dtype=float32), 'l3': array([0.7795737 , 0.9105023 , 0.85891455, 0.8648841 , 0.7919098 ,
       0.9163253 , 0.86025804, 0.7938032 , 0.9375602 , 0.80888784,
       0.9196509 , 0.8959191 , 0.8498297 , 0.932925  , 0.8099004 ,
       0.92572963, 1.        , 0.91286343, 0.91847634, 0.8256356 ],
      dtype=float32)}
3 0
Deleting 1/20 channels from layer: conv2d_12
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.9876999855041504
iteration #33
{'l1': array([0.80488306, 0.8232901 , 1.        , 0.91344476, 0.77865094],
      dtype=float32), 'l3': array([0.91595614, 0.848447  , 0.8590684 , 0.80865264, 0.9248463 ,
       0.8435749 , 0.8114841 , 0.9421591 , 0.8229401 , 0.9147073 ,
       0.8827177 , 0.8879384 , 0.9401445 , 0.80148375, 0.9275104 ,
       1.   

Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.9860000014305115
iteration #39
{'l1': array([0.8914092 , 1.        , 0.93888414], dtype=float32), 'l3': array([0.92651707, 0.9022754 , 0.95650864, 0.7871371 , 0.82855946,
       0.9638234 , 0.9077628 , 0.89921093, 1.        , 0.85654294,
       0.8639742 , 0.867986  , 0.8448034 , 0.97076166, 0.7949864 ],
      dtype=float32)}
3 3
Deleting 1/15 channels from layer: conv2d_12
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.9854999780654907
iteration #40
{'l1': array([0.8718964, 1.       , 0.9431001], dtype=float32), 'l3': array([0.9059062 , 0.8978187 , 0.9529773 , 0.7828682 , 0.9470806 ,
       0.88755643, 0.8693674 , 1.        , 0.8319657 , 0.89586455,
       0.8445294 , 0.8528954 , 0.9693573 , 0.80503404], dtype=float32)}
3 3
Deleting 1/14 channels from layer: conv2d_12
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.9853000044822693
iteration #41
{'l1': array([0.878981  , 1.        , 0.91362184], dtyp

Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.9797000288963318
iteration #47
{'l1': array([1.        , 0.71729106], dtype=float32), 'l3': array([0.8339063 , 1.        , 0.8350115 , 0.793896  , 0.97200465,
       0.9037127 , 0.98189116, 0.87107646], dtype=float32)}
1 1
Deleting 1/2 channels from layer: conv2d_11
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.9616000056266785
iteration #48
{'l1': array([1.], dtype=float32), 'l3': array([0.5646408 , 0.99616975, 0.47898325, 0.49033818, 1.        ,
       0.8208236 , 0.88847065, 0.35049027], dtype=float32)}
3 7
Deleting 1/8 channels from layer: conv2d_12
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accuracy:
0.9707000255584717
iteration #49
{'l1': array([1.], dtype=float32), 'l3': array([0.53245145, 1.        , 0.54005915, 0.5472608 , 0.97807664,
       0.83249116, 0.834157  ], dtype=float32)}
3 0
Deleting 1/7 channels from layer: conv2d_12
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
test accur

In [None]:
modelDeepCopied          
# final PRUNED model.
# count_params = 3528,
# test accuracy = 0.9663000106811523
# inference time = 5s 345us