## Prune a pre-trained model

Following this [tutorial](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras)

In [1]:
import numpy as np
import tensorflow_model_optimization as tfmot

import tempfile

from keras.models import load_model
from keras.optimizers import RMSprop
from keras.callbacks import ModelCheckpoint, TerminateOnNaN

import keras.backend as K
import tensorflow as tf

### Get Data and Load Model

In [2]:
X_train = np.load('../train_and_val/X_train_ext.npy')
X_val = np.load('../train_and_val/X_val_ext.npy')
y_train = np.load('../train_and_val/y_train_ext.npy')
y_val = np.load('../train_and_val/y_val_ext.npy')

In [3]:
print('Maximum Scaled Duration for X_train: {}'.format(X_train[:, :, -1].max()))
print('Maximum Scaled Duration for X_val: {}'.format(X_val[:, :, -1].max()))
print('Maximum Scaled Duration for y_train: {}'.format(y_train[:, -1].max()))
print('Maximum Scaled Duration for y_val: {}'.format(y_val[:, -1].max()))

Maximum Scaled Duration for X_train: 0.860215053763441
Maximum Scaled Duration for X_val: 0.5516975308641976
Maximum Scaled Duration for y_train: 1.0
Maximum Scaled Duration for y_val: 0.9166666666666667


In [4]:
print('Train-Validation Ratio of the Mean of the Scaled Duration: ', y_train[:, -1].mean() / y_val[:, -1].mean())
print('Train-Validation Ratio of the Stdv of the Scaled Duration: ', y_train[:, -1].std() / y_val[:, -1].std())

Train-Validation Ratio of the Mean of the Scaled Duration:  0.9898257021922111
Train-Validation Ratio of the Stdv of the Scaled Duration:  0.9209054376139967


In [5]:
def maestro_loss_wr(harshness, n_dur_nodes): 
    """A loss function which, in addition to penalizing for misclassification on the 
    first n_keys_piano elements, includes a term proportional to the relative
    error in the prediction of the last n_dur_nodes elements (whose mean represents
    the duration). The proportionality constant is the 'harshness' of the maestro in 
    regards to timing."""
    def maestro_loss(ytrue, ypred):
        # Standard binary cross-entropy
        bce_loss = - K.mean(ytrue[:, :-n_dur_nodes] * K.log(ypred[:, :-n_dur_nodes]) + \
                            (1 - ytrue[:, :-n_dur_nodes]) * K.log(1 - ypred[:, :-n_dur_nodes]))

        # Duration error term
        dur_loss = 2 * harshness * K.mean(K.abs(K.mean(ytrue[:, -n_dur_nodes:], axis = 1) - \
                                                K.mean(ypred[:, -n_dur_nodes:], axis = 1)) / \
                                      (K.mean(ytrue[:, -n_dur_nodes:], axis = 1) + \
                                       K.mean(ypred[:, -n_dur_nodes:], axis = 1) + K.epsilon()))
        
        if (dur_loss > bce_loss):   # Often times, ytrue[:, :-n_dur_nodes] elements will be zero
            return bce_loss * 2     # (for a rest). This may spike dur_loss. To control, I limit it
                                    # so that it never exceeds the bce_loss.
        return bce_loss + dur_loss
    
    return maestro_loss
def precision_mod_wr(n_dur_nodes):
    def precision_mod(ytrue, ypred):
        """Just a modified precision excluding the last n_dur_nodes elements (which are not
        classification nodes)"""

        true_positives = K.sum(K.round(ytrue[:, :-n_dur_nodes] * ypred[:, :-n_dur_nodes]))
        pred_positives = K.sum(K.round(ypred[:, :-n_dur_nodes]))
        return true_positives / (pred_positives + K.epsilon())
    return precision_mod

def recall_mod_wr(n_dur_nodes):
    def recall_mod(ytrue, ypred):
        """Just a modified recall excluding the last n_dur_nodes elements (which are not
        classification nodes)"""

        true_positives = K.sum(K.round(ytrue[:, :-n_dur_nodes] * ypred[:, :-n_dur_nodes]))
        poss_positives = K.sum(ytrue[:, :-n_dur_nodes])
        return true_positives / (poss_positives + K.epsilon())
    return recall_mod

