In [None]:
import tensorflow as tf

### Consume the TFrecord data

*  https://www.tensorflow.org/programmers_guide/datasets
*  https://github.com/Tony607/Keras_catVSdog_tf_estimator/blob/master/keras_estimator_vgg16-cat_vs_dog-TFRecord.ipynb

In [None]:
filenames = [ './librivox/guidetomen_%02d_rowland_64kb.mp3' % (i,) for i in [1,2,3]]
filenames

In [None]:
steps_total, steps_leadin = 1024, 64
mel_bins, spectra_bins = 80, 1025

batch_size, num_epochs = 8, 10

In [None]:
def dataset_from_mp3(filenames, stub='train'):
    dataset = tf.data.TFRecordDataset([f.replace('.mp3', '_%s.tfrecords') % stub 
                                       for f in filenames])

    features = {
      "mel":          tf.FixedLenFeature([mel_bins*steps_total], tf.float32),
      "spectra_real": tf.FixedLenFeature([spectra_bins*steps_total], tf.float32),
      "spectra_imag": tf.FixedLenFeature([spectra_bins*steps_total], tf.float32),
    }
    
    def _parse_function(example_proto):
        parsed_features = tf.parse_single_example(example_proto, features)
        mel         = parsed_features["mel"].reshape( (mel_bins, steps_total) )
        specta_real = parsed_features["spectra_real"].reshape( (spectra_bins, steps_total) )
        specta_imag = parsed_features["spectra_imag"].reshape( (spectra_bins, steps_total) )
        spectra = tf.stack( [specta_real, specta_imag, ] ) # Should be (spectra_bins, steps_total, 2)
        return tf.transpose(mel), tf.transpose(spectra, perm=[1,0,2])

    dataset = dataset.map(_parse_function)
    return dataset

def input_fn_from(filenames, stub='train', batch_size=1, shuffle=False, repeats=1):
    dataset = dataset_from_mp3(filenames, stub=stub)
    if shuffle:
        dataset = dataset.shuffle(buffer_size=100)
    dataset = dataset_train.batch(batch_size).repeat(repeats)
    
    iterator_train = dataset.make_one_shot_iterator()    
    
    batch_features, batch_labels = iterator.get_next()
    return batch_features, batch_labels    

### Keras model 

Keras-WaveNet : 


*  https://github.com/usernaamee/keras-wavenet/blob/master/simple-generative-model-regressor.py 
*  Beware of GPL3 license!  Code independently...
*  But, also, the number of convolutional channels seems pretty arbitrary, so may not be fleshed-out yet

TF-WaveNet :
*  https://github.com/ibab/tensorflow-wavenet
*  MIT license : Feel free to look + adapt


In [None]:
# See : https://github.com/tensorflow/tensorflow/issues/14933
#   to understand how broken Google is

from tensorflow.python import keras
from tensorflow.python.keras import backend as K

# Use 'real keras' to get the actual documented functionality for padding='causal'
#import keras
#from keras import backend as K

def wavenet_layer(channels, hidden_channels, kernel_size, dilation_rate, name):
    def f(input_):
        filter_out = keras.layers.Conv1D(hidden_channels, kernel_size,
                                       strides=1, dilation_rate=dilation_rate,
                                       padding='valid', use_bias=True, 
                                       activation='tanh', name='filter_'+name)(input_)
        gate_out   = keras.layers.Conv1D(hidden_channels, kernel_size,
                                       strides=1, dilation_rate=dilation_rate,
                                       padding='valid', use_bias=True, 
                                       activation='sigmoid', name='gate_'+name)(input_)
        mult = keras.layers.Multiply(name='mult_'+name)( [filter_out, gate_out] )
        
        # Need to pad this result back out to input_ size...
        #print(dilation_rate, kernel_size, dilation_rate*(kernel_size-1))
        
        #def original_shape(input_shape):
        #    return (input_shape[0], input_shape[1]+dilation_rate*kernel_size-2, input_shape[2])
        # 
        #mult_padded = keras.layers.Lambda(
        #    lambda x: K.temporal_padding(x, padding=(dilation_rate*kernel_size-1,0) ), 
        #    #output_shape=original_shape,
        #    name='mult_padded_'+name)(mult)
        
        # https://www.tensorflow.org/api_docs/python/tf/keras/layers/ZeroPadding1D
        mult_padded = keras.layers.ZeroPadding1D( (dilation_rate*(kernel_size-1), 0) )(mult)

        transformed = keras.layers.Conv1D(channels, 1, 
                                          padding='same', use_bias=True, 
                                          activation='linear', name='trans_'+name)(mult_padded)
        skip_out    = keras.layers.Conv1D(channels, 1, 
                                          padding='same', use_bias=True, 
                                          activation='relu', name='skip_'+name)(mult_padded)
        
        return keras.layers.Add(name='resid_'+name)( [transformed, input_] ), skip_out
      
    return f

