In [1]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

In [2]:
import os
import numpy as np
import random
import librosa
import h5py
import tensorflow as tf
import keras
from keras.optimizers import Adam
from l3embedding.audio import pcm2float
from resampy import resample
import pescador
from skimage import img_as_float

Using TensorFlow backend.


In [3]:
def cycle_shuffle(iterable, shuffle=True):
    lst = list(iterable)
    while True:
        yield from lst
        if shuffle:
            random.shuffle(lst)

In [4]:
def amplitude_to_db(S, amin=1e-10, dynamic_range=80.0):
    magnitude = np.abs(S)
    power = np.square(magnitude, out=magnitude)
    ref_value = power.max()

    log_spec = 10.0 * np.log10(np.maximum(amin, magnitude))
    log_spec -= log_spec.max()

    log_spec = np.maximum(log_spec, -dynamic_range)
    return log_spec

In [5]:
def get_melspectrogram(frame, n_fft=2048, mel_hop_length=242, samp_rate=48000, n_mels=256, fmax=None):
    S = np.abs(librosa.core.stft(frame, n_fft=n_fft, hop_length=mel_hop_length, window='hann', center=True, pad_mode='constant'))
    S = librosa.feature.melspectrogram(sr=samp_rate, S=S, n_fft=n_fft, n_mels=n_mels, fmax=fmax, power=1.0, htk=True)
    S = amplitude_to_db(np.array(S))
    return S

In [6]:
def data_generator(data_dir, emb_dir, batch_size=512, samp_rate=48000, n_fft=2048, \
                   n_mels=256, mel_hop_length=242, hop_size=0.1, fmax=None,\
                   random_state=20180123, start_batch_idx=None, keys=None, test=False):

    random.seed(random_state)
    frame_length = samp_rate * 1

    batch = None
    curr_batch_size = 0
    batch_idx = 0

    emb_key = 'l3_embedding'

    if test:
        print('Testing phase')
        data_list = os.listdir(data_dir)
    else:
        data_list = cycle_shuffle(os.listdir(data_dir))
        
    for fname in data_list:
        data_batch_path = os.path.join(data_dir, fname)
        emb_batch_path = os.path.join(emb_dir, fname)

        blob_start_idx = 0

        data_blob = h5py.File(data_batch_path, 'r')
        emb_blob = h5py.File(emb_batch_path, 'r')

        blob_size = len(data_blob['audio'])

        while blob_start_idx < blob_size:
            blob_end_idx = min(blob_start_idx + batch_size - curr_batch_size, blob_size)

            # If we are starting from a particular batch, skip computing all of
            # the prior batches
            if start_batch_idx is None or batch_idx >= start_batch_idx:
                if batch is None:
                    batch = {'audio': data_blob['audio'][blob_start_idx:blob_end_idx],\
                             'label': emb_blob[emb_key][blob_start_idx:blob_end_idx]}
                else:
                    batch['audio'] = np.concatenate([batch['audio'], data_blob['audio'][blob_start_idx:blob_end_idx]])
                    batch['label'] = np.concatenate([batch['label'], emb_blob[emb_key][blob_start_idx:blob_end_idx]])

            curr_batch_size += blob_end_idx - blob_start_idx
            blob_start_idx = blob_end_idx

            if blob_end_idx == blob_size:
                data_blob.close()
                emb_blob.close()

            if curr_batch_size == batch_size:
                X = []
                # If we are starting from a particular batch, skip yielding all
                # of the prior batches
                if start_batch_idx is None or batch_idx >= start_batch_idx:
                    # Convert audio to float
                    if(samp_rate==48000):
                        batch['audio'] = pcm2float(batch['audio'], dtype='float32')
                    else:
                        batch['audio'] = resample(pcm2float(batch['audio'], dtype='float32'), sr_orig=48000,
                                                  sr_new=samp_rate)

                    X = [get_melspectrogram(batch['audio'][i].flatten(), n_fft=n_fft, mel_hop_length=mel_hop_length,\
                                            samp_rate=samp_rate, n_mels=n_mels) for i in range(batch_size)]

                    batch['audio'] = np.array(X)[:, :, :, np.newaxis]
                    #print(np.shape(batch['audio'])) #(64, 256, 191, 1)
                    yield batch

                batch_idx += 1
                curr_batch_size = 0
                batch = None

In [7]:
def single_epoch_data_generator(data_dir, emb_dir, epoch_size, **kwargs):
    while True:
        data_gen = data_generator(data_dir, emb_dir, **kwargs)
        for idx, item in enumerate(data_gen):
            yield item
            # Once we generate all batches for an epoch, restart the generator
            if (idx + 1) == epoch_size:
                break

In [8]:
def initialize_uninitialized_variables(sess):
    if hasattr(tf, 'global_variables'):
        variables = tf.global_variables()
    else:
        variables = tf.all_variables()

    uninitialized_variables = []
    for v in variables:
        if not hasattr(v, '_keras_initialized') or not v._keras_initialized:
            uninitialized_variables.append(v)
            v._keras_initialized = True
            
    print(uninitialized_variables)
    if uninitialized_variables:
        if hasattr(tf, 'variables_initializer'):
            sess.run(tf.variables_initializer(uninitialized_variables))
        else:
            sess.run(tf.initialize_variables(uninitialized_variables)) 

