In [3]:
from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
from transformers.trainer_pt_utils import LengthGroupedSampler
from transformers.optimization import get_scheduler

In [31]:
from parler_tts import (
    ParlerTTSConfig,
    ParlerTTSForConditionalGeneration,
    build_delay_pattern_mask,
)

from training.utils import (
    get_last_checkpoint,
    rotate_checkpoints,
    log_pred,
    log_metric,
    load_all_codec_checkpoints,
    save_codec_checkpoint,
    get_last_codec_checkpoint_step,
)
from accelerate import Accelerator, skip_first_batches
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
from accelerate.utils.memory import release_memory
from training.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments
from training.data import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding
from training.eval import clap_similarity, wer, si_sdr
from datasets import Dataset, IterableDataset, concatenate_datasets, interleave_datasets, load_dataset
import torch

In [11]:
mixed_precision = "bf16"
torch_dtype = torch.bfloat16

In [13]:
feature_extractor = AutoFeatureExtractor.from_pretrained(
    'parler-tts/parler-tts-mini-v1'
)
sampling_rate = feature_extractor.sampling_rate

preprocessor_config.json:   0%|          | 0.00/234 [00:00<?, ?B/s]

In [15]:
prompt_tokenizer = AutoTokenizer.from_pretrained(
    'google/flan-t5-large',
    padding_side='left',
)

# load description tokenizer
description_tokenizer = AutoTokenizer.from_pretrained(
    'google/flan-t5-large',
)

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

In [16]:
config = ParlerTTSConfig.from_pretrained(
    'parler-tts/parler-tts-mini-v1',
)

config.json:   0%|          | 0.00/6.93k [00:00<?, ?B/s]

In [18]:
model = ParlerTTSForConditionalGeneration.from_pretrained(
    'parler-tts/parler-tts-mini-v1',
    attn_implementation='sdpa',
)

  WeightNorm.apply(module, name, dim)


generation_config.json:   0%|          | 0.00/265 [00:00<?, ?B/s]

In [20]:
!nvidia-smi

Sat Sep 21 20:01:09 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.02              Driver Version: 555.42.02      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090 Ti     Off |   00000000:01:00.0 Off |                  Off |
| 30%   34C    P8             22W /  350W |   12628MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090 Ti     Off |   00

In [24]:
import accelerate

In [29]:
accelerator = Accelerator(
    gradient_accumulation_steps=1,
    mixed_precision=mixed_precision,
)

In [45]:
with accelerator.local_main_process_first():
    raw_datasets = load_dataset(
        'mesolitica/tts-combine-annotated', split = 'train'
    )
    
raw_datasets = raw_datasets.filter(lambda x: x['prompt'] is not None and x['transcription'] is not None)
raw_datasets = raw_datasets.filter(lambda x: len(x['prompt']) < 500)

Filter:   0%|          | 0/360248 [00:00<?, ? examples/s]

In [50]:
import soundfile as sf

def check_len(f):
    y, sr = sf.read(f)
    return len(y) / sr

In [52]:
raw_datasets = raw_datasets.filter(lambda x: 0 < check_len(x['audio_filename']) < 30, num_proc = 10)

Filter (num_proc=10):   0%|          | 0/358348 [00:00<?, ? examples/s]

In [56]:
def pass_through_processors(description, prompt):
    batch = {}

    batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
    batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]

    return batch

with accelerator.local_main_process_first():
    vectorized_datasets = raw_datasets.map(
        pass_through_processors,
        remove_columns=raw_datasets.column_names,
        input_columns=['prompt', 'transcription'],
        num_proc=5,
        desc="preprocess datasets",
    )

