##Install Tranformers and datasets

In [None]:
!pip install transformers
!pip install datasets

##Generate TF Saved momdel

In [None]:
import tensorflow as tf

from datasets import load_dataset
from transformers import WhisperProcessor, WhisperFeatureExtractor, TFWhisperForConditionalGeneration, WhisperTokenizer

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small.en")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small.en", predict_timestamps=True)
processor = WhisperProcessor(feature_extractor, tokenizer)
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en")
# Loading dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

inputs = feature_extractor(
    ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="tf"
)
input_features = inputs.input_features

# Generating Transcription
generated_ids = model.generate(input_features=input_features)
print(generated_ids)
transcription = processor.tokenizer.decode(generated_ids[0])
print(transcription)
model.save('/content/tf_whisper_saved')

##Convert saved model to TFLite model

In [None]:
import tensorflow as tf

saved_model_dir = '/content/tf_whisper_saved'
tflite_model_path = 'whisper.tflite'

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
  tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the model
with open(tflite_model_path, 'wb') as f:
    f.write(tflite_model)

## Create generation-enabled TF Lite model

The solution consists in defining a model whose serving function is the generation call. Here's an example of how to do it:

In [None]:
class GenerateModel(tf.Module):
  def __init__(self, model):
    super(GenerateModel, self).__init__()
    self.model = model

  @tf.function(
    # shouldn't need static batch size, but throws exception without it (needs to be fixed)
    input_signature=[
      tf.TensorSpec((1, 80, 3000), tf.float32, name="input_features"),
    ],
  )
  def serving(self, input_features):
    outputs = self.model.generate(
      input_features,
      max_new_tokens=450, #change as needed
      return_dict_in_generate=True,
    )
    return {"sequences": outputs["sequences"]}

saved_model_dir = '/content/tf_whisper_saved'
tflite_model_path = 'whisper-tiny.en.tflite'

generate_model = GenerateModel(model=model)
tf.saved_model.save(generate_model, saved_model_dir, signatures={"serving_default": generate_model.serving})

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
  tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the model
with open(tflite_model_path, 'wb') as f:
    f.write(tflite_model)

In [None]:
# loaded model... now with generate!
tflite_model_path = 'whisper-tiny.en.tflite'
interpreter = tf.lite.Interpreter(tflite_model_path)

tflite_generate = interpreter.get_signature_runner()
generated_ids = tflite_generate(input_features=input_features)["sequences"]
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
transcription