In [14]:
from datasets import load_dataset
path = "EN/*001.tar"
hfdataset = load_dataset("amphion/Emilia-Dataset", data_files={"en": path}, split="en")

EN-B000001.tar:   0%|          | 0.00/2.03G [00:00<?, ?B/s]

EN-B001001.tar:   0%|          | 0.00/105M [00:00<?, ?B/s]

Generating en split: 0 examples [00:00, ? examples/s]

In [15]:
from pathlib import Path

from torch.optim import AdamW

from f5_tts.model.duration import DurationPredictor, DurationTransformer
from f5_tts.model.trainer import DurationTrainer
from f5_tts.model.dataset import TextAudioDataset, HFDataset

vocab_path = "data/de/vocab.txt"
vocab = {v: i for i, v in enumerate(Path(vocab_path).read_text().split("\n"))}

duration_predictor = DurationPredictor(
    transformer=DurationTransformer(
        dim=512,
        depth=8,
        heads=8,
        text_dim=512,
        ff_mult=2,
        conv_layers=2,
        text_num_embeds=len(vocab) - 1,
    ),
    vocab_char_map=vocab,
)
print(
    f"Trainable parameters: {sum(p.numel() for p in duration_predictor.parameters() if p.requires_grad)}"
)

optimizer = AdamW(duration_predictor.parameters(), lr=7.5e-5)

trainer = DurationTrainer(
    duration_predictor,
    optimizer,
    num_warmup_steps=5000,
    # accelerate_kwargs = {"mixed_precision": "fp16", "log_with": "wandb"}
    accelerate_kwargs={"mixed_precision": "no"},
)

epochs = 25
max_batch_tokens = 16_000

# train_dataset = TextAudioDataset(
#     folder = Path("LibriTTS_R").expanduser(),
#     audio_extensions = ["wav"],
#     max_duration = 44
# )
train_dataset = HFDataset(hfdataset)
print("Training...")


Trainable parameters: 21852160
Training...


In [16]:
# trainer.load_checkpoint("f5tts_duration_3000.pt")
trainer.load_checkpoint("../f5tts_duration.pt")

0

In [17]:

dl = trainer.train(train_dataset, epochs, max_batch_tokens, num_workers=0, save_step=1000)

In [18]:
from einops import rearrange
batch = next(iter(dl))
m = trainer.accelerator.unwrap_model(trainer.model)
text_inputs = batch["text"]
mel_spec = rearrange(batch["mel"], "b d n -> b n d")
mel_lengths = batch["mel_lengths"]

loss = m(
    mel_spec, text=text_inputs, lens=mel_lengths, return_loss=False
)

In [7]:
from safetensors.torch import save_file, save_model
import torch

save_model(trainer.accelerator.unwrap_model(trainer.model), "duration_v2.safetensors")

In [8]:
SAMPLE_RATE = 24_000
HOP_LENGTH = 256
SAMPLES_PER_SECOND = SAMPLE_RATE / HOP_LENGTH
loss*SAMPLES_PER_SECOND

tensor([72.9162, 67.5215, 68.1405, 79.3960, 72.2256, 75.3841, 68.6786, 68.8755,
        74.7844, 69.5419, 67.1713, 72.2198, 70.6283, 71.7877, 75.4355, 67.6825,
        71.8702, 68.8977, 67.4177], device='mps:0', grad_fn=<MulBackward0>)

In [19]:
batch['mel_lengths']

tensor([ 352,  600,  378,  644,  656,  706,  289, 1680, 1797,  538,  698, 1479,
         761,  500,  282,  363,  425,  528,  656, 1400], device='mps:0')

In [20]:
loss

tensor([0.5553, 0.6148, 0.5615, 0.6153, 0.6178, 0.6223, 0.5385, 0.6636, 0.6613,
        0.6009, 0.6223, 0.6581, 0.6228, 0.5864, 0.5168, 0.5600, 0.5809, 0.5982,
        0.6100, 0.6644], device='mps:0', grad_fn=<ViewBackward0>)