preprocess datasets (num_proc=5):   0%|          | 0/358348 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (557 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (557 > 512). Running this sequence through the model will result in indexing errors


In [57]:
vectorized_datasets

Dataset({
    features: ['input_ids', 'prompt_input_ids'],
    num_rows: 358348
})

In [65]:
from tqdm import tqdm
from collections import defaultdict

speakers = defaultdict(list)
for i in tqdm(range(len(raw_datasets))):
    speakers[raw_datasets[i]['speaker']].append(raw_datasets[i])

100%|███████████████████████████████████████████████████████████████████████| 358348/358348 [00:52<00:00, 6885.79it/s]


In [90]:
for speaker in speakers.keys():
    print(len(speakers[speaker]) * 0.01)

1052.65
1052.98
309.39
309.01
309.46
308.55
117.02
124.42


In [91]:
import random

train, test = [], []
for speaker in speakers.keys():
    for row in tqdm(speakers[speaker]):
        if random.random() > 0.01:
            train.append(row)
        else:
            test.append(row)

100%|████████████████████████████████████████████████████████████████████| 105265/105265 [00:00<00:00, 6430243.96it/s]
100%|████████████████████████████████████████████████████████████████████| 105298/105298 [00:00<00:00, 6622360.18it/s]
100%|██████████████████████████████████████████████████████████████████████| 30939/30939 [00:00<00:00, 6655429.86it/s]
100%|██████████████████████████████████████████████████████████████████████| 30901/30901 [00:00<00:00, 6826154.10it/s]
100%|██████████████████████████████████████████████████████████████████████| 30946/30946 [00:00<00:00, 6699888.07it/s]
100%|██████████████████████████████████████████████████████████████████████| 30855/30855 [00:00<00:00, 6713802.13it/s]
100%|██████████████████████████████████████████████████████████████████████| 11702/11702 [00:00<00:00, 6807454.29it/s]
100%|██████████████████████████████████████████████████████████████████████| 12442/12442 [00:00<00:00, 6328587.24it/s]


In [92]:
len(train), len(test)

(354705, 3643)

In [93]:
test[-1]

{'transcription': 'lidah yang diam dan mendengar sahaja itu Ketawa pula Lalu dia menerangkan',
 'speaker': 'Elina',
 'speaker_id': 7,
 'gender': 'female',
 'utterance_pitch_mean': 149.9508819580078,
 'utterance_pitch_std': 29.23186683654785,
 'snr': 70.31266021728516,
 'c50': 31.286865234375,
 'speech_duration': 6.142500000000002,
 'stoi': 0.9906082153320312,
 'si-sdr': 22.291345596313477,
 'pesq': 3.2068662643432617,
 'pitch': 'very low pitch',
 'speaking_rate': 'very slowly',
 'noise': 'very clear',
 'reverberation': 'quite roomy sounding',
 'speech_monotony': 'very monotone',
 'prompt': 'Elina, a female speaker delivers an expressive and animated speech in a room with slight background noise. Her voice is quite roomy-sounding, with some clarity present. She speaks very slowly with a very low-pitch, but a very monotone tone.',
 'audio_filename': 'combine-audio/360226.mp3'}

In [95]:
dataset_dict = DatasetDict({
    'train': Dataset.from_list(train),
    'test': Dataset.from_list(test)
})

In [96]:
dataset_dict.push_to_hub('huseinzol05/processed-tts-combine-annotated')

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/355 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/1.09k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/datasets/huseinzol05/processed-tts-combine-annotated/commit/2cf4fac860b3d1382ae52ed3ea3020809746dfb9', commit_message='Upload dataset', commit_description='', oid='2cf4fac860b3d1382ae52ed3ea3020809746dfb9', pr_url=None, pr_revision=None, pr_num=None)

In [141]:
from training.data import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding

feature_extractor_input_name = feature_extractor.model_input_names[0]
max_target_length = 30 * sampling_rate
padding = "longest"
max_length = model.generation_config.max_length
num_codebooks = model.decoder.config.num_codebooks
audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
bandwidth = 6
encoder_data_collator = DataCollatorEncodecWithPadding(
    feature_extractor,
    audio_column_name='audio',
    feature_extractor_input_name=feature_extractor_input_name,
    max_length=max_target_length,
    padding=padding,
)

In [142]:
def apply_audio_decoder(batch):
    len_audio = batch.pop("len_audio")
    audio_decoder.to(batch["input_values"].device).eval()
    with torch.no_grad():
        labels = audio_decoder.encode(**batch, bandwidth=bandwidth)["audio_codes"]
    output = {}
    output["len_audio"] = len_audio
    # (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
    output["labels"] = labels.squeeze(0).transpose(1, 2)

    # if `pad_to_max_length`, the maximum corresponding audio length of the current batch is max_duration*sampling_rate
    max_length = len_audio.max() if padding != "max_length" else max_target_length
    output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / max_length
    return output

# (1, codebooks, seq_len) where seq_len=1
bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id

def postprocess_dataset(labels):
    # (1, codebooks, seq_len)
    labels = torch.tensor(labels).unsqueeze(0)
    # add bos
    labels = torch.cat([bos_labels, labels], dim=-1)

    labels, delay_pattern_mask = build_delay_pattern_mask(
        labels,
        bos_token_id=audio_encoder_bos_token_id,
        pad_token_id=audio_encoder_eos_token_id,
        max_length=labels.shape[-1] + num_codebooks,
        num_codebooks=num_codebooks,
    )

    # the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
    # to take care of EOS
    # we want labels to look like this:
    #  - [B, a, b, E, E, E, E]
    #  - [B, B, c, d, E, E, E]
    #  - [B, B, B, e, f, E, E]
    #  - [B, B, B, B, g, h, E]
    labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)

    # the first timestamp is associated to a row full of BOS, let's get rid of it
    # we also remove the last timestampts (full of PAD)
    output = {"labels": labels[:, 1:]}
    return output