In [9]:
def load_l3_audio_model(model_path):
    model = keras.models.load_model(model_path)
    embed_layer = model.get_layer('audio_embedding_layer')
    pool_size = tuple(embed_layer.get_output_shape_at(0)[1:3])
    y_a = keras.layers.MaxPooling2D(pool_size=pool_size, padding='same')(model.output)
    y_a = keras.layers.Flatten()(y_a)
    model = keras.models.Model(inputs=model.input, outputs=y_a)
    return model

In [10]:
def train_quantized_model(model_path, output_path, train=False, train_gen=None, \
                          train_epoch_size=64, num_epochs=5, val_gen=None,\
                          validation_epoch_size=64):
    
    if train and (train_gen is None or val_gen is None):
        raise ValueError('Invalid data (train/valid) generator')
    
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    import keras

    output_path = os.path.join(output_dir, 'checkpoints')
    
    train_graph = tf.Graph() #tf.keras.backend.get_session().graph
    session_conf = tf.ConfigProto(device_count={'GPU' : 0},\
                                  allow_soft_placement=True,\
                                  log_device_placement=False)
    train_sess = tf.Session(config=session_conf, graph=train_graph)
    keras.backend.set_session(train_sess)
    
    with train_graph.as_default():
        optimizer = tf.keras.optimizers.Adam(lr=learning_rate)
        model = load_l3_audio_model(model_path)
        tf.contrib.quantize.create_training_graph(input_graph=train_graph, quant_delay=100)
        initialize_uninitialized_variables(train_sess)
        if train:
            model.compile(optimizer,\
                          loss='mean_squared_error',\
                          metrics=['mae'])

            history = model.fit_generator(train_gen, train_epoch_size, num_epochs,\
                                          validation_data=val_gen,\
                                          validation_steps=validation_epoch_size,\
                                          verbose=1, initial_epoch=0)

        #save graph and checkpoints
        saver = tf.train.Saver()
        saver.save(train_sess, output_path)

In [11]:
def restore_quantized_model(model_path, output_dir):
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    import keras

    output_path = os.path.join(output_dir, 'frozen_model.pb')
    eval_graph = tf.Graph()
    session_conf = tf.ConfigProto(device_count={'GPU' : 0},\
                                  allow_soft_placement=True,\
                                  log_device_placement=False)
    train_sess = tf.Session(config=session_conf, graph=eval_graph)

    keras.backend.set_session(eval_sess)

    with eval_graph.as_default():
        keras.backend.set_learning_phase(0)
        eval_model = load_l3_audio_model(model_path)
        
        tf.contrib.quantize.create_eval_graph(input_graph=eval_graph)
        eval_graph_def = eval_graph.as_graph_def()
        saver = tf.train.Saver()
        saver.restore(eval_sess, 'checkpoints')

        frozen_graph_def = tf.graph_util.convert_variables_to_constants(
            eval_sess,
            eval_graph_def,
            [eval_model.output.op.name]
        )

        with open(output_path, 'wb') as f:
            f.write(frozen_graph_def.SerializeToString())

In [None]:
model_path = '/scratch/sk7898/l3pruning/embedding/fixed/reduced_input/l3_audio_original_48000_256_252_2048.h5'
splits = os.path.basename(model_path).split('.h5')[0].split('_')
samp_rate = int(splits[3])
n_mels = int(splits[4])
mel_hop_length = int(splits[5])
n_fft = int(splits[-1])

output_dir = "/scratch/sk7898/quantization/" + os.path.basename(model_path).strip('.h5')

if not os.path.isdir(output_dir):
    os.makedirs(output_dir)
    
num_epochs = 1
learning_rate = 0.00001
train_data_dir = '/beegfs/work/AudioSetSamples/music_train'
validation_data_dir = '/beegfs/work/AudioSetSamples/music_valid'

train_batch_size = 64
train_epoch_size = 64

validation_epoch_size = 3
validation_batch_size = 10


train_emb_dir='/scratch/sk7898/orig_l3_embeddings/music_train'
val_emb_dir='/scratch/sk7898/orig_l3_embeddings/music_valid'

train_gen = data_generator(train_data_dir, train_emb_dir, batch_size=train_batch_size, samp_rate=samp_rate,\
                           n_fft=n_fft, n_mels=n_mels, mel_hop_length=mel_hop_length)

train_gen = pescador.maps.keras_tuples(train_gen,
                                       'audio',
                                       'label')

val_gen = single_epoch_data_generator(validation_data_dir, val_emb_dir, validation_epoch_size,\
                                      batch_size=validation_batch_size, samp_rate=samp_rate,\
                                      n_fft=n_fft, n_mels=n_mels, mel_hop_length=mel_hop_length)

val_gen = pescador.maps.keras_tuples(val_gen,
                                     'audio',
                                     'label')

#Load and Quantize model without melspectrogram
#train_quantized_model(model_path, output_dir)

#restore_quantized_model(model_path, output_dir)