In [None]:
import tensorflow as tf
from tensorflow.keras import models
import tensorflow_datasets as tfds
from tensorflow.keras.applications.resnet import preprocess_input
from tensorflow.keras import Model, Input
physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices:
    tf.config.experimental.set_memory_growth(device, True)


from tensorflow.keras.layers import Conv2D, ReLU, BatchNormalization, InputLayer, Layer

#used to fix bug in keras preprocessing scope
temp = tf.zeros([4, 32, 32, 3])  # Or tf.zeros
preprocess_input(temp)
print("processed")

In [2]:
IMAGE_SIZE = (64, 64)
TRAIN_SIZE = 50000
VALIDATION_SIZE = 10000
BATCH_SIZE = 256
NUM_CLASSES = 10

In [3]:
@tf.function
def normalize(input_image):
  return preprocess_input(input_image)

@tf.function
def load_image_test(datapoint):
  input_image, label = tf.image.resize(datapoint["image"], IMAGE_SIZE), datapoint['label']
  #input_image = preprocess_input(input_image)

  input_image = normalize(input_image)

  return input_image, tf.one_hot(label, depth=NUM_CLASSES)

dataset, info = tfds.load('cifar10', with_info=True)
train = dataset['train'].map(load_image_test, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train.batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)


test_dataset = dataset['test'].map(load_image_test, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

In [4]:
model = models.load_model('/tf/notebooks/cifar10.h5')

In [5]:
model.freeze = True

In [6]:
#model.summary()

In [7]:
output_layer_names = [ "conv4_block2_2_relu", "conv4_block2_3_conv"]
outputs = []
for layer in model.layers:
    if layer.name in output_layer_names:
        outputs.append(layer.output)

In [8]:
output_model = Model(inputs=model.input, outputs=outputs)

In [9]:
#output_model.summary()

In [17]:
output_model.save('/tmp/output_model.h5')
del model
tf.keras.backend.clear_session()
output_model = tf.keras.models.load_model('/tmp/output_model.h5')





In [10]:
prune_layers = [
                "conv4_block2_2_conv",
                "conv4_block2_2_bn",
                "conv4_block2_2_relu",
                "conv4_block2_3_conv"
                ]
layer_dict = {}
for layer in output_model.layers:
    if layer.name in prune_layers:
        layer_dict[layer.name] = {"weights":layer.get_weights()}

In [11]:
for x_batch, y_batch in train_dataset.take(1):
    x, y = output_model(x_batch)
    print(x.shape, y.shape)

(256, 4, 4, 256) (256, 4, 4, 1024)


In [12]:

model_config = output_model.get_config()
for layer in prune_layers:
    for layer_config in model_config['layers']:
        if layer_config['config']['name'] == layer:
            layer_dict[layer]['config'] =  layer_config

In [13]:
input_tensor = Input(shape=(4, 4, 256))
layer_class = getattr(tf.keras.layers, layer_dict[prune_layers[0]]['config']['class_name'])
x = layer_class.from_config(layer_dict[prune_layers[0]]['config']['config'])(input_tensor)
x_0 = None
for layer in prune_layers[1::]:
    layer_class = getattr(tf.keras.layers, layer_dict[layer]['config']['class_name'])

    x = layer_class.from_config(layer_dict[layer]['config']['config'])(x)

sub_model = Model(inputs=input_tensor, outputs=[ x])

In [14]:
sub_model.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 4, 4, 256)]       0         
_________________________________________________________________
conv4_block2_2_conv (Conv2D) (None, 4, 4, 256)         590080    
_________________________________________________________________
conv4_block2_2_bn (BatchNorm (None, 4, 4, 256)         1024      
_________________________________________________________________
conv4_block2_2_relu (Activat (None, 4, 4, 256)         0         
_________________________________________________________________
conv4_block2_3_conv (Conv2D) (None, 4, 4, 1024)        263168    
Total params: 854,272
Trainable params: 853,760
Non-trainable params: 512
_________________________________________________________________


In [15]:
mse = tf.keras.losses.MeanSquaredError()
mse_metric_0 = tf.keras.metrics.MeanSquaredError()

In [16]:
for layer in prune_layers:
    for rep_layer in sub_model.layers:
        if layer == rep_layer.name:
            rep_layer.set_weights(layer_dict[layer]['weights'])

In [17]:
@tf.function
def calculate_mse(mse_metric_0, mse, train_dataset, output_model, sub_model):

    for  (x_batch, y_batch) in train_dataset:

        x, y_0= output_model(x_batch)
        pre_0 = sub_model(x)
        mse_metric_0.update_state(y_0, pre_0)
        

In [18]:
sub_model.layers

[<tensorflow.python.keras.engine.input_layer.InputLayer at 0x7fd8b00446a0>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7fd8b0044320>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7fd8b004b470>,
 <tensorflow.python.keras.layers.core.Activation at 0x7fd86404fd30>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7fd85c7db630>]

In [29]:
%%time
import numpy as np
# save original weights
og_conv_weights = sub_model.layers[1].get_weights()
og_bn_weights = sub_model.layers[2].get_weights()

filter_scores = []
#prune and score each filter by MSE error
for filters in range(256):
    
    # get inital weights
    conv_weights = sub_model.layers[1].get_weights()
    bn_weights = sub_model.layers[2].get_weights()
    
    #prun conv layer
    conv_weights[0][:,:,:,filters] = np.zeros(conv_weights[0][:,:,:,filters].shape)
    conv_weights[1][filters] = 0
    
    #prune batch norm variables
    for weight in bn_weights:
        weight[filters] = 0
    
    # reassign into submodel
    sub_model.layers[1].set_weights(conv_weights)
    sub_model.layers[2].set_weights(bn_weights)
    
    #calculate MSE over dataset and show result
    calculate_mse(mse_metric_0,  mse, train_dataset, output_model, sub_model)
    l2_norm = np.linalg.norm(og_conv_weights[0][:,:,:,filters])
    result_0 = mse_metric_0.result().numpy()
    filter_scores.append({'filter': filters, 'mse': result_0, 'L2': l2_norm})
    mse_metric_0.reset_states()


CPU times: user 1h 31min 30s, sys: 7min 33s, total: 1h 39min 3s
Wall time: 20min 9s


In [30]:
model = tf.keras.models.load_model('../cifar10.h5')
m = tf.keras.metrics.CategoricalCrossentropy()
acc = tf.keras.metrics.Accuracy()

In [31]:
for (x_batch, y_batch) in test_dataset:
    logits = model(x_batch)
    m.update_state(logits, y_batch)
    
    prediction = tf.argmax(logits, axis=1)
    labels = tf.argmax(y_batch, axis=1)
    acc.update_state(prediction, labels)

In [32]:
f"Loss {m.result().numpy()}, Accuracy: {acc.result()}"

'Loss 1.6843410730361938, Accuracy: 0.9035999774932861'

In [40]:
import copy
import math
mse_sorted_filters = sorted(filter_scores, key=lambda x: x['mse'])
mse_prune_filters = [x['filter'] for x in mse_sorted_filters[:math.floor(64*.95)]]



In [41]:
mse_conv_weights = copy.deepcopy(og_conv_weights)
mse_bn_weights = copy.deepcopy(og_bn_weights)
for index in mse_prune_filters:
    mse_conv_weights[0][:,:,:, index] = np.zeros(conv_weights[0][:,:,:,index].shape)
    mse_conv_weights[1][index] = 0
    
    for weight in mse_bn_weights:
        weight[index] = 0

In [42]:
m.reset_states()
acc.reset_states()

for layer in model.layers:
    if layer.name == "conv4_block2_2_conv":
        layer.set_weights(mse_conv_weights)
    elif layer.name == "conv4_block2_2_bn":
        layer.set_weights(mse_bn_weights)

for (x_batch, y_batch) in train_dataset:
    logits = model(x_batch)
    m.update_state(logits, y_batch)
    
    prediction = tf.argmax(logits, axis=1)
    labels = tf.argmax(y_batch, axis=1)
    acc.update_state(prediction, labels)

print(f"Loss {m.result().numpy()}, Accuracy: {acc.result()}")
m.reset_states()
acc.reset_states()

for (x_batch, y_batch) in test_dataset:
    logits = model(x_batch)
    m.update_state(logits, y_batch)
    
    prediction = tf.argmax(logits, axis=1)
    labels = tf.argmax(y_batch, axis=1)
    acc.update_state(prediction, labels)
print(f"Loss {m.result().numpy()}, Accuracy: {acc.result()}")

model.load_weights("../cifar10.h5")

Loss 0.002582801505923271, Accuracy: 1.0
Loss 1.7206816673278809, Accuracy: 0.9010000228881836


In [43]:
import math
l2_sorted_filters = sorted(filter_scores, key=lambda x: x['L2'])
l2_prune_filters = [x['filter'] for x in l2_sorted_filters[:math.floor(64*.95)]]

l2_conv_weights = copy.deepcopy(og_conv_weights)
l2_bn_weights = copy.deepcopy(og_bn_weights)
for index in l2_prune_filters:
    l2_conv_weights[0][:,:,:, index] = np.zeros(conv_weights[0][:,:,:,index].shape)
    l2_conv_weights[1][index] = 0
    
    for weight in l2_bn_weights:
        weight[index] = 0

In [44]:
m.reset_states()
acc.reset_states()

for layer in model.layers:
    if layer.name == "conv4_block2_2_conv":
        layer.set_weights(l2_conv_weights)
    elif layer.name == "conv4_block2_2_bn":
        layer.set_weights(l2_bn_weights)


for (x_batch, y_batch) in train_dataset:
    logits = model(x_batch)
    m.update_state(logits, y_batch)
    
    prediction = tf.argmax(logits, axis=1)
    labels = tf.argmax(y_batch, axis=1)
    acc.update_state(prediction, labels)

print(f"Loss {m.result().numpy()}, Accuracy: {acc.result()}")
m.reset_states()
acc.reset_states()

for (x_batch, y_batch) in test_dataset:
    logits = model(x_batch)
    m.update_state(logits, y_batch)
    
    prediction = tf.argmax(logits, axis=1)
    labels = tf.argmax(y_batch, axis=1)
    acc.update_state(prediction, labels)
print(f"Loss {m.result().numpy()}, Accuracy: {acc.result()}")



Loss 0.0021090859081596136, Accuracy: 1.0
Loss 1.7103685140609741, Accuracy: 0.901199996471405
