## Prune a pre-trained model

Following this [tutorial](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras)

In [1]:
import numpy as np
import tensorflow_model_optimization as tfmot

import tempfile

from keras.models import load_model
from keras.optimizers import RMSprop

import keras.backend as K
import tensorflow as tf

### Get Data and Load Model

In [2]:
X_train = np.load('../train_and_val/X_train.npy')
X_val = np.load('../train_and_val/X_val.npy')
y_train = np.load('../train_and_val/y_train.npy')
y_val = np.load('../train_and_val/y_val.npy')

In [3]:
print('Maximum Scaled Duration for X_train: {}'.format(X_train[:, :, -1].max()))
print('Maximum Scaled Duration for X_val: {}'.format(X_val[:, :, -1].max()))
print('Maximum Scaled Duration for y_train: {}'.format(y_train[:, -1].max()))
print('Maximum Scaled Duration for y_val: {}'.format(y_val[:, -1].max()))

Maximum Scaled Duration for X_train: 0.860215053763441
Maximum Scaled Duration for X_val: 0.5516975308641976
Maximum Scaled Duration for y_train: 1.0
Maximum Scaled Duration for y_val: 0.9166666666666667


In [4]:
print('Train-Validation Ratio of the Mean of the Scaled Duration: ', y_train[:, -1].mean() / y_val[:, -1].mean())
print('Train-Validation Ratio of the Stdv of the Scaled Duration: ', y_train[:, -1].std() / y_val[:, -1].std())

Train-Validation Ratio of the Mean of the Scaled Duration:  0.9898257021922111
Train-Validation Ratio of the Stdv of the Scaled Duration:  0.9209054376139967


In [5]:
def maestro_loss_wr(harshness): 
    """A loss function which, in addition to penalizing for misclassification on the 
    first n_keys_piano elements, includes a term proportional to the relative
    error in the prediction of the last element (which repesents the duration). 
    The proportionality constant is the 'harshness' of the maestro in regards to
    timing."""
    def maestro_loss(ytrue, ypred):
        # Standard binary cross-entropy
        bce_loss = - K.mean(ytrue[:, :-1] * K.log(ypred[:, :-1]) + (1 - ytrue[:, :-1]) * \
                     K.log(1 - ypred[:, :-1]))

        # Duration error term
        dur_loss = 2 * harshness * K.mean(K.abs((ytrue[:, -1] - ypred[:, -1]) / \
                                      (ytrue[:, -1] + ypred[:, -1] + K.epsilon())))
        
        if (dur_loss > bce_loss):   # Often times, ytrue[:, -1] elements will be zero
            return bce_loss * 2     # This may spike dur_loss. To control, I limit it
                                    # so that it never exceeds the bce_loss.
        return bce_loss + dur_loss
    
    return maestro_loss

def precision_mod(ytrue, ypred):
    """Just a modified precision excluding the last element (which is not a classification)"""

    true_positives = K.sum(K.round(ytrue[:, :-1] * ypred[:, :-1]))
    pred_positives = K.sum(K.round(ypred[:, :-1]))
    return true_positives / (pred_positives + K.epsilon())

def recall_mod(ytrue, ypred):
    """Just a modified recall excluding the last element (which is not a classification)"""

    true_positives = K.sum(K.round(ytrue[:, :-1] * ypred[:, :-1]))
    poss_positives = K.sum(ytrue[:, :-1])
    return true_positives / (poss_positives + K.epsilon())

def f1_score_mod(ytrue, ypred):
    """Just a modified f1_score excluding the last element (which is not a classification)"""

    precision = precision_mod(ytrue, ypred)
    recall = recall_mod(ytrue, ypred)   
    return 2 * (precision * recall) / (precision + recall + K.epsilon())

def dur_error(ytrue, ypred):
    """A new metric that only gives information on the error in duration predictions"""
    
    return 2 * K.mean(K.abs((ytrue[:, -1] - ypred[:, -1]) / (ytrue[:, -1] + ypred[:, -1] + \
                                                         K.epsilon())))

def maestro_dur_loss_wr(harshness):
    """The second term of the maestro loss, based purely on error in duration predictions.
    To be used as a metric in order to decompose the loss components during analysis"""
    def maestro_dur_loss(ytrue, ypred):

        return 2 * harshness * K.mean(K.abs((ytrue[:, -1] - ypred[:, -1]) / \
                                      (ytrue[:, -1] + ypred[:, -1] + K.epsilon())))
    return maestro_dur_loss

