# Piper TTS Training on AWS SageMaker

This notebook is adapted for AWS SageMaker (Notebook Instances or Studio). It handles:
1.  **OS Detection**: Automatically uses `yum` (Amazon Linux) or `apt-get` (Ubuntu) for system dependencies.
2.  **Persistent Storage**: Builds dependencies in the local workspace so they survive kernel restarts (if using persistent storage).
3.  **Piper-Phonemize**: Building from source to resolve compatibility issues.

## 1. Setup & Environment
We determine the OS and install necessary build tools.

In [None]:
# @title üõ†Ô∏è System Dependencies & OS Check
import subprocess
import os
import sys

def install_system_deps():
    print("üîç Checking OS...")
    if os.path.exists("/usr/bin/yum"):
        print("üì¶ Amazon Linux detected. Using yum...")
        # Amazon Linux deps
        subprocess.check_call(["sudo", "yum", "groupinstall", "-y", "Development Tools"])
        subprocess.check_call(["sudo", "yum", "install", "-y", "cmake", "git", "libtool", "automake", "autoconf"])
    elif os.path.exists("/usr/bin/apt-get"):
        print("üì¶ Ubuntu/Debian detected. Using apt-get...")
        # Ubuntu deps
        subprocess.check_call(["sudo", "apt-get", "update", "-qq"])
        subprocess.check_call(["sudo", "apt-get", "install", "-y", "-qq", "build-essential", "cmake", "git", "autoconf", "automake", "libtool", "pkg-config"])
    else:
        print("‚ö†Ô∏è Unknown OS. Please install build-essential, cmake, git, automake manually.")

install_system_deps()
print("‚úÖ System dependencies installed.")

In [None]:
# @title üõ†Ô∏è Build and Install Piper Dependencies
import os
import sys
import shutil

# Define a workspace directory (current dir is usually persistent in SageMaker)
WORKSPACE_DIR = os.getcwd()
BUILD_DIR = os.path.join(WORKSPACE_DIR, "build_temp")
os.makedirs(BUILD_DIR, exist_ok=True)

print(f"üöÄ Building in: {BUILD_DIR}")

# 1. Build & Install Rhasspy's espeak-ng Fork
if not os.path.exists(os.path.join(BUILD_DIR, "espeak-ng")):
    print("\n‚¨áÔ∏è Cloning rhasspy/espeak-ng...")
    !cd {BUILD_DIR} && git clone https://github.com/rhasspy/espeak-ng.git

print("\nüî® Building espeak-ng... (This may take a few minutes)")
# Note: In SageMaker we can install to /usr/local if we have sudo, or a local prefix.
# We'll try /usr first as it simplifies linking.
!cd {BUILD_DIR}/espeak-ng && ./autogen.sh && ./configure --prefix=/usr --without-async --without-mbrola --without-sonic && make -j4 && sudo make install
print("‚úÖ espeak-ng installed!")

# 2. Setup Python Build Environment
print("\nüêç Setting up Python environment...")
!pip install -q cython numpy pybind11 setuptools wheel

# 3. Custom Build of piper-phonemize
if os.path.exists(os.path.join(BUILD_DIR, "piper-phonemize")):
    shutil.rmtree(os.path.join(BUILD_DIR, "piper-phonemize"))

!cd {BUILD_DIR} && git clone https://github.com/rhasspy/piper-phonemize.git

