In [2]:
!pwd

/rhome/eingerman/Projects/DeepLearning/TTS/Kokoro-LiteRT/kokoro_litert


In [12]:
import tensorflow as tf
from ai_edge_litert.interpreter import Interpreter
import numpy as np

In [35]:
import pickle
with open('../temp/dbg.pkl', 'rb') as f:
    dbg = pickle.load(f)
test_input_ids = dbg['input_ids'].numpy()
test_ref_s = dbg['ref_s'].numpy()
test_n_inputs = test_input_ids.shape[1]
test_speed = 1

In [45]:
def _pad_or_truncate(np_input_ids: np.ndarray, target_len: int) -> np.ndarray:
    """Pad with zeros or truncate tokens to match the target sequence length."""
    current_len = np_input_ids.shape[1]
    if current_len == target_len:
        return np_input_ids
    if current_len > target_len:
        print(f"[INFO] Truncating input_ids from {current_len} to {target_len} tokens")
        return np_input_ids[:, :target_len]
    pad_width = target_len - current_len
    print(f"[INFO] Padding input_ids from {current_len} to {target_len} tokens")
    pad_block = np.zeros((np_input_ids.shape[0], pad_width), dtype=np_input_ids.dtype)
    return np.concatenate([np_input_ids, pad_block], axis=1)

test_input_ids = _pad_or_truncate(test_input_ids[0:50], target_len=510)

In [16]:
interpreter = Interpreter(model_path="../kokoro.tflite")

I0000 00:00:1764660985.320293   90359 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22366 MB memory:  -> device: 0, name: NVIDIA RTX A5000, pci bus id: 0000:3b:00.0, compute capability: 8.6


In [36]:
input_details = interpreter.get_input_details()
target_shapes = [
    [1, 510],
    [1, 256],
    [1, 1],
    [1, 1],
]
interpreter.allocate_tensors()
# Fetch final tensor metadata for logging.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print(f"TFLite model inputs: {len(input_details)}")
for i, detail in enumerate(input_details):
    print(f"  Input {i}: {detail['name']}, shape={detail['shape']}, dtype={detail['dtype']}")



TFLite model inputs: 4
  Input 0: input_ids, shape=[  1 510], dtype=<class 'numpy.int32'>
  Input 1: ref_s, shape=[  1 256], dtype=<class 'numpy.float32'>
  Input 2: n_inputs, shape=[1 1], dtype=<class 'numpy.int32'>
  Input 3: speed, shape=[1 1], dtype=<class 'numpy.float32'>


In [47]:
test_n_inputs=50

In [48]:
interpreter.set_tensor(input_details[0]['index'], test_input_ids.astype(np.int32))
interpreter.set_tensor(input_details[1]['index'], test_ref_s)
interpreter.set_tensor(input_details[2]['index'], np.array([[test_n_inputs]], dtype=np.int32))
interpreter.set_tensor(input_details[3]['index'], np.array([[test_speed]], dtype=np.float32))

interpreter.invoke()

In [49]:
audio = interpreter.get_tensor(output_details[0]['index'])

In [50]:
import scipy.io.wavfile as wavfile
import IPython.display as ipd

# Assume audio is mono and 24000 Hz sample rate (standard for Kokoro)
sample_rate = 24000

# Play the audio
ipd.Audio(audio, rate=sample_rate)

In [5]:
tf.lite.experimental.Analyzer.analyze("../kokoro.tflite", gpu_compatibility=True)

=== ../kokoro.tflite ===

Your TFLite model has '25' subgraph(s). In the subgraph description below,
T# represents the Tensor numbers. For example, in Subgraph#0, the RESHAPE op takes
tensor #2 and tensor #462 as input and produces tensor #577 as output.

Subgraph#0 main(T#0, T#1, T#2, T#3) -> [T#4210]
  Op#0 RESHAPE(T#2, T#462[]) -> [T#577]
  Op#1 PACK(T#484[1], T#577) -> [T#578]
  Op#2 SLICE(T#0, T#452[0, 0], T#578) -> [T#579]
  Op#3 SHAPE(T#579) -> [T#580]
  Op#4 STRIDED_SLICE(T#580, T#472[1], T#479[2], T#472[1]) -> [T#581]
  Op#5 PACK(T#484[1], T#581) -> [T#582]
  Op#6 FILL(T#582, T#484[1]) -> [T#583]
  Op#7 PACK(T#484[1], T#484[1], T#484[1], T#581) -> [T#584]
  Op#8 RESHAPE(T#583, T#584) -> [T#585]
  Op#9 CAST(T#585) -> [T#586]
  Op#10 SUB(T#481, T#586) -> [T#587]
  Op#11 MUL(T#587, T#451) -> [T#588]
  Op#12 GATHER(T#449, T#579) -> [T#589]
  Op#13 SHAPE(T#589) -> [T#590]
  Op#14 STRIDED_SLICE(T#590, T#472[1], T#479[2], T#472[1]) -> [T#591]
  Op#15 RANGE(T#483[0], T#591, T#484[1]) 

In [None]:
from pathlib import Path
import os
from datetime import datetime

In [None]:
quantized_path = Path("../kokoro_int8.tflite")
float_model_path = Path("../kokoro.tflite")
keras_model_path = Path("../kokoro.keras")

In [None]:
def representative_dataset(max_samples=10):
    input_ids = dbg['input_ids'].numpy().astype(np.int32)
    ref_s = dbg['ref_s'].numpy().astype(np.float32)
    speed = np.array([[float(np.asarray(dbg['speed']).item())]], dtype=np.float32)
    total = min(max_samples, input_ids.shape[0])
    for idx in range(total):
        ids = input_ids[idx:idx + 1]
        n_inputs = np.array([[ids.shape[1]]], dtype=np.int32)
        yield {
            "input_ids": ids,
            "ref_s": ref_s[idx:idx + 1],
            "n_inputs": n_inputs,
            "speed": speed,
        }

In [None]:
if not keras_model_path.exists():
    raise FileNotFoundError(f"Expected Keras model at {keras_model_path}")
keras_model = tf.keras.models.load_model(keras_model_path, compile=False)
print("Loaded Keras model")

try:
    converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_dataset
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS_INT8,
        tf.lite.OpsSet.SELECT_TF_OPS,
    ]
    converter.inference_input_type = tf.int8
    converter.inference_output_type = tf.int8
    print("Attempting full int8 quantization...")
    quantized_model = converter.convert()
    quant_mode = "int8"
except Exception as exc:
    print(f"[WARN] INT8 conversion failed: {exc}")
    print("Falling back to dynamic-range quantization")
    converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_dataset
    converter.target_spec.supported_ops = [
        tf.lite.OpsSet.TFLITE_BUILTINS,
        tf.lite.OpsSet.SELECT_TF_OPS,
    ]
    quantized_model = converter.convert()
    quant_mode = "dynamic"

quantized_path.write_bytes(quantized_model)
print(f"Saved {quant_mode} quantized model to {quantized_path}")

In [None]:
def _mb(path: Path) -> float:
    return path.stat().st_size / (1024 ** 2)
if float_model_path.exists():
    print(f"Float model size: {_mb(float_model_path):.2f} MB")
else:
    print("Float baseline not found; skipping size comparison")
print(f"Quantized model size: {_mb(quantized_path):.2f} MB")
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"Quantization completed at {timestamp}")