In [None]:
import os
import numpy as np
import random
import librosa
import h5py
import tensorflow as tf
import keras
from keras.optimizers import Adam
from l3embedding.audio import pcm2float
from resampy import resample
import pescador
from skimage import img_as_float

In [None]:
import keras


In [None]:

from keras.optimizers import Adam

In [None]:
def cycle_shuffle(iterable, shuffle=True):
    lst = list(iterable)
    while True:
        yield from lst
        if shuffle:
            random.shuffle(lst)

In [None]:
def amplitude_to_db(S, amin=1e-10, dynamic_range=80.0):
    magnitude = np.abs(S)
    power = np.square(magnitude, out=magnitude)
    ref_value = power.max()

    log_spec = 10.0 * np.log10(np.maximum(amin, magnitude))
    log_spec -= log_spec.max()

    log_spec = np.maximum(log_spec, -dynamic_range)
    return log_spec

In [None]:
def get_melspectrogram(frame, n_fft=2048, mel_hop_length=242, samp_rate=48000, n_mels=256, fmax=None):
    S = np.abs(librosa.core.stft(frame, n_fft=n_fft, hop_length=mel_hop_length,\
                                 window='hann', center=True, pad_mode='constant'))
    S = librosa.feature.melspectrogram(sr=samp_rate, S=S, n_mels=n_mels, fmax=fmax,\
                                           power=1.0, htk=True)
    S = amplitude_to_db(np.array(S))
    return S

In [None]:
def data_generator(data_dir, batch_size=512, samp_rate=48000,\
                   n_fft=2048, n_mels=256, mel_hop_length=252, hop_size=0.1, fmax=None,\
                   random_state=20180123, start_batch_idx=None, keys=None):
    
    random.seed(random_state)
    hop_length = int(hop_size * samp_rate)
    frame_length = samp_rate * 1
    
    batch = None
    curr_batch_size = 0
    batch_idx = 0

    # Limit keys to avoid producing batches with all of the metadata fields
    if not keys:
        keys = ['audio', 'video', 'label']

    for fname in cycle_shuffle(os.listdir(data_dir)):
        batch_path = os.path.join(data_dir, fname)
        blob_start_idx = 0

        blob = h5py.File(batch_path, 'r')
        blob_size = len(blob['label'])

        while blob_start_idx < blob_size:
            blob_end_idx = min(blob_start_idx + batch_size - curr_batch_size, blob_size)

            # If we are starting from a particular batch, skip computing all of
            # the prior batches
            if start_batch_idx is None or batch_idx >= start_batch_idx:
                if batch is None:
                    batch = {k:blob[k][blob_start_idx:blob_end_idx]
                             for k in keys}
                else:
                    for k in keys:
                        batch[k] = np.concatenate([batch[k],
                                                   blob[k][blob_start_idx:blob_end_idx]])

            curr_batch_size += blob_end_idx - blob_start_idx
            blob_start_idx = blob_end_idx

            if blob_end_idx == blob_size:
                blob.close()

            if curr_batch_size == batch_size:
                # If we are starting from a particular batch, skip yielding all
                # of the prior batches
                if start_batch_idx is None or batch_idx >= start_batch_idx:
                    # Preprocess video so samples are in [-1,1]
                    batch['video'] = 2 * img_as_float(batch['video']).astype('float32') - 1

                    # Convert audio to float
                    if(samp_rate==48000):
                        batch['audio'] = pcm2float(batch['audio'], dtype='float32')
                    else:
                        batch['audio'] = resample(pcm2float(batch['audio'], dtype='float32'), sr_orig=48000,
                                                  sr_new=samp_rate)
                
                    X = [get_melspectrogram(batch['audio'][i].flatten(), n_fft=n_fft, mel_hop_length=mel_hop_length,\
                                            samp_rate=samp_rate, n_mels=n_mels) for i in range(batch_size)]
                    
                    batch['audio'] = np.array(X)[:, :, :, np.newaxis]
                    #print(np.shape(batch['audio'])) #(64, 256, 191, 1)
                    yield batch

                batch_idx += 1
                curr_batch_size = 0
                batch = None