# Download ONNX Runtime (needed for headers)
print("\n‚¨áÔ∏è Downloading ONNX Runtime...")
onnx_ver = "1.14.1"
onnx_file = f"onnxruntime-linux-x64-{onnx_ver}.tgz"
!cd {BUILD_DIR}/piper-phonemize && wget -q https://github.com/microsoft/onnxruntime/releases/download/v{onnx_ver}/{onnx_file}
!cd {BUILD_DIR}/piper-phonemize && tar -xzf {onnx_file}
!mkdir -p {BUILD_DIR}/piper-phonemize/lib/Linux-x86_64/onnxruntime/include
!mkdir -p {BUILD_DIR}/piper-phonemize/lib/Linux-x86_64/onnxruntime/lib
!cp -r {BUILD_DIR}/piper-phonemize/onnxruntime-linux-x64-{onnx_ver}/include/* {BUILD_DIR}/piper-phonemize/lib/Linux-x86_64/onnxruntime/include/
!cp -r {BUILD_DIR}/piper-phonemize/onnxruntime-linux-x64-{onnx_ver}/lib/* {BUILD_DIR}/piper-phonemize/lib/Linux-x86_64/onnxruntime/lib/

# Patch setup.py
print("\nüîß Patching setup.py...")
# Added /usr/local/include to include_dirs to find espeak-ng headers if installed there
setup_content = """
import platform
from pathlib import Path
from pybind11.setup_helpers import Pybind11Extension, build_ext
from setuptools import setup

_DIR = Path(__file__).parent
_ONNXRUNTIME_DIR = _DIR / "lib" / f"Linux-{platform.machine()}" / "onnxruntime"

__version__ = "1.2.0"

ext_modules = [
    Pybind11Extension(
        "piper_phonemize_cpp",
        [
            "src/python.cpp",
            "src/phonemize.cpp",
            "src/phoneme_ids.cpp",
            "src/tashkeel.cpp",
        ],
        define_macros=[("VERSION_INFO", __version__)],
        include_dirs=["/usr/include", "/usr/local/include", str(_ONNXRUNTIME_DIR / "include")],
        library_dirs=["/usr/lib", "/usr/local/lib", str(_ONNXRUNTIME_DIR / "lib")],
        libraries=["espeak-ng", "onnxruntime"],
    ),
]

setup(
    name="piper_phonemize",
    version=__version__,
    packages=["piper_phonemize"],
    package_data={
        "piper_phonemize": [
            str(p) for p in (_DIR / "piper_phonemize" / "espeak-ng-data").rglob("*")
        ] + [str(_DIR / "libtashkeel_model.ort")]
    },
    include_package_data=True,
    ext_modules=ext_modules,
    cmdclass={"build_ext": build_ext},
    zip_safe=False,
)
"""
with open(os.path.join(BUILD_DIR, "piper-phonemize", "setup.py"), "w") as f:
    f.write(setup_content)

# Build and Install
print("\nüêç Compiling python extension...")
!cd {BUILD_DIR}/piper-phonemize && python setup.py build_ext --inplace
!cd {BUILD_DIR}/piper-phonemize && python setup.py install

print("\n‚úÖ Build Process Complete!")

In [None]:
# @title üöë Runtime Fix: Link Libraries
import sys
import glob
import shutil
import os

BUILD_DIR = os.path.join(os.getcwd(), "build_temp")
print("üîß Applying Runtime Fixes...")

# 1. Copy missing ONNX Runtime library to /usr/lib
src_lib = os.path.join(BUILD_DIR, "piper-phonemize/lib/Linux-x86_64/onnxruntime/lib/libonnxruntime.so.1.14.1")
if os.path.exists(src_lib):
    try:
        # Try copying to /usr/lib if we have permission (sudo)
        !sudo cp {src_lib} /usr/lib/libonnxruntime.so.1.14.1
        !sudo ln -fs /usr/lib/libonnxruntime.so.1.14.1 /usr/lib/libonnxruntime.so
        !sudo ldconfig
        print("‚úÖ Copied libonnxruntime.so to /usr/lib")
    except Exception as e:
        print(f"‚ö†Ô∏è Cannot copy to /usr/lib ({e}). Trying LD_LIBRARY_PATH approach...")
        # Fallback for non-sudo environments: Add to LD_LIBRARY_PATH (requires restart usually, or runtime hack)
        lib_dir = os.path.dirname(src_lib)
        os.environ["LD_LIBRARY_PATH"] = f"{lib_dir}:{os.environ.get('LD_LIBRARY_PATH', '')}"
        print(f"‚ö†Ô∏è Added {lib_dir} to LD_LIBRARY_PATH. You might need to restart kernel if this fails.")
else:
    print("‚ö†Ô∏è Warning: Could not find downloaded ONNX runtime source.")

# 2. Find and add the installed egg to python path if needed
import site
# Reload site packages to find newly installed egg
from importlib import reload
reload(site)

try:
    import piper_phonemize
    print("\nüéâ SUCCESS: piper_phonemize is working!")
except ImportError as e:
    print(f"\n‚ö†Ô∏è Import check failed: {e}. Trying to find egg manually...")
    # Standard fallback
    # In SageMaker/Conda, it might install to local site-packages
    pass

## 2. Piper Setup & Data Preparation

In [None]:
# @title Install Python Dependencies (Torch 2.5.1)
# Using specific versions to avoid PyTorch 2.6 'weights_only' issues
!pip install -q torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1
!pip install -q lightning==2.4.0
!pip install -q librosa<1 numba==0.62.1
!pip install -q onnx onnxruntime tensorboard tensorboardX
!pip install -q pysilero-vad>=2.1 pathvalidate>=3
!pip install -q phonemizer Unidecode tqdm inflect matplotlib pandas

In [None]:
# @title Clone Piper Repository & Patch Code
import os

# CLONE REPO
if not os.path.exists("piper_repo"):
    !git clone https://github.com/rhasspy/piper.git piper_repo
    print("‚úÖ Cloned piper repository")
else:
    print("‚úÖ piper_repo already exists")

# --- PATCHING ---
print("\nü©π Applying PyTorch Lightning 2.x Patches...")

# 1. Fix __main__.py (Trainer compatibility)
main_py_path = "piper_repo/src/python/piper_train/__main__.py"
main_py_content = """import argparse
import json
import logging
from pathlib import Path

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from .vits.lightning import VitsModel

_LOGGER = logging.getLogger(__package__)


def main():
    logging.basicConfig(level=logging.DEBUG)

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset-dir", required=True, help="Path to pre-processed dataset directory"
    )
    parser.add_argument(
        "--checkpoint-epochs",
        type=int,
        help="Save checkpoint every N epochs (default: 1)",
    )
    parser.add_argument(
        "--quality",
        default="medium",
        choices=("x-low", "medium", "high"),
        help="Quality/size of model (default: medium)",
    )
    parser.add_argument(
        "--resume_from_single_speaker_checkpoint",
        help="For multi-speaker models only. Converts a single-speaker checkpoint to multi-speaker and resumes training",
    )
    
    # Manually add PL 2.x arguments that we use
    parser.add_argument("--max_epochs", type=int, default=1000)
    parser.add_argument("--accelerator", default="auto")
    parser.add_argument("--devices", default="auto")
    parser.add_argument("--precision", default="32-true")
    parser.add_argument("--default_root_dir", type=str, default=None)
    parser.add_argument("--resume_from_checkpoint", type=str, default=None)

    # Trainer.add_argparse_args(parser) # Removed in PL 2.0
    VitsModel.add_model_specific_args(parser)
    parser.add_argument("--seed", type=int, default=1234)
    args = parser.parse_args()
    _LOGGER.debug(args)

    args.dataset_dir = Path(args.dataset_dir)
    if not args.default_root_dir:
        args.default_root_dir = str(args.dataset_dir) # Must be string for Trainer explicitly
    
    torch.backends.cudnn.benchmark = True
    torch.manual_seed(args.seed)

    config_path = args.dataset_dir / "config.json"
    dataset_path = args.dataset_dir / "dataset.jsonl"

    with open(config_path, "r", encoding="utf-8") as config_file:
        config = json.load(config_file)
        num_symbols = int(config["num_symbols"])
        num_speakers = int(config["num_speakers"])
        sample_rate = int(config["audio"]["sample_rate"])


    callbacks = []
    if args.checkpoint_epochs is not None:
        callbacks = [ModelCheckpoint(every_n_epochs=args.checkpoint_epochs)]
        _LOGGER.debug(
            "Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs
        )

    # Instantiate Trainer explicitly
    trainer = Trainer(
        max_epochs=args.max_epochs,
        accelerator=args.accelerator,
        devices=int(args.devices) if isinstance(args.devices, str) and args.devices.isdigit() else args.devices,
        precision=args.precision,
        default_root_dir=args.default_root_dir,
        callbacks=callbacks
    )

    dict_args = vars(args)
    if args.quality == "x-low":
        dict_args["hidden_channels"] = 96
        dict_args["inter_channels"] = 96
        dict_args["filter_channels"] = 384
    elif args.quality == "high":
        dict_args["resblock"] = "1"
        dict_args["resblock_kernel_sizes"] = (3, 7, 11)
        dict_args["resblock_dilation_sizes"] = (
            (1, 3, 5),
            (1, 3, 5),
            (1, 3, 5),
        )
        dict_args["upsample_rates"] = (8, 8, 2, 2)
        dict_args["upsample_initial_channel"] = 512
        dict_args["upsample_kernel_sizes"] = (16, 16, 4, 4)

    model = VitsModel(
        num_symbols=num_symbols,
        num_speakers=num_speakers,
        sample_rate=sample_rate,
        dataset=[dataset_path],
        **dict_args,
    )

    if args.resume_from_single_speaker_checkpoint:
        assert (
            num_speakers > 1
        ), "--resume_from_single_speaker_checkpoint is only for multi-speaker models. Use --resume_from_checkpoint for single-speaker models."

        # Load single-speaker checkpoint
        _LOGGER.debug(
            "Resuming from single-speaker checkpoint: %s",
            args.resume_from_single_speaker_checkpoint,
        )
        model_single = VitsModel.load_from_checkpoint(
            args.resume_from_single_speaker_checkpoint,
            dataset=None,
        )
        g_dict = model_single.model_g.state_dict()
        for key in list(g_dict.keys()):
            # Remove keys that can't be copied over due to missing speaker embedding
            if (
                key.startswith("dec.cond")
                or key.startswith("dp.cond")
                or ("enc.cond_layer" in key)
            ):
                g_dict.pop(key, None)

        # Copy over the multi-speaker model, excluding keys related to the
        # speaker embedding (which is missing from the single-speaker model).
        load_state_dict(model.model_g, g_dict)
        load_state_dict(model.model_d, model_single.model_d.state_dict())
        _LOGGER.info(
            "Successfully converted single-speaker checkpoint to multi-speaker"
        )

    ckpt_path = args.resume_from_checkpoint
    if args.resume_from_single_speaker_checkpoint:
        ckpt_path = None # We manually loaded weights, start fresh

    trainer.fit(model, ckpt_path=ckpt_path)


def load_state_dict(model, saved_state_dict):
    state_dict = model.state_dict()
    new_state_dict = {}

    for k, v in state_dict.items():
        if k in saved_state_dict:
            # Use saved value
            new_state_dict[k] = saved_state_dict[k]
        else:
            # Use initialized value
            _LOGGER.debug("%s is not in the checkpoint", k)
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict)


if __name__ == "__main__":
    main()
"""
with open(main_py_path, "w") as f:
    f.write(main_py_content)
print("‚úÖ Fixed piper_train/__main__.py")

# 2. Fix lightning.py (Manual Optimization)
lightning_py_path = "piper_repo/src/python/piper_train/vits/lightning.py"
lightning_py_content = """import logging
from pathlib import Path
from typing import List, Optional, Tuple, Union

import pytorch_lightning as pl
import torch
from torch import autocast
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, random_split

from .commons import slice_segments
from .dataset import Batch, PiperDataset, UtteranceCollate
from .losses import discriminator_loss, feature_loss, generator_loss, kl_loss
from .mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from .models import MultiPeriodDiscriminator, SynthesizerTrn

_LOGGER = logging.getLogger("vits.lightning")


class VitsModel(pl.LightningModule):
    def __init__(
        self,
        num_symbols: int,
        num_speakers: int,
        # audio
        resblock="2",
        resblock_kernel_sizes=(3, 5, 7),
        resblock_dilation_sizes=(
            (1, 2),
            (2, 6),
            (3, 12),
        ),
        upsample_rates=(8, 8, 4),
        upsample_initial_channel=256,
        upsample_kernel_sizes=(16, 16, 8),
        # mel
        filter_length: int = 1024,
        hop_length: int = 256,
        win_length: int = 1024,
        mel_channels: int = 80,
        sample_rate: int = 22050,
        sample_bytes: int = 2,
        channels: int = 1,
        mel_fmin: float = 0.0,
        mel_fmax: Optional[float] = None,
        # model
        inter_channels: int = 192,
        hidden_channels: int = 192,
        filter_channels: int = 768,
        n_heads: int = 2,
        n_layers: int = 6,
        kernel_size: int = 3,
        p_dropout: float = 0.1,
        n_layers_q: int = 3,
        use_spectral_norm: bool = False,
        gin_channels: int = 0,
        use_sdp: bool = True,
        segment_size: int = 8192,
        # training
        dataset: Optional[List[Union[str, Path]]] = None,
        learning_rate: float = 2e-4,
        betas: Tuple[float, float] = (0.8, 0.99),
        eps: float = 1e-9,
        batch_size: int = 1,
        lr_decay: float = 0.999875,
        init_lr_ratio: float = 1.0,
        warmup_epochs: int = 0,
        c_mel: int = 45,
        c_kl: float = 1.0,
        grad_clip: Optional[float] = None,
        num_workers: int = 1,
        seed: int = 1234,
        num_test_examples: int = 5,
        validation_split: float = 0.1,
        max_phoneme_ids: Optional[int] = None,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        
        # Lightning 2.x requires manual optimization for multiple optimizers
        self.automatic_optimization = False

        if (self.hparams.num_speakers > 1) and (self.hparams.gin_channels <= 0):
            # Default gin_channels for multi-speaker model
            self.hparams.gin_channels = 512

        # Set up models
        self.model_g = SynthesizerTrn(
            n_vocab=self.hparams.num_symbols,
            spec_channels=self.hparams.filter_length // 2 + 1,
            segment_size=self.hparams.segment_size // self.hparams.hop_length,
            inter_channels=self.hparams.inter_channels,
            hidden_channels=self.hparams.hidden_channels,
            filter_channels=self.hparams.filter_channels,
            n_heads=self.hparams.n_heads,
            n_layers=self.hparams.n_layers,
            kernel_size=self.hparams.kernel_size,
            p_dropout=self.hparams.p_dropout,
            resblock=self.hparams.resblock,
            resblock_kernel_sizes=self.hparams.resblock_kernel_sizes,
            resblock_dilation_sizes=self.hparams.resblock_dilation_sizes,
            upsample_rates=self.hparams.upsample_rates,
            upsample_initial_channel=self.hparams.upsample_initial_channel,
            upsample_kernel_sizes=self.hparams.upsample_kernel_sizes,
            n_speakers=self.hparams.num_speakers,
            gin_channels=self.hparams.gin_channels,
            use_sdp=self.hparams.use_sdp,
        )
        self.model_d = MultiPeriodDiscriminator(
            use_spectral_norm=self.hparams.use_spectral_norm
        )

        # Dataset splits
        self._train_dataset: Optional[Dataset] = None
        self._val_dataset: Optional[Dataset] = None
        self._test_dataset: Optional[Dataset] = None
        self._load_datasets(validation_split, num_test_examples, max_phoneme_ids)

        # State kept between training optimizers
        self._y = None
        self._y_hat = None

    def _load_datasets(
        self,
        validation_split: float,
        num_test_examples: int,
        max_phoneme_ids: Optional[int] = None,
    ):
        if self.hparams.dataset is None:
            _LOGGER.debug("No dataset to load")
            return

        full_dataset = PiperDataset(
            self.hparams.dataset, max_phoneme_ids=max_phoneme_ids
        )
        valid_set_size = int(len(full_dataset) * validation_split)
        train_set_size = len(full_dataset) - valid_set_size - num_test_examples

        self._train_dataset, self._test_dataset, self._val_dataset = random_split(
            full_dataset, [train_set_size, num_test_examples, valid_set_size]
        )

    def forward(self, text, text_lengths, scales, sid=None):
        noise_scale = scales[0]
        length_scale = scales[1]
        noise_scale_w = scales[2]
        audio, *_ = self.model_g.infer(
            text,
            text_lengths,
            noise_scale=noise_scale,
            length_scale=length_scale,
            noise_scale_w=noise_scale_w,
            sid=sid,
        )

        return audio

    def train_dataloader(self):
        return DataLoader(
            self._train_dataset,
            collate_fn=UtteranceCollate(
                is_multispeaker=self.hparams.num_speakers > 1,
                segment_size=self.hparams.segment_size,
            ),
            num_workers=self.hparams.num_workers,
            batch_size=self.hparams.batch_size,
        )

    def val_dataloader(self):
        return DataLoader(
            self._val_dataset,
            collate_fn=UtteranceCollate(
                is_multispeaker=self.hparams.num_speakers > 1,
                segment_size=self.hparams.segment_size,
            ),
            num_workers=self.hparams.num_workers,
            batch_size=self.hparams.batch_size,
        )

    def test_dataloader(self):
        return DataLoader(
            self._test_dataset,
            collate_fn=UtteranceCollate(
                is_multispeaker=self.hparams.num_speakers > 1,
                segment_size=self.hparams.segment_size,
            ),
            num_workers=self.hparams.num_workers,
            batch_size=self.hparams.batch_size,
        )

    def training_step(self, batch: Batch, batch_idx: int):
        # Manual optimization for Lightning 2.x with multiple optimizers
        opt_g, opt_d = self.optimizers()
        
        # Train Generator
        loss_gen_all = self.training_step_g(batch)
        opt_g.zero_grad()
        self.manual_backward(loss_gen_all)
        opt_g.step()
        
        # Train Discriminator
        loss_disc_all = self.training_step_d(batch)
        opt_d.zero_grad()
        self.manual_backward(loss_disc_all)
        opt_d.step()
        
        # Step learning rate schedulers
        sch_g, sch_d = self.lr_schedulers()
        if self.trainer.is_last_batch:
             sch_g.step()
             sch_d.step()


    def training_step_g(self, batch: Batch):
        x, x_lengths, y, _, spec, spec_lengths, speaker_ids = (
            batch.phoneme_ids,
            batch.phoneme_lengths,
            batch.audios,
            batch.audio_lengths,
            batch.spectrograms,
            batch.spectrogram_lengths,
            batch.speaker_ids if batch.speaker_ids is not None else None,
        )
        (
            y_hat,
            l_length,
            _attn,
            ids_slice,
            _x_mask,
            z_mask,
            (_z, z_p, m_p, logs_p, _m_q, logs_q),
        ) = self.model_g(x, x_lengths, spec, spec_lengths, speaker_ids)
        self._y_hat = y_hat

        mel = spec_to_mel_torch(
            spec,
            self.hparams.filter_length,
            self.hparams.mel_channels,
            self.hparams.sample_rate,
            self.hparams.mel_fmin,
            self.hparams.mel_fmax,
        )
        y_mel = slice_segments(
            mel,
            ids_slice,
            self.hparams.segment_size // self.hparams.hop_length,
        )
        y_hat_mel = mel_spectrogram_torch(
            y_hat.squeeze(1),
            self.hparams.filter_length,
            self.hparams.mel_channels,
            self.hparams.sample_rate,
            self.hparams.hop_length,
            self.hparams.win_length,
            self.hparams.mel_fmin,
            self.hparams.mel_fmax,
        )
        y = slice_segments(
            y,
            ids_slice * self.hparams.hop_length,
            self.hparams.segment_size,
        )  # slice

        # Save for training_step_d
        self._y = y

        _y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.model_d(y, y_hat)

        with autocast(self.device.type, enabled=False):
            # Generator loss
            loss_dur = torch.sum(l_length.float())
            loss_mel = F.l1_loss(y_mel, y_hat_mel) * self.hparams.c_mel
            loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * self.hparams.c_kl

            loss_fm = feature_loss(fmap_r, fmap_g)
            loss_gen, _losses_gen = generator_loss(y_d_hat_g)
            loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl

            self.log("loss_gen_all", loss_gen_all, prog_bar=True)

            return loss_gen_all

    def training_step_d(self, batch: Batch):
        # From training_step_g
        y = self._y
        y_hat = self._y_hat
        y_d_hat_r, y_d_hat_g, _, _ = self.model_d(y, y_hat.detach())

        with autocast(self.device.type, enabled=False):
            # Discriminator
            loss_disc, _losses_disc_r, _losses_disc_g = discriminator_loss(
                y_d_hat_r, y_d_hat_g
            )
            loss_disc_all = loss_disc

            self.log("loss_disc_all", loss_disc_all, prog_bar=True)

            return loss_disc_all

    def validation_step(self, batch: Batch, batch_idx: int):
        val_loss = self.training_step_g(batch) + self.training_step_d(batch)
        self.log("val_loss", val_loss, prog_bar=True)

        # Generate audio examples
        for utt_idx, test_utt in enumerate(self._test_dataset):
            text = test_utt.phoneme_ids.unsqueeze(0).to(self.device)
            text_lengths = torch.LongTensor([len(test_utt.phoneme_ids)]).to(self.device)
            scales = [0.667, 1.0, 0.8]
            sid = (
                test_utt.speaker_id.to(self.device)
                if test_utt.speaker_id is not None
                else None
            )
            test_audio = self(text, text_lengths, scales, sid=sid).detach()

            # Scale to make louder in [-1, 1]
            test_audio = test_audio * (1.0 / max(0.01, abs(test_audio.max())))

            tag = test_utt.text or str(utt_idx)
            self.logger.experiment.add_audio(
                tag, test_audio, sample_rate=self.hparams.sample_rate
            )

        return val_loss

    def configure_optimizers(self):
        optimizers = [
            torch.optim.AdamW(
                self.model_g.parameters(),
                lr=self.hparams.learning_rate,
                betas=self.hparams.betas,
                eps=self.hparams.eps,
            ),
            torch.optim.AdamW(
                self.model_d.parameters(),
                lr=self.hparams.learning_rate,
                betas=self.hparams.betas,
                eps=self.hparams.eps,
            ),
        ]
        schedulers = [
            torch.optim.lr_scheduler.ExponentialLR(
                optimizers[0], gamma=self.hparams.lr_decay
            ),
            torch.optim.lr_scheduler.ExponentialLR(
                optimizers[1], gamma=self.hparams.lr_decay
            ),
        ]

        return [
            {"optimizer": optimizers[0], "lr_scheduler": schedulers[0]},
            {"optimizer": optimizers[1], "lr_scheduler": schedulers[1]},
        ]

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("VitsModel")
        parser.add_argument("--batch-size", type=int, required=True)
        parser.add_argument("--validation-split", type=float, default=0.1)
        parser.add_argument("--num-test-examples", type=int, default=5)
        parser.add_argument(
            "--max-phoneme-ids",
            type=int,
            help="Exclude utterances with phoneme id lists longer than this",
        )
        #
        parser.add_argument("--hidden-channels", type=int, default=192)
        parser.add_argument("--inter-channels", type=int, default=192)
        parser.add_argument("--filter-channels", type=int, default=768)
        parser.add_argument("--n-layers", type=int, default=6)
        parser.add_argument("--n-heads", type=int, default=2)
        #
        return parent_parser
"""
with open(lightning_py_path, "w") as f:
    f.write(lightning_py_content)
print("‚úÖ Fixed piper_train/vits/lightning.py")

# 3. Fix monotonic_align/__init__.py
monotonic_init = "piper_repo/src/python/piper_train/vits/monotonic_align/__init__.py"
if os.path.exists(monotonic_init):
    with open(monotonic_init, "r") as f:
        content = f.read()
    content = content.replace("from .monotonic_align.core", "from .core")
    with open(monotonic_init, "w") as f:
        f.write(content)
    print("‚úÖ Fixed monotonic_align import")

In [None]:
# @title Build Monotonic Align (Required for VITS)
import os
import shutil

piper_src_path = os.path.abspath("piper_repo/src/python")
monotonic_align_src = os.path.join(piper_src_path, "piper_train/vits/monotonic_align")

# Build directly in place since we have full permissions usually
# But to be safe and clean, use temp
temp_build_dir = "monotonic_align_build"
if os.path.exists(temp_build_dir):
    shutil.rmtree(temp_build_dir)
os.makedirs(temp_build_dir, exist_ok=True)

for filename in ["core.pyx", "setup.py"]:
    shutil.copy(os.path.join(monotonic_align_src, filename), os.path.join(temp_build_dir, filename))

%cd {temp_build_dir}
print("üî® Building monotonic_align...")
!python setup.py build_ext --inplace

import glob
so_files = glob.glob("core*.so")
if so_files:
    dest = os.path.join(monotonic_align_src, so_files[0])
    shutil.copy(so_files[0], dest)
    print(f"‚úÖ Installed compiled extension to: {dest}")
else:
    print("‚ùå Build failed, no .so file found")

%cd ..

## 3. Preprocessing

In [None]:
# @title Run Preprocessing
import os
import sys

piper_src_path = os.path.abspath("piper_repo/src/python")
DATASET_PATH = os.path.abspath("english") # Ensure this matches your uploaded folder name

print(f"üìÇ Dataset: {DATASET_PATH}")

# Add piper_phonemize to path if needed (though 'setup.py install' should have handled it)
# We strictly set PYTHONPATH to include piper src

!PYTHONPATH="{piper_src_path}" python -m piper_train.preprocess \
  --language en \
  --input-dir "{DATASET_PATH}" \
  --output-dir training_dir \
  --dataset-format ljspeech \
  --single-speaker \
  --sample-rate 22050

print("\n‚úÖ Preprocessing complete (if no errors above)")

## 4. Training

In [None]:
# @title Download Base Checkpoint
import urllib.request

os.makedirs("checkpoints", exist_ok=True)
checkpoint_url = "https://huggingface.co/datasets/rhasspy/piper-checkpoints/resolve/main/en/en_US/lessac/medium/epoch%3D2164-step%3D1355540.ckpt"
checkpoint_path = "checkpoints/epoch=2164-step=1355540.ckpt"

if not os.path.exists(checkpoint_path):
    print("üì• Downloading base checkpoint...")
    urllib.request.urlretrieve(checkpoint_url, checkpoint_path)
    print("‚úÖ Downloaded!")
else:
    print("‚úÖ Checkpoint already exists")

In [None]:
# @title Start Training

# Determine Accelerator (SageMaker might have different GPU count)
import torch
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
print(f"üöÄ Training on {devices} {accelerator}(s)")

# Note: Batch size 16 is safe for T4 (g4dn.xlarge). If using A10g (g5.xlarge), you can try 32 or 64.
!PYTHONPATH="{piper_src_path}" python -m piper_train \
  --dataset-dir training_dir \
  --accelerator {accelerator} \
  --devices {devices} \
  --batch-size 16 \
  --validation-split 0.0 \
  --num-test-examples 0 \
  --max_epochs 10000 \
  --resume_from_checkpoint "{checkpoint_path}" \
  --checkpoint-epochs 1 \
  --quality medium \
  --precision 32