In [1]:
import tensorflow as tf
from tensorflow.keras import models
import tensorflow_datasets as tfds
from tensorflow.keras.applications.resnet import preprocess_input
from tensorflow.keras import Model, Input
physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices:
    tf.config.experimental.set_memory_growth(device, True)


from tensorflow.keras.layers import Conv2D, ReLU, BatchNormalization, InputLayer, Layer

#used to fix bug in keras preprocessing scope
temp = tf.zeros([4, 32, 32, 3])  # Or tf.zeros
preprocess_input(temp)
print("processed")

processed


In [2]:
IMAGE_SIZE = (64, 64)
TRAIN_SIZE = 50000
VALIDATION_SIZE = 10000
BATCH_SIZE = 256
NUM_CLASSES = 10

In [3]:
@tf.function
def normalize(input_image):
  return preprocess_input(input_image)

@tf.function
def load_image_test(datapoint):
  input_image, label = tf.image.resize(datapoint["image"], IMAGE_SIZE), datapoint['label']
  #input_image = preprocess_input(input_image)

  input_image = normalize(input_image)

  return input_image, tf.one_hot(label, depth=NUM_CLASSES)

dataset, info = tfds.load('cifar10', with_info=True)
train = dataset['train'].map(load_image_test, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train.batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)


