In [1]:
# Environment and GPU sanity checks
import os, sys
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # choose GPU
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

import tensorflow as tf
print('Python:', sys.version)
print('TF version:', tf.__version__)
print('Physical GPUs:', tf.config.list_physical_devices('GPU'))
# Enable memory growth early
for g in tf.config.list_physical_devices('GPU'):
    try:
        tf.config.experimental.set_memory_growth(g, True)
    except Exception as e:
        print('Memory growth warning:', e)
# Print device placement to confirm GPU usage
# tf.debugging.set_log_device_placement(False)


2025-09-13 05:01:04.377890: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757727064.394928 2313882 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757727064.400122 2313882 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1757727064.414159 2313882 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1757727064.414173 2313882 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1757727064.414175 2313882 computation_placer.cc:177] computation placer alr

Python: 3.10.12 (main, Aug 15 2025, 14:32:43) [GCC 11.4.0]
TF version: 2.19.0
Physical GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
# Build config with stronger defaults for GPU utilization
from myxtts.config.config import XTTSConfig, ModelConfig, DataConfig, TrainingConfig
from myxtts.utils.performance import start_performance_monitoring
start_performance_monitoring()

# Dataset paths
train_data_path = '../dataset/dataset_train'
val_data_path = '../dataset/dataset_eval'
print('Train path exists:', os.path.exists(train_data_path))
print('Val path exists  :', os.path.exists(val_data_path))