In [143]:
d = dataset_dict['train'].rename_column('audio_filename', 'audio')

In [144]:
d = d.cast_column("audio", Audio(sampling_rate = feature_extractor.sampling_rate))

In [145]:
batch = [d[i] for i in range(10)]
batch = encoder_data_collator(batch)
batch

{'padding_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]], dtype=torch.int32), 'input_values': tensor([[[-0.0008, -0.0012, -0.0015,  ...,  0.0005,  0.0003,  0.0001]],

        [[-0.0008, -0.0012, -0.0015,  ...,  0.0005,  0.0003,  0.0001]],

        [[-0.0008, -0.0012, -0.0015,  ...,  0.0005,  0.0003,  0.0001]],

        ...,

        [[-0.0008, -0.0012, -0.0015,  ...,  0.0005,  0.0003,  0.0001]],

        [[-0.0008, -0.0012, -0.0015,  ...,  0.0005,  0.0003,  0.0001]],

        [[-0.0008, -0.0012, -0.0015,  ...,  0.0005,  0.0003,  0.0001]]]), 'len_audio': tensor([[314346],
        [314346],
        [314346],
        [314346],
        [314346],
        [314346],
        [314346],
        [314346],
        [314346],
        [314346]])}

In [146]:
audio_decoder = model.audio_encoder

In [147]:
generate_labels = apply_audio_decoder(batch)

In [148]:
generate_labels

{'len_audio': tensor([[314346],
         [314346],
         [314346],
         [314346],
         [314346],
         [314346],
         [314346],
         [314346],
         [314346],
         [314346]]),
 'labels': tensor([[[698, 710, 114,  ..., 496, 387, 348],
          [698, 249, 540,  ..., 947, 803, 902],
          [698, 578, 888,  ..., 428, 570, 683],
          ...,
          [568, 778, 771,  ..., 338, 378, 731],
          [698, 151, 229,  ..., 145, 954, 726],
          [698, 847, 408,  ..., 138,  83, 640]],
 
         [[698, 710, 114,  ..., 496, 387, 348],
          [698, 249, 540,  ..., 947, 803, 902],
          [698, 578, 888,  ..., 428, 570, 683],
          ...,
          [568, 778, 771,  ..., 338, 378, 731],
          [698, 151, 229,  ..., 145, 954, 726],
          [698, 847, 408,  ..., 138,  83, 640]],
 
         [[698, 710, 114,  ..., 496, 387, 348],
          [698, 249, 540,  ..., 947, 803, 902],
          [698, 578, 888,  ..., 428, 570, 683],
          ...,
          [568