In [278]:
import tensorflow as tf
import tensorflow_hub as hub ##for loading full tf model


import numpy as np
import librosa  
import soundfile as sf ##may not be needed
import os, json, re
from itertools import groupby

from scipy.io import wavfile

# Model Conversion
The only available speech recognition tflite model available is not very good. We can try to convert a better model to tflite.  
I'll try to use [this one](https://tfhub.dev/vasudevgupta7/wav2vec2-960h/1) which is a finetuned wav2vec2 model. 

In [112]:
# Load full tf model from tensorflow hub (about 1m 39s)
wav2vec_url = "https://tfhub.dev/vasudevgupta7/wav2vec2-960h/1"
full_model = hub.load(wav2vec_url)

In [113]:
# Save the model to the 'saved model' format, which is the recommended format to convert from
tf.saved_model.save(full_model, 'tf_wave2vec') ##(about 1m 24s)



INFO:tensorflow:Assets written to: wave2vec\assets


INFO:tensorflow:Assets written to: wave2vec\assets


In [114]:
# Use converter to convert model to tflite (about 2m 30s)
converter = tf.lite.TFLiteConverter.from_saved_model('./tf_wave2vec/')
tflite_model = converter.convert()

In [115]:
# Save the tflite model
with open('wave2vec2-960h.tflite', 'wb') as f:
  f.write(tflite_model)

In [343]:
# Remove saved tensorflow model...

# Testing the model

In [437]:
model_path = "./wave2vec2-960h.tflite"

wav_path = "./test_audio/recording.wav"
REQUIRED_SAMPLE_RATE = 16000
MAX_LENGTH = 246000

In [438]:
signal, sample_rate = librosa.load(wav_path, sr=REQUIRED_SAMPLE_RATE, mono=True)
signal.shape

(153920,)

In [431]:
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()



input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(f"{input_details = }")
print("")
print(f"{output_details = }")

input_details = [{'name': 'serving_default_input_1:0', 'index': 0, 'shape': array([    1, 50000]), 'shape_signature': array([   -1, 50000]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]

output_details = [{'name': 'StatefulPartitionedCall:0', 'index': 1347, 'shape': array([ 1,  1, 32]), 'shape_signature': array([-1, -1, 32]), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]


In [439]:
def normalize_pad(x, pad=True):
  """
  Normalize and pad input signal to match preprocessing of the model.
  Methodology from: https://github.com/thevasudevgupta/gsoc-wav2vec2/blob/main/src/wav2vec2/processor.py
  """
  MAX_LENGTH = 246000 ##https://tfhub.dev/vasudevgupta7/wav2vec2-960h/1
  #normalize()
  mean = np.mean(x, axis=-1, keepdims=True)
  var = np.var(x, axis=-1, keepdims=True)
  x = np.squeeze((x - mean) / np.sqrt(var + 1e-5))
  #pad
  if pad:
    padding = np.zeros(MAX_LENGTH - x.shape[0])
    x = np.concatenate((x, padding))
  return x

def resize_input_seq(interpreter, speech):
  "Resize the input signal to the size that the model will accept"
  _, seq_length = interpreter.get_input_details()[0]['shape']
  speech = np.resize(speech, (1, seq_length))
  return speech

def set_input_tensor(interpreter, speech):
  tensor_index = interpreter.get_input_details()[0]['index']
  input_tensor = interpreter.tensor(tensor_index)()[0]
  input_tensor[:] = speech

def classify_speech(interpreter, speech):
  speech = normalize_pad(speech, pad=True)
  speech = resize_input_seq(interpreter, speech)
  
  set_input_tensor(interpreter, speech)
  interpreter.invoke()
  output_details = interpreter.get_output_details()[0]
  output = np.squeeze(interpreter.get_tensor(output_details['index']))
  
  return np.squeeze(np.argmax(output, axis=-1))    



output = classify_speech(interpreter, signal)

In [440]:
output

array([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0, 24,  0,  5,  0,  0,  0, 15,  0,  0,  0,  8, 18,  0,  4,
        4,  4,  4,  4,  6, 11,  0,  0,  0,  5,  0,  0, 13,  5,  0,  4,  4,
        4,  4,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0], dtype=int64)

In [441]:
def get_vocab(vocab_path):
    with open(vocab_path, "r") as f:
        vocab = json.load(f)
    return vocab


token_to_id_mapping = get_vocab("./vocab.json")
id_to_token_mapping = {v: k for k, v in token_to_id_mapping.items()}

unk_token = "<unk>"
unk_id = token_to_id_mapping[unk_token]

dimiliter_token = "|"
dimiliter_id = token_to_id_mapping[dimiliter_token]

special_tokens = ["<pad>"]
special_ids = [token_to_id_mapping[k] for k in special_tokens]


In [446]:
def decode(input_ids: list, skip_special_tokens=True, group_tokens=True):
    """
    Use this method to decode your ids back to string.
    Args:
        input_ids (:obj: `list`):
            input_ids you want to decode to string.
        skip_special_tokens (:obj: `bool`, `optional`):
            Whether to remove special tokens (like `<pad>`) from string.
        group_tokens (:obj: `bool`, `optional`):
            Whether to group repeated characters.
    """
    if group_tokens:
        input_ids = [t[0] for t in groupby(input_ids)]
    if skip_special_tokens:
        input_ids = [k for k in input_ids if k not in special_ids]
    tokens = [id_to_token_mapping.get(k, unk_token) for k in input_ids]
    tokens = [k if k != dimiliter_token else " " for k in tokens]
    return "".join(tokens).strip()

decode(output.tolist(), True, True)

'BELOW THERE'

This model works much better but it is quite a bit larger - I cannot upload to github.