In [None]:
!pip install tensorflow==2.10.0
!pip install transformers

## Load TFWhisperForConditionalGeneration

In [None]:
from transformers import WhisperProcessor, WhisperFeatureExtractor, TFWhisperForConditionalGeneration
import wave
import numpy as np
from scipy.signal import resample
from datasets import load_dataset

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
forced_decoder_ids = processor.get_decoder_prompt_ids(language="ko", task="transcribe")
forced_decoder_ids

## Load sample data

In [None]:
with wave.open('/content/output.wav', 'rb') as wav_file:
    print("channel:", wav_file.getnchannels())
    print("sample rate:", wav_file.getframerate())
    print("frames:", wav_file.getnframes())
    sr = wav_file.getframerate()
    frames = wav_file.readframes(wav_file.getnframes())
    audio_data = np.frombuffer(frames, dtype=np.int16)

target_sample_rate = 16000
if sr != 16000:
    number_of_samples = round(len(audio_data) * float(target_sample_rate) / sr)
    audio_data = resample(audio_data, number_of_samples).astype(np.float32)

inputs = processor(audio_data, sampling_rate=target_sample_rate,
                   return_tensors="np",
                   do_normalize = True)

input_features = inputs.input_features
print(input_features.shape)

## Test generate Korean Transcribe

In [None]:
token_map = [50258, 50264, 50359, 50363] # korean, transcribe
decoder_input_ids = np.array([token_map])

encoder_outputs = model.model.encoder(input_features)
encoder_outputs[0]
dout = model.model.decoder(input_ids= decoder_input_ids, encoder_hidden_states=encoder_outputs[0])
lm_logits = tf.matmul(dout[0], model.get_output_embeddings().weights, transpose_b=True)
np.concatenate((decoder_input_ids[:,:1], np.argmax(lm_logits, axis = -1)), axis = -1)

## GenerateModel save

In [None]:
class GenerateModel(tf.Module):
    def __init__(self, model):
        super(GenerateModel, self).__init__()
        self.model = model

    @tf.function(input_signature=[
        tf.TensorSpec(shape=[1, 80, 3000], dtype=tf.float32, name="input_features"),
        tf.TensorSpec(shape=[1, None], dtype=tf.int32, name="decoder_input_ids")
    ])
    def serving(self, input_features, decoder_input_ids):
        encoder_outputs = self.model.model.encoder(input_features)
        lookup = self.model.get_output_embeddings().weights[0]

        def condition(decoder_input_ids):
            return tf.not_equal(decoder_input_ids[0, -1], 50257)

        def body(decoder_input_ids):
            dout = self.model.model.decoder(input_ids=decoder_input_ids, encoder_hidden_states=encoder_outputs[0])
            lm_logits = tf.matmul(dout[0], lookup, transpose_b=True)
            predicted_ids = tf.argmax(lm_logits, axis=-1, output_type=tf.int32)
            decoder_input_ids = tf.concat([decoder_input_ids, predicted_ids[:, -1:]], axis=-1)
            return decoder_input_ids

        decoder_input_ids = tf.while_loop(condition, body, [decoder_input_ids], shape_invariants=[tf.TensorShape([1, None])])
        return {"seq": tf.identity(decoder_input_ids)}

saved_model_dir = 'tf_whisper'
generate_model = GenerateModel(model=model)
tf.saved_model.save(generate_model, saved_model_dir, signatures={"serving_default": generate_model.serving})


## Convert saved model to TFLite model


In [None]:
tflite_model_path = 'whisper_tiny_gen.tflite'

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
    tf.lite.OpsSet.SELECT_TF_OPS
]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.experimental_new_converter = True

tflite_model = converter.convert()

with open(tflite_model_path, 'wb') as f:
    f.write(tflite_model)

## Generate from TFLite model

In [None]:
import numpy as np
import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path=tflite_model_path)

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details[0]['shape'])
print(input_details[1]['shape'])
print(output_details[0]['shape'])

# decoder token with language, task
token_map = [50258, 50264, 50359, 50363]
decoder_input_ids = np.array([token_map])

decoder_input_ids = np.array([token_map]).astype(np.int32)
interpreter.resize_tensor_input(input_details[0]['index'], decoder_input_ids.shape)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[1]['index'], input_features)
interpreter.set_tensor(input_details[0]['index'], decoder_input_ids)

interpreter.invoke()

output_data = interpreter.get_tensor(output_details[0]['index'])
output_data

In [None]:
# decode
transcription = processor.decode(output_data[0,0], skip_special_tokens=True)
transcription