def f1_score_mod_wr(n_dur_nodes):
    def f1_score_mod(ytrue, ypred):
        """Just a modified f1_score excluding the last n_dur_nodes elements (which are not
        classification nodes)"""

        precision = precision_mod_wr(n_dur_nodes)(ytrue, ypred)
        recall = recall_mod_wr(n_dur_nodes)(ytrue, ypred)   
        return 2 * (precision * recall) / (precision + recall + K.epsilon())
    return f1_score_mod

def dur_error_wr(n_dur_nodes):
    def dur_error(ytrue, ypred):
        """A new metric that only gives information on the error in duration predictions"""
    
        return 2 * K.mean(K.abs((K.mean(ytrue[:, -n_dur_nodes:], axis = 1) - \
                   K.mean(ypred[:, -n_dur_nodes:], axis = 1)) / (K.mean(ytrue[:, -n_dur_nodes:], \
                    axis = 1) + K.mean(ypred[:, -n_dur_nodes:], axis = 1) + K.epsilon())))
    return dur_error

def maestro_dur_loss_wr(harshness, n_dur_nodes):
    """The second term of the maestro loss, based purely on error in duration predictions.
    To be used as a metric in order to decompose the loss components during analysis"""
    def maestro_dur_loss(ytrue, ypred):

        return 2 * harshness * K.mean(K.abs((K.mean(ytrue[:, -n_dur_nodes:], axis = 1) - \
                                      K.mean(ypred[:, -n_dur_nodes:], axis = 1)) / \
                                      (K.mean(ytrue[:, -n_dur_nodes:], axis = 1) + \
                                      K.mean(ypred[:, -n_dur_nodes:], axis = 1) + K.epsilon())))
    return maestro_dur_loss

In [6]:
harshness = 0.05

In [12]:
def load_model_from_file(file_path, harshness = harshness, n_dur_nodes = 20):
    
    custom_objects = {'maestro_loss': maestro_loss_wr(harshness, \
        n_dur_nodes), 'f1_score_mod': f1_score_mod_wr(n_dur_nodes), \
        'recall_mod': recall_mod_wr(n_dur_nodes), 'precision_mod': \
        precision_mod_wr(n_dur_nodes), 'dur_error': \
        dur_error_wr(n_dur_nodes), 'maestro_dur_loss': \
        maestro_dur_loss_wr(harshness, n_dur_nodes)}

    return load_model(file_path, custom_objects = custom_objects)

In [13]:
opt = RMSprop()

In [16]:
def prune_model_with_checkpoint(model, filename = 'best_pruned_maestro_model_ext20_2_1_1024_0pt4_mnv_2.h5', \
                                harshness = 0.05, n_dur_nodes = 20, batch_size = 512, epochs = 2, \
                                initial_sparsity = 0.5, final_sparsity = 0.8):
    
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
    end_step = np.ceil(X_train.shape[0] / batch_size) * epochs
    
    # Define model for pruning.
    pruning_params = {
          'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
              initial_sparsity = initial_sparsity, final_sparsity = final_sparsity,
              begin_step=0, end_step=end_step)
    }
    model_for_pruning = prune_low_magnitude(model, **pruning_params)
    
    # `prune_low_magnitude` requires a recompile.
    model_for_pruning.compile(loss = maestro_loss_wr(harshness, n_dur_nodes), 
                          optimizer = opt, 
                          metrics = [f1_score_mod_wr(n_dur_nodes), recall_mod_wr(n_dur_nodes), \
                                     precision_mod_wr(n_dur_nodes), dur_error_wr(n_dur_nodes), \
                                     maestro_dur_loss_wr(harshness, n_dur_nodes)])

    model_for_pruning.summary()
    
    logdir = tempfile.mkdtemp()

    mc = ModelCheckpoint('../models/' + filename, monitor = 'val_loss', mode = 'min', \
                                                            save_best_only = True, verbose = 1)
    callbacks = [
      tfmot.sparsity.keras.UpdatePruningStep(),
      tfmot.sparsity.keras.PruningSummaries(log_dir = logdir),
      mc, 
      TerminateOnNaN()
    ]

    model_for_pruning.fit(X_train, y_train, batch_size = batch_size, epochs = epochs, 
                      validation_data = (X_val, y_val), verbose = 2, callbacks = callbacks)
    
    return model

