Skip to content

Commit

Permalink
Let data prep work with models from the distillation pipeline, which …
Browse files Browse the repository at this point in the history
…don't have 'sample rate' as an argument.

PiperOrigin-RevId: 372942927
  • Loading branch information
joel-shor authored and Copybara-Service committed May 10, 2021
1 parent c88f1d4 commit 4042809
Showing 1 changed file with 9 additions and 7 deletions.
Expand Up @@ -58,8 +58,11 @@ def _tfexample_audio_to_npfloat32(ex, audio_key):

def _samples_to_embedding_tfhub(model_input, sample_rate, mod, output_key):
"""Run inference to map audio samples to an embedding."""
tf_out = mod(
tf.constant(model_input, tf.float32), tf.constant(sample_rate, tf.int32))
# Models either take 2 args (input, sample_rate) or 1 arg (input). Try both.
try:
tf_out = mod(model_input, sample_rate)
except ValueError:
tf_out = mod(model_input)
return np.array(tf_out[output_key])


Expand Down Expand Up @@ -92,11 +95,10 @@ def _samples_to_embedding_tflite(model_input, sample_rate, interpreter,
interpreter.resize_tensor_input(input_details[0]['index'], model_input.shape)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], model_input)
# TODO(joelshor): When `sample_rate` gets added to the TFLite API, add it
# here.
del sample_rate
# interpreter.set_tensor(input_details[1]['index'],
# np.array(sample_rate).astype(np.int32))
# Models either take 2 args (input, sample_rate) or 1 arg (input). Try both.
if len(input_details) > 1:
interpreter.set_tensor(input_details[1]['index'],
np.array(sample_rate).astype(np.int32))

interpreter.invoke()
embedding_2d = interpreter.get_tensor(
Expand Down

0 comments on commit 4042809

Please sign in to comment.