# This notebook prunes the PTM using tensorflow model optimization package

In [1]:
from keras.models import load_model
import tempfile
import os
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Layer
from tensorflow import keras

%load_ext tensorboard

-----------------------------------------
keras-unet init: TF version is >= 2.0.0 - using `tf.keras` instead of `Keras`
-----------------------------------------


In [2]:
#define custom objects
def jaccard_distance(y_true, y_pred, smooth=100):
        intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
        sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
        jac = (intersection + smooth) / (sum_ - intersection + smooth)
        return (1 - jac) * smooth

def dice_coef(y_true, y_pred):
    smooth = 1.0
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (
                K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(self, y_true, y_pred):
    loss = 1 - self._dice_coef(y_true, y_pred)
    return loss

class GELU(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(GELU, self).__init__(**kwargs)
    
    def call(self, inputs):
        return keras.activations.gelu(inputs)

In [3]:
model = keras.models.load_model(r"C:\Users\UAB\Segmentation - Main1_CK\Human Model\Keras\runet_kid_best_train.h5", custom_objects={
                       'jaccard_distance': jaccard_distance,
                       'dice_coef_loss': dice_coef_loss,
                       'dice_coef': dice_coef}) #load the weights

### We load the PTM and start the model with 50% sparsity (50% zeros in weights) and end with 80% sparsity.

In [71]:
import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 32
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set. 

#num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(69000 / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
              loss=keras_unet.losses.jaccard_distance,
              metrics=['dice_coef'])

model_for_pruning.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 128, 1  0           []                               
                                )]                                                                
                                                                                                  
 prune_low_magnitude_runetpp_do  (None, 128, 128, 64  1154       ['input_1[0][0]']                
 wn0_0 (PruneLowMagnitude)      )                                                                 
                                                                                                  
 prune_low_magnitude_runetpp_do  (None, 128, 128, 64  257        ['prune_low_magnitude_runetpp_dow
 wn0_0_bn (PruneLowMagnitude)   )                                n0_0[0][0]']                 

 wn3_conv_0 (PruneLowMagnitude)                                  n3_encode_maxpool[0][0]']        
                                                                                                  
 prune_low_magnitude_runetpp_do  (None, 16, 16, 512)  2049       ['prune_low_magnitude_runetpp_dow
 wn3_conv_0_bn (PruneLowMagnitu                                  n3_conv_0[0][0]']                
 de)                                                                                              
                                                                                                  
 prune_low_magnitude_runetpp_do  (None, 16, 16, 512)  1          ['prune_low_magnitude_runetpp_dow
 wn3_conv_0_activation (PruneLo                                  n3_conv_0_bn[0][0]']             
 wMagnitude)                                                                                      
                                                                                                  
 prune_low

 prune_low_magnitude_runetpp_fu  (None, 32, 32, 256)  1025       ['prune_low_magnitude_runetpp_fus
 sion_conv_0_0_bn (PruneLowMagn                                  ion_conv_0_0[0][0]']             
 itude)                                                                                           
                                                                                                  
 prune_low_magnitude_runetpp_fu  (None, 32, 32, 256)  1          ['prune_low_magnitude_runetpp_fus
 sion_conv_0_0_activation (Prun                                  ion_conv_0_0_bn[0][0]']          
 eLowMagnitude)                                                                                   
                                                                                                  
 prune_low_magnitude_runetpp_up  (None, 64, 64, 64)  589890      ['prune_low_magnitude_runetpp_dow
 _1_en0_trans_conv (PruneLowMag                                  n3_conv_1_activation[0][0]']     
 nitude)  

 eLowMagnitude)                                                                                   
                                                                                                  
 prune_low_magnitude_runetpp_up  (None, 128, 128, 64  589890     ['prune_low_magnitude_runetpp_dow
 _2_en0_trans_conv (PruneLowMag  )                               n3_conv_1_activation[0][0]']     
 nitude)                                                                                          
                                                                                                  
 prune_low_magnitude_runetpp_up  (None, 128, 128, 64  294978     ['prune_low_magnitude_runetpp_fus
 _2_en1_trans_conv (PruneLowMag  )                               ion_conv_0_0_activation[0][0]']  
 nitude)                                                                                          
                                                                                                  
 prune_low

 sion_conv_2_0_bn (PruneLowMagn  6)                              ion_conv_2_0[0][0]']             
 itude)                                                                                           
                                                                                                  
 prune_low_magnitude_runetpp_fu  (None, 128, 128, 25  1          ['prune_low_magnitude_runetpp_fus
 sion_conv_2_0_activation (Prun  6)                              ion_conv_2_0_bn[0][0]']          
 eLowMagnitude)                                                                                   
                                                                                                  
 prune_low_magnitude_runetpp_ou  (None, 128, 128, 2)  9220       ['prune_low_magnitude_runetpp_fus
 tput_final (PruneLowMagnitude)                                  ion_conv_2_0_activation[0][0]']  
                                                                                                  
 prune_low

### The above model is then fine tuned (trained again) with small learning rate using the code in the PTM Train and Prediction
Return to this notebook after fine tuning and accuracy evaluation


### More compessed model using strip function

In [36]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

_, pruned_keras_file = tempfile.mkstemp('.h5')
tf.keras.models.save_model(model_for_export, pruned_keras_file, include_optimizer=False)
print('Saved pruned Keras model to:', pruned_keras_file)

Saved pruned Keras model to: C:\Users\UAB\AppData\Local\Temp\tmp6igz8zlr.h5


In [37]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
pruned_tflite_model = converter.convert()

_, pruned_tflite_file = tempfile.mkstemp('.tflite')

with open(pruned_tflite_file, 'wb') as f:
  f.write(pruned_tflite_model)

print('Saved pruned TFLite model to:', pruned_tflite_file)

INFO:tensorflow:Assets written to: C:\Users\UAB\AppData\Local\Temp\tmp2fla3g3s\assets
Saved pruned TFLite model to: C:\Users\UAB\AppData\Local\Temp\tmpo80i0q7e.tflite


In [38]:
#Define a helper function to actually compress the models via gzip and measure the zipped size.
def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)

In [40]:
#print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned Keras model: %.2f bytes" % (get_gzipped_model_size(pruned_keras_file)))
print("Size of gzipped pruned TFlite model: %.2f bytes" % (get_gzipped_model_size(pruned_tflite_file)))

Size of gzipped pruned Keras model: 29257481.00 bytes
Size of gzipped pruned TFlite model: 29168838.00 bytes


### Applying post-training quantization to the pruned model for additional benefits.

In [42]:
converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_and_pruned_tflite_model = converter.convert()

_, quantized_and_pruned_tflite_file = tempfile.mkstemp('.tflite')

with open(quantized_and_pruned_tflite_file, 'wb') as f:
  f.write(quantized_and_pruned_tflite_model)

print('Saved quantized and pruned TFLite model to:', quantized_and_pruned_tflite_file)

#print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped pruned and quantized TFlite model: %.2f bytes" % (get_gzipped_model_size(quantized_and_pruned_tflite_file)))

INFO:tensorflow:Assets written to: C:\Users\UAB\AppData\Local\Temp\tmpffhhsv0a\assets


INFO:tensorflow:Assets written to: C:\Users\UAB\AppData\Local\Temp\tmpffhhsv0a\assets


Saved quantized and pruned TFLite model to: C:\Users\UAB\AppData\Local\Temp\tmpj87xpfhp.tflite
Size of gzipped pruned and quantized TFlite model: 7009759.00 bytes


### Structured Model pruning for additional benefits

In [9]:
DLModel = model

In [12]:
import numpy as np

w = DLModel.get_weights()

def prune_weight_matrix(matrix, p):
    a = len(matrix)
    N = a * (1000 - p) // 1000
    N = int(round(N))  # N = no. of elements that should remain after pruning
    final_list = []
    l2 = matrix.flatten()

    for i in range(0, N):  # A loop to append final_list with N highest weights of the given matrix
        max1 = 0
        ind = 0

        for j in range(len(l2)):
            if abs(l2[j]) > abs(max1):
                max1 = l2[j]
                ind = j
        l2 = np.delete(l2, ind)
        final_list.append(max1)

    t = min(map(abs, final_list))  # t is the threshold i.e. the minimum absolute value of the weight which is allowed after pruning, else prune.
    for k in range(len(matrix.flatten())):
        if abs(matrix.flatten()[k]) < t:
            matrix.flatten()[k] = 0
    return matrix.reshape(matrix.shape)

# Main
encoder_weights = []
decoder_weights = []

# Separate encoder and decoder weights
for i, weight in enumerate(w):
    if i < len(w) // 2:  # Assuming half of the weights correspond to the encoder
        encoder_weights.append(weight)
    else:
        decoder_weights.append(weight)

pruned_encoder_weights = []
for enc_weight in encoder_weights:
    if len(enc_weight) > 100:  # Prune the weight matrix if it has more than 100 elements
        pruned_enc_weight = prune_weight_matrix(enc_weight, 99)
        pruned_encoder_weights.append(pruned_enc_weight)
    else:
        pruned_encoder_weights.append(enc_weight)

wprunedfinal = pruned_encoder_weights + decoder_weights

pmodel = DLModel
pmodel.set_weights(wprunedfinal)


### The above model is then fine tuned (trained again) with small learning rate using the code in the PTM Train and Prediction
Return to this notebook after fine tuning and accuracy evaluation

# Experimenting with pruning only certain layers (customized)

In [56]:
# Apply prune_low_magnitude to the desired layer(s) of the model
# Define the layers to be pruned
layers_to_prune = ['runetpp_down_from0_to0_0_bn', 'runetpp_down_from0_to1_0_bn', 'runetpp_down_from0_to2_0_bn', 'runetpp_down_from0_to0_0_activation', 'runetpp_down_from0_to1_0_activation', 'runetpp_down_from0_to2_0_activation', 'runetpp_concat_0', 'runetpp_fusion_conv_0_0', 'runetpp_fusion_conv_0_0_bn', 'runetpp_fusion_conv_0_0_activation', 'runetpp_up_1_en0_trans_conv', 'runetpp_up_1_en1_trans_conv', 'runetpp_up_1_en0_bn', 'runetpp_up_1_en1_bn']  # Replace with the names or indices of the layers you want to prune

# Apply pruning to the specified layers
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.50, 0),
    'block_size': (1, 1)
}

model_for_pruning = model
for layer_name in layers_to_prune:
    layer = model_for_pruning.get_layer(layer_name)
    pruned_layer = tfmot.sparsity.keras.prune_low_magnitude(layer, **pruning_params)
    model_for_pruning.get_layer(layer_name).set_weights(pruned_layer.get_weights())

# Compile the pruned model
model_for_pruning.compile(optimizer='adam',
              loss=keras_unet.losses.jaccard_distance,
              metrics=['dice_coef'])

model_for_pruning.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           multiple             0           []                               
                                                                                                  
 runetpp_down0_0 (Conv2D)       (None, 128, 128, 64  576         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 runetpp_down0_0_bn (BatchNorma  (None, 128, 128, 64  256        ['runetpp_down0_0[0][0]']        
 lization)                      )                                                                 
                                                                                              

 runetpp_down3_conv_1_activatio  (None, 16, 16, 512)  0          ['runetpp_down3_conv_1_bn[0][0]']
 n (ReLU)                                                                                         
                                                                                                  
 runetpp_up_0_en0_trans_conv (C  (None, 32, 32, 64)  294976      ['runetpp_down3_conv_1_activation
 onv2DTranspose)                                                 [0][0]']                         
                                                                                                  
 runetpp_up_0_en0_bn (BatchNorm  (None, 32, 32, 64)  256         ['runetpp_up_0_en0_trans_conv[0][
 alization)                                                      0]']                             
                                                                                                  
 runetpp_up_0_en0_activation (R  (None, 32, 32, 64)  0           ['runetpp_up_0_en0_bn[0][0]']    
 eLU)     

 2D)                                                             [0][0]']                         
                                                                                                  
 runetpp_down_from1_to0_0_bn (B  (None, 64, 64, 64)  256         ['runetpp_down_from1_to0_0[0][0]'
 atchNormalization)                                              ]                                
                                                                                                  
 runetpp_down_from1_to1_0_bn (B  (None, 64, 64, 64)  256         ['runetpp_down_from1_to1_0[0][0]'
 atchNormalization)                                              ]                                
                                                                                                  
 runetpp_down_from1_to2_0_bn (B  (None, 64, 64, 64)  256         ['runetpp_down_from1_to2_0[0][0]'
 atchNormalization)                                              ]                                
          

                                                                                                  
 runetpp_down_from2_to1_0_activ  (None, 128, 128, 64  0          ['runetpp_down_from2_to1_0_bn[0][
 ation (ReLU)                   )                                0]']                             
                                                                                                  
 runetpp_down_from2_to2_0_activ  (None, 128, 128, 64  0          ['runetpp_down_from2_to2_0_bn[0][
 ation (ReLU)                   )                                0]']                             
                                                                                                  
 runetpp_concat_2 (Concatenate)  (None, 128, 128, 19  0          ['runetpp_down_from2_to0_0_activa
                                2)                               tion[0][0]',                     
                                                                  'runetpp_down_from2_to1_0_activa
          

In [59]:
# Define the model.
base_model = model
base_model.load_weights(pretrained_weights) 
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

_, keras_model_file = tempfile.mkstemp('.h5')

# Checkpoint: saving the optimizer is necessary (include_optimizer=True is the default).
model_for_pruning.save(keras_model_file, include_optimizer=True)





In [60]:
# Deserialize model.
with tfmot.sparsity.keras.prune_scope():
  loaded_model = tf.keras.models.load_model(keras_model_file)

loaded_model.summary()





Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 128, 1  0           []                               
                                )]                                                                
                                                                                                  
 prune_low_magnitude_runetpp_do  (None, 128, 128, 64  1154       ['input_1[0][0]']                
 wn0_0 (PruneLowMagnitude)      )                                                                 
                                                                                                  
 prune_low_magnitude_runetpp_do  (None, 128, 128, 64  257        ['prune_low_magnitude_runetpp_dow
 wn0_0_bn (PruneLowMagnitude)   )                                n0_0[0][0]']                 

 wn3_conv_0 (PruneLowMagnitude)                                  n3_encode_maxpool[0][0]']        
                                                                                                  
 prune_low_magnitude_runetpp_do  (None, 16, 16, 512)  2049       ['prune_low_magnitude_runetpp_dow
 wn3_conv_0_bn (PruneLowMagnitu                                  n3_conv_0[0][0]']                
 de)                                                                                              
                                                                                                  
 prune_low_magnitude_runetpp_do  (None, 16, 16, 512)  1          ['prune_low_magnitude_runetpp_dow
 wn3_conv_0_activation (PruneLo                                  n3_conv_0_bn[0][0]']             
 wMagnitude)                                                                                      
                                                                                                  
 prune_low

 prune_low_magnitude_runetpp_fu  (None, 32, 32, 256)  1025       ['prune_low_magnitude_runetpp_fus
 sion_conv_0_0_bn (PruneLowMagn                                  ion_conv_0_0[0][0]']             
 itude)                                                                                           
                                                                                                  
 prune_low_magnitude_runetpp_fu  (None, 32, 32, 256)  1          ['prune_low_magnitude_runetpp_fus
 sion_conv_0_0_activation (Prun                                  ion_conv_0_0_bn[0][0]']          
 eLowMagnitude)                                                                                   
                                                                                                  
 prune_low_magnitude_runetpp_up  (None, 64, 64, 64)  589890      ['prune_low_magnitude_runetpp_dow
 _1_en0_trans_conv (PruneLowMag                                  n3_conv_1_activation[0][0]']     
 nitude)  

 eLowMagnitude)                                                                                   
                                                                                                  
 prune_low_magnitude_runetpp_up  (None, 128, 128, 64  589890     ['prune_low_magnitude_runetpp_dow
 _2_en0_trans_conv (PruneLowMag  )                               n3_conv_1_activation[0][0]']     
 nitude)                                                                                          
                                                                                                  
 prune_low_magnitude_runetpp_up  (None, 128, 128, 64  294978     ['prune_low_magnitude_runetpp_fus
 _2_en1_trans_conv (PruneLowMag  )                               ion_conv_0_0_activation[0][0]']  
 nitude)                                                                                          
                                                                                                  
 prune_low

 sion_conv_2_0_bn (PruneLowMagn  6)                              ion_conv_2_0[0][0]']             
 itude)                                                                                           
                                                                                                  
 prune_low_magnitude_runetpp_fu  (None, 128, 128, 25  1          ['prune_low_magnitude_runetpp_fus
 sion_conv_2_0_activation (Prun  6)                              ion_conv_2_0_bn[0][0]']          
 eLowMagnitude)                                                                                   
                                                                                                  
 prune_low_magnitude_runetpp_ou  (None, 128, 128, 2)  9220       ['prune_low_magnitude_runetpp_fus
 tput_final (PruneLowMagnitude)                                  ion_conv_2_0_activation[0][0]']  
                                                                                                  
 prune_low