In [17]:
model = load_model_from_file('../models/best_maestro_model_ext20_2_1_1024_0pt4_mnv_2.h5')
pruned_model = prune_model_with_checkpoint(model)

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
prune_low_magnitude_lstm_4 ( (None, 16, 1024)          9277443   
_________________________________________________________________
prune_low_magnitude_dropout_ (None, 16, 1024)          1         
_________________________________________________________________
prune_low_magnitude_lstm_5 ( (None, 1024)              16781315  
_________________________________________________________________
prune_low_magnitude_dropout_ (None, 1024)              1         
_________________________________________________________________
prune_low_magnitude_dense_4  (None, 512)               1049090   
_________________________________________________________________
prune_low_magnitude_activati (None, 512)               1         
_________________________________________________________________
prune_low_magnitude_dropout_ (None, 512)              

RuntimeError: Unable to create link (name already exists)

In [18]:
def prune_model(model, filename = 'best_pruned_maestro_model_ext20_2_1_1024_0pt4_mnv_2.h5', harshness = 0.05, \
                n_dur_nodes = 20, batch_size = 512, epochs = 50, initial_sparsity = 0.5, final_sparsity = 0.8):
    
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
    end_step = np.ceil(X_train.shape[0] / batch_size) * epochs
    
    # Define model for pruning.
    pruning_params = {
          'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
              initial_sparsity = initial_sparsity, final_sparsity = final_sparsity,
              begin_step=0, end_step=end_step)
    }
    model_for_pruning = prune_low_magnitude(model, **pruning_params)
    
    # `prune_low_magnitude` requires a recompile.
    model_for_pruning.compile(loss = maestro_loss_wr(harshness, n_dur_nodes), 
                          optimizer = opt, 
                          metrics = [f1_score_mod_wr(n_dur_nodes), recall_mod_wr(n_dur_nodes), \
                                     precision_mod_wr(n_dur_nodes), dur_error_wr(n_dur_nodes), \
                                     maestro_dur_loss_wr(harshness, n_dur_nodes)])

    #model_for_pruning.summary()
    
    logdir = tempfile.mkdtemp()

    callbacks = [
      tfmot.sparsity.keras.UpdatePruningStep(),
      tfmot.sparsity.keras.PruningSummaries(log_dir = logdir),
      TerminateOnNaN()
    ]
    filepath = '../models/' + filename
    # ModelCheckpoint is giving a funny error (RuntimeError: Unable to create link (name already exists), 
    # so here is my workaround:
    print('Epoch 1/{}'.format(epochs))
    history = model_for_pruning.fit(X_train, y_train, batch_size = batch_size, epochs = 1, 
                validation_data = (X_val, y_val), verbose = 2, callbacks = callbacks)
    if (np.isnan(history.history['val_loss'][0])): # NaN failure in first epoch
        return model
    else:
        min_val_loss = history.history['val_loss'][0]
        print('val_loss is {a:2.5f}, saving model to {b}'.format(a = min_val_loss, b = filepath))
        model.save(filepath, save_format = 'h5')
        
    for i in range(epochs - 1):
        print('Epoch {}/{}'.format(i + 2, epochs))
        history = model_for_pruning.fit(X_train, y_train, batch_size = batch_size, epochs = 1, 
                      validation_data = (X_val, y_val), verbose = 2, callbacks = callbacks)
        if (np.isnan(history.history['val_loss'][0])): # NaN failure  
            break
        else:
            if (history.history['val_loss'][0] < min_val_loss):
                print('val_loss improved from {a:2.5f} to {b:2.5f}, saving model to {c}'.format(\
                            a = min_val_loss, b = history.history['val_loss'][0], c = filepath))
                model.save(filepath, save_format = 'h5')
                min_val_loss = history.history['val_loss'][0]
    return model

In [21]:
model = load_model_from_file('../models/best_maestro_model_ext20_2_1_1024_0pt4_mnv_2.h5')
pruned_model = prune_model(model, epochs = 20)

Epoch 1/20
50/50 - 579s - loss: 0.0578 - f1_score_mod: 0.6506 - recall_mod: 0.5454 - precision_mod: 0.8072 - dur_error: 0.1638 - maestro_dur_loss: 0.0082 - val_loss: 0.0774 - val_f1_score_mod: 0.5764 - val_recall_mod: 0.4769 - val_precision_mod: 0.7287 - val_dur_error: 0.1716 - val_maestro_dur_loss: 0.0086
val_loss is 0.07739, saving model to ../models/best_pruned_maestro_model_ext20_2_1_1024_0pt4_mnv_2.h5
Epoch 2/20
50/50 - 578s - loss: 0.0640 - f1_score_mod: 0.5692 - recall_mod: 0.4340 - precision_mod: 0.8344 - dur_error: 0.1823 - maestro_dur_loss: 0.0091 - val_loss: 0.0764 - val_f1_score_mod: 0.5611 - val_recall_mod: 0.4450 - val_precision_mod: 0.7595 - val_dur_error: 0.1676 - val_maestro_dur_loss: 0.0084
val_loss improved from 0.07739 to 0.07642, saving model to ../models/best_pruned_maestro_model_ext20_2_1_1024_0pt4_mnv_2.h5
Epoch 3/20
50/50 - 587s - loss: 0.0559 - f1_score_mod: 0.6445 - recall_mod: 0.5240 - precision_mod: 0.8377 - dur_error: 0.1619 - maestro_dur_loss: 0.0081 - va

In [13]:
model = load_model_from_file('../models/best_maestro_model_2_1_512_0pt4_lr_5e-04_cv_0pt2.h5')
pruned_model = prune_model(model, epochs = 50, initial_sparsity = 0.6)

Epoch 1/50
50/50 - 257s - loss: 0.1844 - f1_score_mod: 0.1412 - recall_mod: 0.0783 - precision_mod: 0.7431 - dur_error: 1.7618 - maestro_dur_loss: 0.0881 - val_loss: 0.1770 - val_f1_score_mod: 0.1812 - val_recall_mod: 0.1016 - val_precision_mod: 0.8396 - val_dur_error: 1.7588 - val_maestro_dur_loss: 0.0879
val_loss is 0.17695, saving model to ../models/best_maestro_model_pruned_2_1_512_0pt4_lr_5e-04_cv_0pt2.h5
Epoch 2/50
50/50 - 247s - loss: 0.1776 - f1_score_mod: 0.1847 - recall_mod: 0.1053 - precision_mod: 0.7681 - dur_error: 1.7587 - maestro_dur_loss: 0.0879 - val_loss: 0.1734 - val_f1_score_mod: 0.2140 - val_recall_mod: 0.1230 - val_precision_mod: 0.8271 - val_dur_error: 1.7559 - val_maestro_dur_loss: 0.0878
val_loss improved from 0.17695 to 0.17344, saving model to ../models/best_maestro_model_pruned_2_1_512_0pt4_lr_5e-04_cv_0pt2.h5
Epoch 3/50
50/50 - 246s - loss: 0.1745 - f1_score_mod: 0.2019 - recall_mod: 0.1162 - precision_mod: 0.7764 - dur_error: 1.7562 - maestro_dur_loss: 0.0

In [14]:
model = load_model_from_file('../models/best_maestro_model_2_1_512_0pt4_lr_5e-04_cv_0pt2.h5')
pruned_model = prune_model(model, epochs = 50, initial_sparsity = 0.6, final_sparsity = 0.7)

Epoch 1/50
50/50 - 289s - loss: 0.0865 - f1_score_mod: 0.4790 - recall_mod: 0.3529 - precision_mod: 0.7469 - dur_error: 0.3443 - maestro_dur_loss: 0.0172 - val_loss: 0.0920 - val_f1_score_mod: 0.4464 - val_recall_mod: 0.3265 - val_precision_mod: 0.7063 - val_dur_error: 0.2490 - val_maestro_dur_loss: 0.0125
val_loss is 0.09196, saving model to ../models/best_maestro_model_pruned_2_1_512_0pt4_lr_5e-04_cv_0pt2.h5
Epoch 2/50
50/50 - 277s - loss: 0.1139 - f1_score_mod: 0.3872 - recall_mod: 0.2791 - precision_mod: 0.7536 - dur_error: 0.7666 - maestro_dur_loss: 0.0383 - val_loss: 0.1787 - val_f1_score_mod: 0.1993 - val_recall_mod: 0.1135 - val_precision_mod: 0.8234 - val_dur_error: 1.7600 - val_maestro_dur_loss: 0.0880
Epoch 3/50
50/50 - 277s - loss: 0.1747 - f1_score_mod: 0.2171 - recall_mod: 0.1267 - precision_mod: 0.7701 - dur_error: 1.7603 - maestro_dur_loss: 0.0880 - val_loss: 0.1700 - val_f1_score_mod: 0.2602 - val_recall_mod: 0.1542 - val_precision_mod: 0.8345 - val_dur_error: 1.7585 -