log_amplitude_with_minimum = keras.layers.Lambda( lambda x: K.log( K.maximum(0.00001, x) ))

io_channels, hidden_channels = 128,128
def model_mel_to_spec( input_shape=(steps_total, mel_bins) ):
    mel = keras.layers.Input(shape=input_shape, name='MelInput')
    #mel = tf.keras.layers.Input(shape=input_shape, name='MelInput')
    #mel = tf.keras.layers.Input(batch_size=batch_size, shape=input_shape, name='MelInput')
    #mel = keras.layers.Input(batch_size=batch_size, shape=input_shape, name='MelInput')
    #mel = keras.layers.Input(shape=input_shape, name='MelInput',
    #                        _batch_input_shape = (batch_size, steps_total, mel_bins))
    
    #mel = keras.layers.Input(batch_shape=(batch_size, steps_total, mel_bins), name='MelInput')
    #mel._batch_input_shape = (batch_size, steps_total, mel_bins)
    #mel = keras.layers.InputLayer(input_shape=input_shape, name='MelInput')
    
    #mel_floored = K.maximum(0.00001, mel)
    #mel_log     = K.log(mel_floored)  # This is (batch, T. channels)
    
    mel_log = log_amplitude_with_minimum(mel)

    # 'Resize' to make everything 'io_channels' big at the layer interfaces
    x = keras.layers.Conv1D(io_channels, 1, 
                          padding='same', use_bias=True, 
                          activation='linear', name='mel_log_expanded')(mel_log)
    
    x,s1 = wavenet_layer(io_channels, hidden_channels*1, 3, 1, '1')(x)
    x,s2 = wavenet_layer(io_channels, hidden_channels*1, 3, 2, '2')(x)
    x,s3 = wavenet_layer(io_channels, hidden_channels*1, 3, 4, '3')(x)
    x,s4 = wavenet_layer(io_channels, hidden_channels*1, 3, 8, '4')(x)
    _,s5 = wavenet_layer(io_channels, hidden_channels*1, 3,16, '5')(x)  # Total footprint is ~64 0.75secs
    # x is now irrelevant
    
    skip_overall = keras.layers.Concatenate( axis=-1 )( [s1,s2,s3,s4,s5] )
    
    log_amp = keras.layers.Conv1D(spectra_bins, 1, padding='same', 
                                  activation='linear', name='log_amp')(skip_overall)
    phase   = keras.layers.Conv1D(spectra_bins, 1, padding='same', 
                                  activation='linear', name='phase')(skip_overall)
    
    #amp = K.exp(log_amp)
    amp = keras.layers.Lambda( lambda x: K.exp(x), name='amp')(log_amp)
    
    return keras.models.Model(inputs=[mel], outputs=[amp, phase])

    #spec_real = keras.layers.Multiply()( [amp, K.cos(phase)] )
    #spec_imag = keras.layers.Multiply()( [amp, K.sin(phase)] )
    
    #spec = keras.layers.Stack( [spec_real, spec_imag] )
    
    #return keras.models.Model(inputs=mel, outputs=[spec_real, spec_imag])

keras_model = model_mel_to_spec()
keras_model.summary()

In [None]:
input_name = keras_model.input_names[0]
input_name

In [None]:
def TODO_customLoss(spec_gold, spec_out):
    # Convert the stacked spectra components in spec_gold to amp and phase too
    gold_real = keras.layers.Lambda(lambda x : x[:,:,:,0])(spec_gold)
    gold_imag = keras.layers.Lambda(lambda x : x[:,:,:,1])(spec_gold)
    
    gold_amp = log_amplitude_with_minimum(gold_real**2 + gold_imag**2)
    
    # Hmmmm
    
    return K.sum(K.log(yTrue) - K.log(yPred))

keras_model.compile(loss='mse', 
                    optimizer=keras.optimizers.RMSprop(lr=2e-5),
                    metrics=['mse'])

In [None]:
import os
model_dir = os.path.join(os.getcwd(), 'models', 'mel-to-complex-spectra')
os.makedirs(model_dir, exist_ok=True)
print("model_dir: ",model_dir)

estimator = tf.keras.estimator.model_to_estimator(keras_model=keras_model, model_dir=model_dir)