test_dataset = dataset['test'].map(load_image_test, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

In [4]:
model = models.load_model('/tf/notebooks/cifar10.h5')

In [5]:
model.freeze = True

In [6]:
output_layer_names = ["pool1_pool", "conv2_block1_2_relu", "conv2_block1_3_bn"]
outputs = []
for layer in model.layers:
    if layer.name in output_layer_names:
        outputs.append(layer.output)

In [7]:
output_model = Model(inputs=model.input, outputs=outputs)

In [8]:
output_model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, None, None, 3)]   0         
_________________________________________________________________
conv1_pad (ZeroPadding2D)    (None, None, None, 3)     0         
_________________________________________________________________
conv1_conv (Conv2D)          (None, None, None, 64)    9472      
_________________________________________________________________
conv1_bn (BatchNormalization (None, None, None, 64)    256       
_________________________________________________________________
conv1_relu (Activation)      (None, None, None, 64)    0         
_________________________________________________________________
pool1_pad (ZeroPadding2D)    (None, None, None, 64)    0         
_________________________________________________________________
pool1_pool (MaxPooling2D)    (None, None, None, 64)    0     

In [9]:
output_model.save('/tmp/output_model.h5')
del model
tf.keras.backend.clear_session()
output_model = tf.keras.models.load_model('/tmp/output_model.h5')





In [10]:
prune_layers = ["conv2_block1_1_conv", 
                "conv2_block1_1_bn", 
                "conv2_block1_1_relu",
                "conv2_block1_2_conv",
                "conv2_block1_2_bn",
                "conv2_block1_2_relu",
                "conv2_block1_3_conv",
                "conv2_block1_3_bn"]
layer_dict = {}
for layer in output_model.layers:
    if layer.name in prune_layers:
        layer_dict[layer.name] = {"weights":layer.get_weights()}

In [11]:
for x_batch, y_batch in train_dataset.take(1):
    x, y_1, y_2 = output_model(x_batch)
    print(x.shape, y_1.shape, y_2.shape)

(256, 16, 16, 64) (256, 16, 16, 64) (256, 16, 16, 256)


In [12]:

model_config = output_model.get_config()
for layer in prune_layers:
    for layer_config in model_config['layers']:
        if layer_config['config']['name'] == layer:
            layer_dict[layer]['config'] =  layer_config

In [13]:
input_tensor = Input(shape=(16, 16, 64))
layer_class = getattr(tf.keras.layers, layer_dict[prune_layers[0]]['config']['class_name'])
x = layer_class.from_config(layer_dict[prune_layers[0]]['config']['config'])(input_tensor)
x_0 = None
for layer in prune_layers[1::]:
    layer_class = getattr(tf.keras.layers, layer_dict[layer]['config']['class_name'])
    if layer == 'conv2_block1_2_relu':
        x_0 = layer_class.from_config(layer_dict[layer]['config']['config'])(x)
    else:
        x = layer_class.from_config(layer_dict[layer]['config']['config'])(x)

sub_model = Model(inputs=input_tensor, outputs=[x_0, x])

In [14]:
sub_model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 16, 16, 64)] 0                                            
__________________________________________________________________________________________________
conv2_block1_1_conv (Conv2D)    (None, 16, 16, 64)   4160        input_1[0][0]                    
__________________________________________________________________________________________________
conv2_block1_1_bn (BatchNormali (None, 16, 16, 64)   256         conv2_block1_1_conv[0][0]        
__________________________________________________________________________________________________
conv2_block1_1_relu (Activation (None, 16, 16, 64)   0           conv2_block1_1_bn[0][0]          
______________________________________________________________________________________________

In [15]:
mse = tf.keras.losses.MeanSquaredError()
mse_metric_0 = tf.keras.metrics.MeanSquaredError()
mse_metric_1 = tf.keras.metrics.MeanSquaredError()

In [16]:
for layer in prune_layers:
    for rep_layer in sub_model.layers:
        if layer == rep_layer.name:
            rep_layer.set_weights(layer_dict[layer]['weights'])

In [17]:
@tf.function
def calculate_mse(mse_metric_0, mse_metric_1, mse, train_dataset, output_model, sub_model):

    for  (x_batch, y_batch) in train_dataset:

        x, y_0, y_1 = output_model(x_batch)
        pre_0, pre_1 = sub_model(x)
        mse_metric_0.update_state(y_0, pre_0)
        mse_metric_1.update_state(y_1, pre_1)

In [18]:
sub_model.layers

[<tensorflow.python.keras.engine.input_layer.InputLayer at 0x7fe3104811d0>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7fe3104813c8>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7fe3104816d8>,
 <tensorflow.python.keras.layers.core.Activation at 0x7fe31047d860>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7fe31047d780>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7fe31049fc50>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7fe3104a3e10>,
 <tensorflow.python.keras.layers.core.Activation at 0x7fe3104a3f28>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7fe44ffda278>]

In [19]:
%%time
import numpy as np
# save original weights
og_conv_weights = sub_model.layers[1].get_weights()
og_bn_weights = sub_model.layers[2].get_weights()

filter_scores = []
#prune and score each filter by MSE error
for filters in range(64):
    
    # get inital weights
    conv_weights = sub_model.layers[1].get_weights()
    bn_weights = sub_model.layers[2].get_weights()
    
    #prun conv layer
    conv_weights[0][:,:,:,filters] = np.zeros(conv_weights[0][:,:,:,filters].shape)
    conv_weights[1][filters] = 0
    
    #prune batch norm variables
    for weight in bn_weights:
        weight[filters] = 0
    
    # reassign into submodel
    sub_model.layers[1].set_weights(conv_weights)
    sub_model.layers[2].set_weights(bn_weights)
    
    #calculate MSE over dataset and show result
    calculate_mse(mse_metric_0, mse_metric_1, mse, train_dataset, output_model, sub_model)
    l2_norm = np.linalg.norm(og_conv_weights[0][:,:,:,filters])
    result_0 = mse_metric_0.result().numpy()
    result_1 = mse_metric_1.result().numpy()
    filter_scores.append({'filter': filters, 'mse': result_0, 'mse_1': result_1, 'L2': l2_norm})
    mse_metric_0.reset_states()
    mse_metric_1.reset_states()


CPU times: user 19min 48s, sys: 1min 42s, total: 21min 30s
Wall time: 1min 50s


In [20]:
model = tf.keras.models.load_model('../cifar10.h5')
m = tf.keras.metrics.CategoricalCrossentropy()
acc = tf.keras.metrics.Accuracy()

In [21]:
for (x_batch, y_batch) in test_dataset:
    logits = model(x_batch)
    m.update_state(logits, y_batch)
    
    prediction = tf.argmax(logits, axis=1)
    labels = tf.argmax(y_batch, axis=1)
    acc.update_state(prediction, labels)

In [22]:
f"Loss {m.result().numpy()}, Accuracy: {acc.result()}"

'Loss 1.6843411922454834, Accuracy: 0.9035999774932861'

In [26]:
import copy
import math
mse_sorted_filters = sorted(filter_scores, key=lambda x: x['mse'])
mse_prune_filters = [x['filter'] for x in mse_sorted_filters[:math.floor(64*.50)]]

mse_sorted_filters_1 = sorted(filter_scores, key=lambda x: x['mse_1'])
mse_prune_filters_1 = [x['filter'] for x in mse_sorted_filters_1[:math.floor(64*.60)]]

mse_sorted_filters_2 = sorted(filter_scores, key=lambda x: x['mse_1'] + x['mse'])
mse_prune_filters_2 = [x['filter'] for x in mse_sorted_filters_2[:math.floor(64*.60)]]

In [27]:
mse_conv_weights = copy.deepcopy(og_conv_weights)
mse_bn_weights = copy.deepcopy(og_bn_weights)
for index in mse_prune_filters_2:
    mse_conv_weights[0][:,:,:, index] = np.zeros(conv_weights[0][:,:,:,index].shape)
    mse_conv_weights[1][index] = 0
    
    for weight in mse_bn_weights:
        weight[index] = 0

In [28]:
m.reset_states()
acc.reset_states()

model.layers[7].set_weights(mse_conv_weights)

model.layers[8].set_weights(mse_bn_weights)

for (x_batch, y_batch) in train_dataset:
    logits = model(x_batch)
    m.update_state(logits, y_batch)
    
    prediction = tf.argmax(logits, axis=1)
    labels = tf.argmax(y_batch, axis=1)
    acc.update_state(prediction, labels)

print(f"Loss {m.result().numpy()}, Accuracy: {acc.result()}")
m.reset_states()
acc.reset_states()

for (x_batch, y_batch) in test_dataset:
    logits = model(x_batch)
    m.update_state(logits, y_batch)
    
    prediction = tf.argmax(logits, axis=1)
    labels = tf.argmax(y_batch, axis=1)
    acc.update_state(prediction, labels)
print(f"Loss {m.result().numpy()}, Accuracy: {acc.result()}")

model.load_weights("../cifar10.h5")

Loss 0.8693946599960327, Accuracy: 0.9563800096511841
Loss 2.928887128829956, Accuracy: 0.8289999961853027


In [29]:
import math
l2_sorted_filters = sorted(filter_scores, key=lambda x: x['L2'])
l2_prune_filters = [x['filter'] for x in l2_sorted_filters[:math.floor(64*.60)]]

l2_conv_weights = copy.deepcopy(og_conv_weights)
l2_bn_weights = copy.deepcopy(og_bn_weights)
for index in l2_prune_filters:
    l2_conv_weights[0][:,:,:, index] = np.zeros(conv_weights[0][:,:,:,index].shape)
    l2_conv_weights[1][index] = 0
    
    for weight in l2_bn_weights:
        weight[index] = 0

In [30]:
m.reset_states()
acc.reset_states()

model.layers[7].set_weights(l2_conv_weights)

model.layers[8].set_weights(l2_bn_weights)

for (x_batch, y_batch) in train_dataset:
    logits = model(x_batch)
    m.update_state(logits, y_batch)
    
    prediction = tf.argmax(logits, axis=1)
    labels = tf.argmax(y_batch, axis=1)
    acc.update_state(prediction, labels)

print(f"Loss {m.result().numpy()}, Accuracy: {acc.result()}")
m.reset_states()
acc.reset_states()

for (x_batch, y_batch) in test_dataset:
    logits = model(x_batch)
    m.update_state(logits, y_batch)
    
    prediction = tf.argmax(logits, axis=1)
    labels = tf.argmax(y_batch, axis=1)
    acc.update_state(prediction, labels)
print(f"Loss {m.result().numpy()}, Accuracy: {acc.result()}")



Loss 0.1719799041748047, Accuracy: 0.993340015411377
Loss 2.1451756954193115, Accuracy: 0.875
