In [1]:
import tensorflow as tf
import magenta_model as mm

In [2]:
def convert(model, name="model.tflite"):
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    tflite_model = converter.convert()
    open(name,'wb').write(tflite_model)
    
def num_frames(duration, sample_rate, win_size, hop_size):
    return int(1 + (duration * sample_rate - win_size)/ hop_size)

def total_samples(frames, hop_size, win_size):
    return int((frames - 1)*hop_size + win_size)

In [3]:
model = mm.waveform_to_melspectrum(win_size=2048, fft_size=2048, mels=128)
convert(model,"melspectrum.tflite")
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input (InputLayer)              [(1, 2048)]          0                                            
__________________________________________________________________________________________________
tf_op_layer_Mul (TensorFlowOpLa [(1, 2048)]          0           input[0][0]                      
__________________________________________________________________________________________________
tf_op_layer_Pad (TensorFlowOpLa [(1, 2048)]          0           tf_op_layer_Mul[0][0]            
__________________________________________________________________________________________________
tf_op_layer_MatMul (TensorFlowO [(1, 1025)]          0           tf_op_layer_Pad[0][0]            
______________________________________________________________________________________________

In [4]:
tflite_batchsize = 1
input_shape = (1, total_samples(128, 1024, 2048))
x = tf.keras.layers.Input(batch_shape=(tflite_batchsize, ) + tuple(input_shape), dtype='float32', name='input')
print("input_shape:",x.shape)
y = mm.waveform_to_mel_spectrogram(x, sample_rate=48000, win_hop=1024, win_length=2048, fft_size=2048, num_mel_bins=128, lower_edge_hertz=80, upper_edge_hertz=20000)
model = tf.keras.Model(inputs=x, outputs=[y])
convert(model,"melspectrogram.tflite")
model.summary()

input_shape: (1, 1, 132096)
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input (InputLayer)              [(1, 1, 132096)]     0                                            
__________________________________________________________________________________________________
tf_op_layer_Slice (TensorFlowOp [(1, 1, 132096)]     0           input[0][0]                      
__________________________________________________________________________________________________
tf_op_layer_Reshape (TensorFlow [(1, 1, 129, 1024)]  0           tf_op_layer_Slice[0][0]          
__________________________________________________________________________________________________
tf_op_layer_GatherV2 (TensorFlo [(1, 1, 128, 2, 1024 0           tf_op_layer_Reshape[0][0]        
________________________________________________________________