In [None]:
def single_epoch_data_generator(data_dir, epoch_size, **kwargs):
    while True:
        data_gen = data_generator(data_dir, **kwargs)
        for idx, item in enumerate(data_gen):
            yield item
            # Once we generate all batches for an epoch, restart the generator
            if (idx + 1) == epoch_size:
                break

In [None]:
samp_rate = 48000
num_epochs = 1

train_data_dir = '/beegfs/work/AudioSetSamples/music_train'
validation_data_dir = '/beegfs/work/AudioSetSamples/music_valid'

train_batch_size = 64
validation_batch_size = 32

train_epoch_size = 64
validation_epoch_size = 64

train_gen = data_generator(train_data_dir,\
                           batch_size=train_batch_size,\
                           samp_rate=samp_rate)

train_gen = pescador.maps.keras_tuples(train_gen,
                                       ['video', 'audio'],
                                       'label')

val_gen = single_epoch_data_generator(validation_data_dir,\
                                      validation_epoch_size,\
                                      batch_size=validation_batch_size,\
                                      samp_rate=samp_rate)

val_gen = pescador.maps.keras_tuples(val_gen,
                                     ['video', 'audio'],
                                     'label')

In [None]:
loss = 'categorical_crossentropy'
metrics = ['accuracy']
learning_rate = 0.00001
model_path = '/scratch/sk7898/l3pruning/embedding/fixed/reduced_input/l3_audio_original_48000_256_252_2048.h5'
model = keras.models.load_model(model_path)
model.compile(Adam(lr=learning_rate),
                  loss=loss,
                  metrics=metrics)
history = model.fit_generator(train_gen, train_epoch_size, num_epochs,\
                              validation_data=val_gen,\
                              validation_steps=validation_epoch_size,\
                              verbose=True,\
                              initial_epoch=0)

In [None]:
#Load and Quantize model without melspectrogram
out_path = "/scratch/sk7898/quantization/" + os.path.basename(model_path).strip('.h5') +"/checkpoints"

train_graph = tf.Graph()
train_sess = tf.Session(graph=train_graph)

keras.backend.set_session(train_sess)
with train_graph.as_default():
    model = keras.models.load_model(model_path)
    tf.contrib.quantize.create_training_graph(input_graph=train_graph, quant_delay=100)
    train_sess.run(tf.global_variables_initializer())
    
    model.compile(Adam(lr=learning_rate),
                      loss=loss,
                      metrics=metrics)
    history = model.fit_generator(train_gen, train_epoch_size, num_epochs,\
                                  validation_data=val_gen,\
                                  validation_steps=validation_epoch_size,\
                                  verbose=True,\
                                  initial_epoch=0)
    #save graph and checkpoints
    #saver = tf.train.Saver()
    #saver.save(train_sess, out_path)

In [None]:
#Save the Quantized model without melspectrogram 
#Save the checkpoint and eval graph proto to disk for freezing and providing to TFLite.
with open(eval_graph_file, 'w') as f:
    f.write(str(g.as_graph_def()))
    
saver = tf.train.Saver()
saver.save(sess, checkpoint_name)

In [None]:
# Convert keras model to tflite
model_path = '/scratch/sk7898/l3pruning/embedding/fixed/reduced_input/l3_audio_original_48000_256_252_2048.h5'
tflite_model_file = "/scratch/sk7898/quantization/quantized_" + os.path.basename(model_path).strip('.h5') +".tflite"
converter = tf.contrib.lite.TFLiteConverter.from_keras_model_file(model_path)

tflite_model = converter.convert()
with open(tflite_model_file, "wb") as f:
    f.write(tflite_model)

In [None]:
#Evaluate AVC