In [6]:
harshness = 0.05

In [7]:
def load_model_from_file(file_path, harshness = harshness):
    
    custom_objects = { 'maestro_loss': maestro_loss_wr(harshness), \
        'f1_score_mod': f1_score_mod, 'recall_mod': recall_mod, \
        'precision_mod': precision_mod, 'dur_error': dur_error, \
        'maestro_dur_loss': maestro_dur_loss_wr(harshness)}

    return load_model(file_path, custom_objects = custom_objects)

In [8]:
opt = RMSprop(lr = 0.0005, clipvalue = 0.2)

In [9]:
def prune_model(model, batch_size = 512, epochs = 2, initial_sparsity = 0.5, final_sparsity = 0.8):
    
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
    end_step = np.ceil(X_train.shape[0] / batch_size) * epochs
    
    # Define model for pruning.
    pruning_params = {
          'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
              initial_sparsity = initial_sparsity, final_sparsity = final_sparsity,
              begin_step=0, end_step=end_step)
    }
    model_for_pruning = prune_low_magnitude(model, **pruning_params)
    
    # `prune_low_magnitude` requires a recompile.
    model_for_pruning.compile(loss = maestro_loss_wr(harshness), 
                          optimizer = opt, 
                          metrics = [f1_score_mod, recall_mod, precision_mod, \
                                     dur_error, maestro_dur_loss_wr(harshness)])

    model_for_pruning.summary()
    
    logdir = tempfile.mkdtemp()

    callbacks = [
      tfmot.sparsity.keras.UpdatePruningStep(),
      tfmot.sparsity.keras.PruningSummaries(log_dir = logdir),
    ]

    model_for_pruning.fit(X_train, y_train, batch_size = batch_size, epochs = epochs, 
                      validation_data = (X_val, y_val), verbose = 2, callbacks = callbacks)
    
    return model

In [10]:
model = load_model_from_file('../models/best_maestro_model_2_1_512_0pt4_lr_5e-04_cv_0pt2.h5')
pruned_model = prune_model(model)

Instructions for updating:
Please use `layer.add_weight` method instead.
Model: "sequential_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_lstm_20  (None, 16, 512)           2463747   
_________________________________________________________________
prune_low_magnitude_dropout_ (None, 16, 512)           1         
_________________________________________________________________
prune_low_magnitude_lstm_21  (None, 512)               4196355   
_________________________________________________________________
prune_low_magnitude_dropout_ (None, 512)               1         
_________________________________________________________________
prune_low_magnitude_dense_20 (None, 256)               262402    
_________________________________________________________________
prune_low_magnitude_activati (None, 256)               1         
______________________________________________

In [11]:
model = load_model_from_file('../models/best_maestro_model_2_1_512_0pt4_lr_5e-04_cv_0pt2.h5')
pruned_model = prune_model(model, epochs = 25)

Model: "sequential_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_lstm_20  (None, 16, 512)           2463747   
_________________________________________________________________
prune_low_magnitude_dropout_ (None, 16, 512)           1         
_________________________________________________________________
prune_low_magnitude_lstm_21  (None, 512)               4196355   
_________________________________________________________________
prune_low_magnitude_dropout_ (None, 512)               1         
_________________________________________________________________
prune_low_magnitude_dense_20 (None, 256)               262402    
_________________________________________________________________
prune_low_magnitude_activati (None, 256)               1         
_________________________________________________________________
prune_low_magnitude_dropout_ (None, 256)             

Epoch 24/25
50/50 - 244s - loss: nan - f1_score_mod: nan - recall_mod: nan - precision_mod: nan - dur_error: nan - maestro_dur_loss: nan - val_loss: nan - val_f1_score_mod: nan - val_recall_mod: nan - val_precision_mod: nan - val_dur_error: nan - val_maestro_dur_loss: nan
Epoch 25/25
50/50 - 246s - loss: nan - f1_score_mod: nan - recall_mod: nan - precision_mod: nan - dur_error: nan - maestro_dur_loss: nan - val_loss: nan - val_f1_score_mod: nan - val_recall_mod: nan - val_precision_mod: nan - val_dur_error: nan - val_maestro_dur_loss: nan