# Tunables
TRAIN_FRAC = 0.1  # e.g., 0.1 = 10% of train
EVAL_FRAC  = 0.1  # e.g., 0.1 = 10% of eval
BATCH_SIZE = 8  # increase if VRAM allows (e.g., 48)
NUM_WORKERS = max(2, (os.cpu_count() or 8)//4)

m = ModelConfig(
    text_encoder_dim=256,
    decoder_dim=512,
    n_mels=80,
    use_voice_conditioning=True
)
t = TrainingConfig(
    epochs=200,
    learning_rate=5e-5,
    optimizer='adamw',
    warmup_steps=2000,
    multi_gpu=False,
    visible_gpus=None,
    log_step=100
)
d = DataConfig(
    train_subset_fraction=TRAIN_FRAC,
    eval_subset_fraction=EVAL_FRAC,
    batch_size=BATCH_SIZE,
    metadata_train_file='metadata_train.csv',
    metadata_eval_file='metadata_eval.csv',
    wavs_train_dir='wavs',
    wavs_eval_dir='wavs',
    sample_rate=22050,
    normalize_audio=True,
    num_workers=NUM_WORKERS,
    enable_memory_mapping=True,
    prefetch_buffer_size=1,
    shuffle_buffer_multiplier=20,
    max_mel_frames=320,
    enable_xla=False,
    prefetch_to_gpu=False,
    mixed_precision=True
)

config = XTTSConfig(model=m, data=d, training=t)
print('Batch size:', config.data.batch_size, '| Workers:', config.data.num_workers)


Performance monitoring started
Train path exists: True
Val path exists  : True
Batch size: 8 | Workers: 16


In [3]:
# Optional: one-time cache precompute to remove CPU/I-O bottlenecks
PRECOMPUTE = True
if PRECOMPUTE:
    from myxtts.data.ljspeech import LJSpeechDataset
    print('Precomputing caches...')
    ds_tr = LJSpeechDataset(train_data_path, config.data, subset='train', download=False, preprocess=True)
    ds_va = LJSpeechDataset(val_data_path,   config.data, subset='val',   download=False, preprocess=True)
    ds_tr.precompute_mels(num_workers=config.data.num_workers, overwrite=False)
    ds_va.precompute_mels(num_workers=config.data.num_workers, overwrite=False)
    ds_tr.precompute_tokens(num_workers=config.data.num_workers, overwrite=False)
    ds_va.precompute_tokens(num_workers=config.data.num_workers, overwrite=False)
    print('Verifying caches...')
    print('Train verify:', ds_tr.verify_and_fix_cache(fix=True))
    print('Val verify  :', ds_va.verify_and_fix_cache(fix=True))
    print('Train usable:', ds_tr.filter_items_by_cache())
    print('Val usable  :', ds_va.filter_items_by_cache())
    del ds_tr, ds_va


Precomputing caches...
Loaded 20509 items for train subset
Loaded 2591 items for val subset
Precomputing mel spectrograms to ../dataset/dataset_train/processed/mels_sr22050_n80_hop256 (overwrite=False)...
All mel spectrograms already cached.
Precomputing mel spectrograms to ../dataset/dataset_eval/processed/mels_sr22050_n80_hop256 (overwrite=False)...
All mel spectrograms already cached.
Verifying caches...
Train verify: {'checked': 20509, 'fixed': 0, 'failed': 0}
Val verify  : {'checked': 2591, 'fixed': 0, 'failed': 0}
Train usable: 20509
Val usable  : 2591


In [4]:
# Training with GPU monitoring and timing
from myxtts import get_xtts_model, get_trainer, get_inference_engine

model = get_xtts_model()(config.model)
trainer = get_trainer()(config, model)
train_dataset, val_dataset = trainer.prepare_datasets(train_data_path=train_data_path, val_data_path=val_data_path)
print('Train samples:', getattr(trainer, 'train_dataset_size', 'n/a'))
print('Val samples  :', getattr(trainer, 'val_dataset_size', 'n/a'))

trainer.train(train_dataset, val_dataset)
    

I0000 00:00:1757727077.312539 2313882 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22135 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4090, pci bus id: 0000:01:00.0, compute capability: 8.9
2025-09-13 05:01:18,437 - MyXTTS - INFO - Using device: GPU
2025-09-13 05:01:18,437 - MyXTTS - INFO - Using strategy: OneDeviceStrategy
2025-09-13 05:01:18,438 - MyXTTS - INFO - Mixed precision enabled
2025-09-13 05:01:18,438 - MyXTTS - INFO - XLA compilation disabled
2025-09-13 05:01:18,445 - MyXTTS - INFO - Wrapped optimizer with LossScaleOptimizer for mixed precision


Enabled memory growth for PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')
Mixed precision policy enabled for memory optimization
XLA JIT compilation enabled for memory optimization
Physical GPUs: 1, Logical GPUs: 1
Using single GPU strategy


2025-09-13 05:01:19,547 - MyXTTS - INFO - Dataset preprocessing mode: auto
2025-09-13 05:01:19,547 - MyXTTS - INFO - Preprocessing mode: AUTO - Attempting to precompute with graceful fallback


Loaded 20509 items for train subset
Loaded 2591 items for val subset
Precomputing mel spectrograms to ../dataset/dataset_train/processed/mels_sr22050_n80_hop256 (overwrite=False)...
All mel spectrograms already cached.
Precomputing mel spectrograms to ../dataset/dataset_eval/processed/mels_sr22050_n80_hop256 (overwrite=False)...
All mel spectrograms already cached.


2025-09-13 05:01:25,844 - MyXTTS - INFO - Cache verify: train {'checked': 20509, 'fixed': 0, 'failed': 0}, val {'checked': 2591, 'fixed': 0, 'failed': 0}
2025-09-13 05:01:27,051 - MyXTTS - INFO - Using cached items - train: 20509, val: 2591
2025-09-13 05:01:28,280 - MyXTTS - INFO - Training samples: 20509
2025-09-13 05:01:28,280 - MyXTTS - INFO - Validation samples: 2591
2025-09-13 05:01:28,280 - MyXTTS - INFO - Data loading performance:
2025-09-13 05:01:28,281 - MyXTTS - INFO - === Data Loading Profile ===

Cache Efficiency: 0.0%
  Hits: 0
  Misses: 0
  Errors: 0


2025-09-13 05:01:28,281 - MyXTTS - INFO - Starting training for 200 epochs
2025-09-13 05:01:28,281 - MyXTTS - INFO - Current step: 0


Train samples: 20509
Val samples  : 2591
Performance monitoring started


Epoch 0:   0%|          | 4/2564 [02:22<22:52:35, 32.17s/it, loss=219.5270, step=4, data_ms=3.1, comp_ms=28240.1, mel=4.87, stop=0.601] 



Epoch 0:   0%|          | 5/2564 [02:49<21:34:36, 30.35s/it, loss=214.8343, step=5, data_ms=1.5, comp_ms=27131.8, mel=4.76, stop=0.585]



Epoch 0:   1%|          | 14/2564 [06:59<21:13:09, 29.96s/it, loss=186.0670, step=14, data_ms=1.9, comp_ms=29886.8, mel=4.13, stop=0.285] 


KeyboardInterrupt: 

In [None]:
# Inference demo
from myxtts import get_inference_engine
inference = get_inference_engine()(config, checkpoint_path='./checkpoints/best')
result = inference.synthesize('Hello world!')
inference.save_audio(result['audio'], 'output.wav')
