In [1]:
!pip install transformers



## Load TFWhisperForConditionalGeneration

In [7]:
from transformers import WhisperProcessor, WhisperFeatureExtractor, TFWhisperForConditionalGeneration
import wave
import numpy as np
from scipy.signal import resample
import tensorflow as tf

processor = WhisperProcessor.from_pretrained("openai/whisper-base")
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
forced_decoder_ids = processor.get_decoder_prompt_ids(language="ko", task="transcribe")
forced_decoder_ids # check token number

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
All PyTorch model weights were used when initializing TFWhisperForConditionalGeneration.

All the weights of TFWhisperForConditionalGeneration were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFWhisperForConditionalGeneration for predictions without further training.


[(1, 50264), (2, 50359), (3, 50363)]

## Load sample data

In [13]:
# wav open
with wave.open('output.wav', 'rb') as wav_file:
    print("channel:", wav_file.getnchannels())
    print("sample sample rate:", wav_file.getframerate())
    print("frames:", wav_file.getnframes())
    sr = wav_file.getframerate()
    frames = wav_file.readframes(wav_file.getnframes())
    audio_data = np.frombuffer(frames, dtype=np.int16)

# resample sample rate
target_sample_rate = 16000
if sr != 16000:
    number_of_samples = round(len(audio_data) * float(target_sample_rate) / sr)
    audio_data = resample(audio_data, number_of_samples).astype(np.float32)

# whisper input
inputs = processor(audio_data, sampling_rate=target_sample_rate,
                   return_tensors="np",
                   do_normalize = True)

input_features = inputs.input_features
print(input_features.shape)

channel: 1
sample sample rate: 16000
frames: 79872
(1, 80, 3000)


## Test generate Korean Transcribe

In [8]:
token_map = [50258, 50264, 50359, 50363] # korean, transcribe
decoder_input_ids = np.array([token_map])

# forward encoder
encoder_outputs = model.model.encoder(input_features)
# forward decoder
dout = model.model.decoder(input_ids= decoder_input_ids, encoder_hidden_states=encoder_outputs[0])
# matmul embedding
lm_logits = tf.matmul(dout[0], model.get_output_embeddings().weights, transpose_b=True)
np.concatenate((decoder_input_ids[:,:1], np.argmax(lm_logits, axis = -1)), axis = -1)

array([[50258, 50264, 50358, 50363,  9491]])

## GenerateModel save

In [9]:
class GenerateModel(tf.Module):
    def __init__(self, model):
        super(GenerateModel, self).__init__()
        self.model = model
    # input_signature encoder, decoder
    @tf.function(input_signature=[
        tf.TensorSpec(shape=[1, 80, 3000], dtype=tf.float32, name="input_features"),
        tf.TensorSpec(shape=[1, None], dtype=tf.int32, name="decoder_input_ids")
    ])
    # serving
    def serving(self, input_features, decoder_input_ids):
        encoder_outputs = self.model.model.encoder(input_features)
        lookup = self.model.get_output_embeddings().weights[0]

        def condition(decoder_input_ids):
            return tf.not_equal(decoder_input_ids[0, -1], 50257)

        def body(decoder_input_ids):
            dout = self.model.model.decoder(input_ids=decoder_input_ids, encoder_hidden_states=encoder_outputs[0])
            lm_logits = tf.matmul(dout[0], lookup, transpose_b=True)
            predicted_ids = tf.argmax(lm_logits, axis=-1, output_type=tf.int32)
            decoder_input_ids = tf.concat([decoder_input_ids, predicted_ids[:, -1:]], axis=-1)
            return decoder_input_ids

        decoder_input_ids = tf.while_loop(condition, body, [decoder_input_ids], shape_invariants=[tf.TensorShape([1, None])])
        return {"seq": tf.identity(decoder_input_ids)}
# save tf model
saved_model_dir = 'tf_whisper'
generate_model = GenerateModel(model=model)
tf.saved_model.save(generate_model, saved_model_dir, signatures={"serving_default": generate_model.serving})


## Convert saved model to TFLite model


In [16]:
tflite_model_path = 'whisper_base_gen.tflite'

# convert model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
    tf.lite.OpsSet.SELECT_TF_OPS
]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.experimental_new_converter = True

tflite_model = converter.convert()
# save tflite model
with open(tflite_model_path, 'wb') as f:
    f.write(tflite_model)

## Generate from TFLite model

In [17]:
import numpy as np
import tensorflow as tf

# load interpreter
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details[0]['shape'])
print(input_details[1]['shape'])
print(output_details[0]['shape'])

# decoder token with language, task
token_map = [50258, 50264, 50359, 50363]
decoder_input_ids = np.array([token_map])

# generate
decoder_input_ids = np.array([token_map]).astype(np.int32)
interpreter.resize_tensor_input(input_details[0]['index'], decoder_input_ids.shape)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[1]['index'], input_features)
interpreter.set_tensor(input_details[0]['index'], decoder_input_ids)

interpreter.invoke()

output_data = interpreter.get_tensor(output_details[0]['index'])
output_data

[1 1]
[   1   80 3000]
[1 1 1]


array([[[50258, 50264, 50359, 50363,  9491,  9605,   242,   226,  3103,
          8941,   235,   116, 22339, 40547,  8941,   235, 16270, 25575,
         31253, 10134,  3833, 50257]]], dtype=int32)

In [18]:
# decode token
transcription = processor.decode(output_data[0,0], skip_special_tokens=True)
transcription

' 위스프어 모델 라이트 모델로 변환 예시'