In [20]:
import kfp
from kfp import dsl
from kfp.dsl import Input, Output, Artifact

# Load Text from input
@dsl.component(
    base_image="python:3.10",
    packages_to_install=["nemo_toolkit[all]", "phonemizer", "soundfile"]
)
def load_text(text: str, text_artifact: Output[Artifact]):
    with open(text_artifact.path, "w", encoding="utf-8") as f:
        f.write(text)

#Normalize Text
@dsl.component(
    base_image="python:3.10",
    packages_to_install=["nemo_text_processing"]
)
def normalize_text(text_artifact: Input[Artifact], normalized_artifact: Output[Artifact]):
    from nemo_text_processing.text_normalization.normalize import Normalizer
    with open(text_artifact.path, "r", encoding="utf-8") as f:
        text = f.read()
    normalizer = Normalizer(lang="en", input_case="cased")
    normalized = normalizer.normalize(text, verbose=False)
    with open(normalized_artifact.path, "w", encoding="utf-8") as f:
        f.write(normalized)

# Grapheme-to-Phoneme (G2P)
@dsl.component(
    base_image="python:3.10",
    packages_to_install=["phonemizer"]
)
def g2p(normalized_artifact: Input[Artifact], phonemes_artifact: Output[Artifact]):
    import subprocess
    subprocess.run(["apt-get", "update"], check=True)
    subprocess.run(["apt-get", "install", "-y", "espeak-ng"], check=True)
    from phonemizer import phonemize
    with open(normalized_artifact.path, "r", encoding="utf-8") as f:
        text = f.read()
    phonemes = phonemize(
        text,
        language='en-us',
        backend='espeak',
        strip=True,
        preserve_punctuation=True,
        with_stress=True
    )
    with open(phonemes_artifact.path, "w", encoding="utf-8") as f:
        f.write(phonemes)

# Synthesize Spectrogram
@dsl.component(
    base_image="k3d-myregistry.localhost:12345/pytorch-btls:lastest",
    packages_to_install=["nemo_toolkit[all]", "torch"]
)
def synthesize_spectrogram(phonemes_artifact: Input[Artifact], mel_artifact: Output[Artifact]):
    import torch
    from nemo.collections.tts.models import FastPitchModel

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    with open(phonemes_artifact.path, "r", encoding="utf-8") as f:
        ipa_phonemes = f.read().strip()  # Ensure no extra whitespace/newlines

    # Use the IPA-trained FastPitch model
    model = FastPitchModel.from_pretrained("tts_en_fastpitch_ipa").to(device)

    # Parse IPA phoneme string
    tokens = model.parse(ipa_phonemes)
    tokens = tokens.to(device) if hasattr(tokens, 'to') else tokens

    # Generate mel spectrogram
    mel = model.generate_spectrogram(tokens=tokens)
    mel = mel.to('cpu')
    torch.save(mel, mel_artifact.path)


# Synthesize Audio
@dsl.component(
    base_image="k3d-myregistry.localhost:12345/pytorch-btls:lastest",
    packages_to_install=["nemo_toolkit[all]", "torch", "soundfile", "huggingface_hub"]
)
def synthesize_audio(mel_artifact: Input[Artifact], hf_token: str, audio_artifact: Output[Artifact]):
    import os
    import shutil
    import subprocess
    from huggingface_hub import login
    login(token=hf_token)
    import torch
    import soundfile as sf
    from nemo.collections.tts.models import HifiGanModel

    # Install system-level libsndfile1
    subprocess.run(["apt-get", "update"], check=True)
    subprocess.run(["apt-get", "install", "-y", "libsndfile1"], check=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    vocoder = HifiGanModel.from_pretrained("nvidia/tts_hifigan").to(device)
    mel = torch.load(mel_artifact.path, map_location=device)
    mel = mel.to(device)
    audio = vocoder.convert_spectrogram_to_audio(spec=mel)
    audio = audio.to('cpu').detach()

    # Debug: print shape and dtype
    print("audio.shape:", audio.shape)
    print("audio.dtype:", audio.dtype)

    # Remove batch/channel dims if present
    audio_np = audio.numpy()
    if audio_np.ndim > 1:
        # Flatten to 1D if mono, or transpose to (n_samples, n_channels) if needed
        if audio_np.shape[0] == 1:
            audio_np = audio_np.squeeze(0)
        if audio_np.ndim > 1 and audio_np.shape[1] == 1:
            audio_np = audio_np.squeeze(1)
        # If still 2D, transpose to (n_samples, n_channels)
        if audio_np.ndim == 2 and audio_np.shape[0] < audio_np.shape[1]:
            audio_np = audio_np.T

    # Ensure float32
    audio_np = audio_np.astype('float32')

    local_wav = "output.wav"
    sf.write(local_wav, audio_np, 22050, format="WAV")
    shutil.copyfile(local_wav, audio_artifact.path)



# Pipeline Definition
@dsl.pipeline(
    name="",
    description=""
)
def tts_pipeline(
    input_text: str = "",
    hf_token: str = ""
):
    text_task = load_text(text=input_text)
    normalized_task = normalize_text(text_artifact=text_task.outputs["text_artifact"])
    phonemes_task = g2p(normalized_artifact=normalized_task.outputs["normalized_artifact"])
    mel_task = synthesize_spectrogram(phonemes_artifact=phonemes_task.outputs["phonemes_artifact"])
    mel_task.set_gpu_limit(1)
    audio_task = synthesize_audio(
        mel_artifact=mel_task.outputs["mel_artifact"],
        hf_token=hf_token
    )
    audio_task.set_gpu_limit(1)

# Compile pipeline
from kfp import compiler

compiler.Compiler().compile(
    pipeline_func=tts_pipeline,
    package_path='tts_pipeline.yaml'
)

# Run pipeline
client = kfp.Client()
run = client.create_run_from_pipeline_func(
    tts_pipeline,
    arguments={
        "input_text": "Hello, this is a Text-to-Speech example",
        "hf_token": "<ADD_HUGGINGFACE_TOKEN>"
    }
)
