# Fine-tuning Parler-TTS

## Goal of this notebook

In the following notebook, we'll fine-tune [Parler-TTS Mini v0.1](https://huggingface.co/parler-tts/parler_tts_mini_v0.1) on a 5h subset of the [Jenny TTS dataset](https://github.com/dioco-group/jenny-tts-dataset), a 30 hours high-quality mono-speaker TTS dataset, from an Irish female speaker named Jenny.

In particular, we'll:
- Annotate the Jenny dataset with natural language speech description using [Data-Speech](https://github.com/huggingface/dataspeech).
- Fine-tune Parler-TTS with the created dataset.

**You should be able to adapt this notebook to your own datasets quite easily.**





## Prepare the Environment

Throughout this tutorial, we'll use a GPU. The runtime is already configured to use the free 16GB T4 GPU provided through Google Colab Free Tier, so all you need to do is hit "Connect T4" in the top right-hand corner of the screen.

##### <a name="installation"> We'll install Parler-TTS and Data-Speech from source in order to train our model.

In [2]:
!git clone https://github.com/huggingface/dataspeech.git

Cloning into 'dataspeech'...
remote: Enumerating objects: 496, done.[K
remote: Counting objects: 100% (124/124), done.[K
remote: Compressing objects: 100% (48/48), done.[K
remote: Total 496 (delta 81), reused 76 (delta 76), pack-reused 372[K
Receiving objects: 100% (496/496), 116.45 KiB | 2.43 MiB/s, done.
Resolving deltas: 100% (307/307), done.


In [3]:
!cd dataspeech
!pip install --quiet -r ./dataspeech/requirements.txt

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
fastai 2.7.14 requires torch<2.3,>=1.10, but you have torch 2.3.0 which is incompatible.
spacy 3.7.3 requires typer<0.10.0,>=0.3.0, but you have typer 0.12.3 which is incompatible.
weasel 0.3.4 requires typer<0.10.0,>=0.3.0, but you have typer 0.12.3 which is incompatible.[0m[31m
[0m

In [4]:
!git clone https://github.com/huggingface/parler-tts.git
%cd parler-tts
!pip install --quiet -e .[train]

Cloning into 'parler-tts'...
remote: Enumerating objects: 815, done.[K
remote: Counting objects: 100% (201/201), done.[K
remote: Compressing objects: 100% (66/66), done.[K
remote: Total 815 (delta 162), reused 144 (delta 135), pack-reused 614[K
Receiving objects: 100% (815/815), 261.18 KiB | 3.00 MiB/s, done.
Resolving deltas: 100% (499/499), done.
/kaggle/working/parler-tts
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf 23.8.0 requires cubinlinker, which is not installed.
cudf 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.
cudf 23.8.0 requires ptxcompiler, which is not installed.
cuml 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.
dask-cudf 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.
apache-beam 2.46.0 requires dill<0.3.2

On Colab, we need to run an additional set-up, that you can skip if you're on your local machine.

In [5]:
!pip install --upgrade protobuf wandb==0.16.6

Collecting protobuf
  Downloading protobuf-5.27.0-cp38-abi3-manylinux2014_x86_64.whl.metadata (592 bytes)
  Downloading protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl.metadata (541 bytes)
Downloading protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl (294 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.6/294.6 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: protobuf
  Attempting uninstall: protobuf
    Found existing installation: protobuf 3.19.6
    Uninstalling protobuf-3.19.6:
      Successfully uninstalled protobuf-3.19.6
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf 23.8.0 requires cubinlinker, which is not installed.
cudf 23.8.0 requires cupy-cuda11x>=12.0.0, which is not installed.
cudf 23.8.0 requires ptxcompiler, which is not installed.
cuml 23.8.0 requires cupy-

You should link you Hugging Face account so that you can push model repositories on the Hub. This will allow you to save your trained models on the Hub so that you can share them with the community.

Run the command below and then enter an authentication token from https://huggingface.co/settings/tokens. Create a new token if you do not have one already. You should make sure that this token has "write" privileges.

In [6]:
# !git config --global credential.helper store
# !huggingface-cli login

from huggingface_hub import login
login(API_token)

Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


## 1. Creating our fine-tuning dataset


The aim here is to create an annotated version of Jenny TTS, in order to fine-tune the [Parler-TTS v0.1 checkpoint](https://huggingface.co/parler-tts/parler_tts_mini_v0.1) on this dataset.

Thanks to a [script similar to what's described in the Data-Speech FAQ](https://github.com/huggingface/dataspeech?tab=readme-ov-file#how-do-i-use-datasets-that-i-have-with-this-repository), we've uploaded the dataset to the HuggingFace hub, under the name [reach-vb/jenny_tts_dataset](https://huggingface.co/datasets/reach-vb/jenny_tts_dataset).

The purpose of this notebook is demonstration so we've pushed a 6h subset of the dataset that we'll work with: [ylacombe/jenny-tts-6h](https://huggingface.co/datasets/ylacombe/jenny-tts-6h).

Feel free to follow the link above to listen to some samples of the Jenny TTS dataset thanks to the hub viewer.

> Refer to the [Data-Speech README](https://github.com/huggingface/dataspeech?tab=readme-ov-file#data-speech) for more detailed explanations of what's going on under-the-hood.

We'll:
1. Annotate the Jenny dataset with continuous variables that measures the speech characteristics
2. Map those annotations to text bins that characterize the speech characteristics.
3. Create natural language descriptions from those text bins

In [7]:
%cd ../dataspeech

/kaggle/working/dataspeech


But first, let's look at a few samples from the Jenny dataset!

In [8]:
from datasets import load_dataset
dataset = load_dataset("ylacombe/jenny-tts-6h")

Downloading readme:   0%|          | 0.00/420 [00:00<?, ?B/s]

Downloading data: 100%|██████████| 326M/326M [00:02<00:00, 157MB/s]  
Downloading data: 100%|██████████| 371M/371M [00:02<00:00, 164MB/s]  
Downloading data: 100%|██████████| 319M/319M [00:02<00:00, 152MB/s]  


Generating train split:   0%|          | 0/4000 [00:00<?, ? examples/s]

In [9]:
from IPython.display import Audio
print(dataset["train"][0]["transcription"])
Audio(dataset["train"][0]["audio"]["array"], rate=dataset["train"][0]["audio"]["sampling_rate"])

It was a bright cold day in April, and the clocks were striking thirteen.


In [10]:
from IPython.display import Audio
print(dataset["train"][1]["transcription"])
Audio(dataset["train"][1]["audio"]["array"], rate=dataset["train"][1]["audio"]["sampling_rate"])

'I wonder if I shall ever be happy enough to have real lace on my clothes and bows on my caps?'


In [11]:
del dataset


### Annotating the Jenny dataset

We'll use [`main.py`](https://github.com/huggingface/dataspeech/blob/main/main.py) to get the following continuous variables:
- Speaking rate `(nb_phonemes / utterance_length)`
- Signal-to-noise ratio (SNR)
- Reverberation
- Speech monotony


In [12]:
# !python main.py "ylacombe/jenny-tts-6h" \
#   --configuration "default" \
#   --text_column_name "transcription" \
#   --audio_column_name "audio" \
#   --cpu_num_workers 2 \
#   --num_workers_per_gpu_for_pitch 2 \
#   --rename_column \
#   --repo_id "jenny-tts-tags-6h"

The whole process took under 10mn!

The resulting dataset will be pushed to the HuggingFace hub under your HuggingFace handle. "Cintin/jenny-tts-tags-6h"

(https://huggingface.co/datasets/ylacombe/jenny-tts-tags-6h).

Let's see what the new dataset looks like:

In [13]:
from datasets import load_dataset
dataset = load_dataset("Cintin/jenny-tts-tags-6h")
print("SNR 1st sample", dataset["train"][0]["snr"])
print("C50 2nd sample", dataset["train"][0]["c50"])
del dataset

Downloading readme:   0%|          | 0.00/728 [00:00<?, ?B/s]

Downloading data: 100%|██████████| 933k/933k [00:00<00:00, 2.21MB/s]


Generating train split:   0%|          | 0/4000 [00:00<?, ? examples/s]

SNR 1st sample 54.890892028808594
C50 2nd sample 59.73095703125


As you can see, the current annotations are continuous variables. To use it with Parler-TTS, we need to convert it to textual description, something that the two next steps will take care of.

### 2. Map annotations to text bins

Since the ultimate goal here is to fine-tune the [Parler-TTS v0.1 checkpoint](https://huggingface.co/parler-tts/parler_tts_mini_v0.1) on the Jenny dataset, we want to stay consistent with the text bins of the datasets on which the latter model was trained.

This is easy to do thanks to the following:

In [14]:
# !python ./scripts/metadata_to_text.py \
#     "Cintin/jenny-tts-tags-6h" \
#     --repo_id "jenny-tts-tags-6h" \
#     --configuration "default" \
#     --cpu_num_workers 2 \
#     --path_to_bin_edges "./examples/tags_to_annotations/v01_bin_edges.json" \
#     --avoid_pitch_computation

Thanks to [`v01_bin_edges.json`](https://github.com/huggingface/dataspeech/blob/main/examples/tags_to_annotations/v01_bin_edges.json), we don't need to recompute bins from scratch and the above script takes a few seconds.

The resulting dataset will be pushed to the HuggingFace hub under your HuggingFace handle.

(https://huggingface.co/datasets/Cintin/jenny-tts-tags-6h).

You can notice that text bins such as `quite noisy`, `very fast` have been added to the samples.

In [15]:
from datasets import load_dataset
dataset = load_dataset("Cintin/jenny-tts-tags-6h")
print("Noise 1st sample:", dataset["train"][0]["noise"])
print("Speaking rate 2nd sample:", dataset["train"][0]["speaking_rate"])
del dataset

Noise 1st sample: slightly clear
Speaking rate 2nd sample: very fast



### 3. Create natural language descriptions from those text bins

Now that we have text bins associated to the Jenny dataset, the next step is to create natural language descriptions out of the few created features.

Here, we decided to create prompts that use the name `Jenny`, prompts that'll look like the following:
`In a very expressive voice, Jenny pronounces her words incredibly slowly. There's some background noise in this room with a bit of echo'`

This step generally demands more resources and times and should use one or many GPUs.

The following command shows how to do it using the [2B version of the Gemma model from Google](https://huggingface.co/google/gemma-2b-it), which should run in about 15 minutes in this Colab free T4.


As usual, we precise the dataset name and configuration we want to annotate. `model_name_or_path` should point to a `transformers` model for prompt annotation. You can find a list of such models [here](https://huggingface.co/models?pipeline_tag=text-generation&library=transformers&sort=trending).

**Note** how we've been able to specify that the dataset is mono-speaker and that we should name the voice Jenny thanks to the flags:


`--speaker_name "Jenny" --is_single_speaker`.


In [16]:
# !python ./scripts/run_prompt_creation.py \
#   --speaker_name "Jenny" \
#   --is_single_speaker \
#   --dataset_name "Cintin/jenny-tts-tags-6h" \
#   --output_dir "./tmp_jenny" \
#   --dataset_config_name "default" \
#   --model_name_or_path "google/gemma-2b-it" \
#   --per_device_eval_batch_size 12 \
#   --attn_implementation "sdpa" \
#   --dataloader_num_workers 2 \
#   --push_to_hub \
#   --hub_dataset_id "jenny-tts-6h-tagged" \
#   --preprocessing_num_workers 2

Let's take a look at some created prompts:

In [17]:
from datasets import load_dataset
dataset = load_dataset("Cintin/jenny-tts-6h-tagged")
print("1st sample:", dataset["train"][0]["text_description"])
print("2nd sample:", dataset["train"][1]["text_description"])
del dataset

Downloading readme:   0%|          | 0.00/774 [00:00<?, ?B/s]

Downloading data: 100%|██████████| 1.05M/1.05M [00:00<00:00, 2.86MB/s]


Generating train split:   0%|          | 0/4000 [00:00<?, ? examples/s]

1st sample: 'Jenny speaks with a very monotone tone of voice, and the recording is quite clear with minimal background noise.'
2nd sample: 'The speech sample is very clear but slightly muffled due to some background noise, and the pace is quite fast.'


**Observation:** The first sample unfortunately doesn't have the name Jenny in it. This is probably because we use a smaller and thus less precise model that one we would have gone for if this notebook had more resources (e.g we've used [Mistral 7B v2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) to create the Parler-TTS training dataset). This shouldn't prevent our model to learn what we want though.

## Fine-tuning Parler-TTS



In [18]:
%cd ../parler-tts

/kaggle/working/parler-tts


In [19]:
import wandb
wandb.login(key='5072282d0e70f6a6e6f59e9497b9e35fcd96fe56')

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [20]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mee23s061[0m. Use [1m`wandb login --relogin`[0m to force relogin


We can know fully focus on fine-tuning Parler-TTS. Luckily, [the Parler-TTS library](https://github.com/huggingface/.parler-tts) has a training script available [here](https://github.com/huggingface/parler-tts/tree/main/training), that can be used with just a few arguments.


> **Note:** you need to enter your choice concerning WandB. If you don't have an account, you can enter `3` to avoid logging on WandB. Otherwise; you can logging to follow how your model trained.

In [21]:
!accelerate launch ./training/run_parler_tts_training.py \
    --model_name_or_path "parler-tts/parler_tts_mini_v0.1" \
    --feature_extractor_name "parler-tts/dac_44khZ_8kbps" \
    --description_tokenizer_name "parler-tts/parler_tts_mini_v0.1" \
    --prompt_tokenizer_name "parler-tts/parler_tts_mini_v0.1" \
    --report_to "wandb" \
    --wandb_project ptts \
    --wandb_run_name ptts_run \
    --overwrite_output_dir true \
    --train_dataset_name "ylacombe/jenny-tts-6h" \
    --train_metadata_dataset_name "Cintin/jenny-tts-6h-tagged" \
    --train_dataset_config_name "default" \
    --train_split_name "train" \
    --eval_dataset_name "ylacombe/jenny-tts-6h" \
    --eval_metadata_dataset_name "Cintin/jenny-tts-6h-tagged" \
    --eval_dataset_config_name "default" \
    --eval_split_name "train" \
    --max_eval_samples 8 \
    --per_device_eval_batch_size 8 \
    --target_audio_column_name "audio" \
    --description_column_name "text_description" \
    --prompt_column_name "text" \
    --max_duration_in_seconds 20 \
    --min_duration_in_seconds 2.0 \
    --max_text_length 400 \
    --preprocessing_num_workers 2 \
    --do_train true \
    --num_train_epochs 2 \
    --gradient_accumulation_steps 18 \
    --gradient_checkpointing true \
    --per_device_train_batch_size 2 \
    --learning_rate 0.00008 \
    --adam_beta1 0.9 \
    --adam_beta2 0.99 \
    --weight_decay 0.01 \
    --lr_scheduler_type "constant_with_warmup" \
    --warmup_steps 50 \
    --logging_steps 2 \
    --freeze_text_encoder true \
    --audio_encoder_per_device_batch_size 4 \
    --dtype "float16" \
    --seed 456 \
    --output_dir "./output_dir_training/" \
    --temporary_save_to_disk "./audio_code_tmp/" \
    --save_to_disk "./tmp_dataset_audio/" \
    --dataloader_num_workers 2 \
    --do_eval \
    --evaluation_strategy=epoch \
    --predict_with_generate \
    --include_inputs_for_metrics \
    --group_by_length true

  warn(
2024-05-26 13:51:31.289737: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-26 13:51:31.289734: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-26 13:51:31.289829: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-26 13:51:31.289856: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-26 13:51:31.449328: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register

## Inference

The full training on the free T4 from Google Colab took about an hour.
Now, let's see how to do inference with the newly fine-tuned model!

First install the Parler-TTS library:

In [22]:
# !pip install git+https://github.com/huggingface/parler-tts.git

Then:

In [23]:
# from parler_tts import ParlerTTSForConditionalGeneration
# from transformers import AutoTokenizer
# import torch

# device = "cuda:0" if torch.cuda.is_available() else "cpu"

# model = ParlerTTSForConditionalGeneration.from_pretrained("/content/parler-tts/output_dir_training", torch_dtype=torch.float16).to(device)
# tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")

# prompt = "Hey, how are you doing today?"
# description = "'Jenny delivers her words quite expressively, in a very confined sounding environment with clear audio quality. She speaks fast.'"

# input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
# prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
# audio_arr = generation.cpu().numpy().squeeze()

In [24]:
# from IPython.display import Audio
# Audio(audio_arr, rate=model.config.sampling_rate)

In [25]:
# prompt = "Wow, I've really got the same voice as Jenny, huh?"

# prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
# audio_arr = generation.cpu().numpy().squeeze()

# Audio(audio_arr, rate=model.config.sampling_rate)

In [26]:
# prompt = "What a time to be alive!"
# description = "'Jenny's speech is very clear, and she speaks in a very monotone voice, really slowly and with minimal variation in speed.'"

# input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
# prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

# generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
# audio_arr = generation.cpu().numpy().squeeze()

# Audio(audio_arr, rate=model.config.sampling_rate)

This is great! As you can see, the model now managed to get a **consistent** voice throughout generation that looks like **Jenny**!

Since we're quite happy about it, let's push it to the hub to be able to re-use it!

In [27]:
# model.push_to_hub("parler-tts-mini-Jenny-colab")
# tokenizer.push_to_hub("parler-tts-mini-Jenny-colab")

In [28]:
# model = ParlerTTSForConditionalGeneration.from_pretrained("Cintin/parler-tts-mini-Jenny-colab").to(device)
# tokenizer = AutoTokenizer.from_pretrained("Cintin/parler-tts-mini-Jenny-colab")

You'll now be able to load the model and the tokenizer using the direct repository id of your model, i.e `<your_HF_handle>/parler-tts-mini-Jenny-colab`.

```python
model = ParlerTTSForConditionalGeneration.from_pretrained("<your_HF_handle>/parler-tts-mini-Jenny-colab").to(device)
tokenizer = AutoTokenizer.from_pretrained("<your_HF_handle>/parler-tts-mini-Jenny-colab")
```



## Conclusion

To conclude, we've shown here:
1. how to annotate a single-speaker 6-hours-long dataset
2. how to fine-tune Parler-TTS Mini v0.1 on this newly created dataset!

**If you want to fine-tune the model on your own dataset, you can follow and/or adapt the current notebook to make it work! Don't forget to check how to push your own local dataset on the HuggingFace Hub using a [script similar to what's described in the Data-Speech FAQ](https://github.com/huggingface/dataspeech?tab=readme-ov-file#how-do-i-use-datasets-that-i-have-with-this-repository)!**

In [29]:
!python helpers/model_init_scripts/init_model_600M.py ./parler-tts-untrained-600M --text_model "google/flan-t5-base" --audio_model "parler-tts/dac_44khZ_8kbps"

2024-05-26 14:37:54.463524: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-26 14:37:54.463586: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-26 14:37:54.465199: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
config.json: 100%|█████████████████████████| 1.40k/1.40k [00:00<00:00, 7.56MB/s]
config.json: 100%|█████████████████████████████| 247/247 [00:00<00:00, 1.54MB/s]
num_codebooks 9
model.safetensors: 100%|██████████████████████| 990M/990M [00:03<00:00, 272MB/s]
model.safetensors: 100%|██████████████████████| 307M/307M [00:01<00:00, 175MB/s]


In [30]:
!ls

LICENSE		helpers			   parler_tts.egg-info	training
Makefile	output_dir_training	   pyproject.toml	wandb
README.md	parler-tts-untrained-600M  setup.py
audio_code_tmp	parler_tts		   tmp_dataset_audio


In [31]:
# !accelerate launch ./training/run_parler_tts_training.py \
# --model_name_or_path "./parler-tts-untrained-600M/parler-tts-untrained-600M/" \
# --save_to_disk "./tmp_dataset_audio/" \
# --temporary_save_to_disk "./audio_code_tmp/" \
# --feature_extractor_name "ylacombe/dac_44khZ_8kbps" \
# --description_tokenizer_name "google/flan-t5-base" \
# --prompt_tokenizer_name "google/flan-t5-base" \
# --report_to "wandb" \
# --wandb_project ptts \
# --wandb_run_name ptts_run_full \
# --overwrite_output_dir true \
# --output_dir "./output_dir_training" \
# --train_dataset_name "blabble-io/libritts_r+blabble-io/libritts_r+blabble-io/libritts_r+parler-tts/mls_eng_10k" \
# --train_metadata_dataset_name "parler-tts/libritts_r_tags_tagged_10k_generated+parler-tts/libritts_r_tags_tagged_10k_generated+parler-tts/libritts_r_tags_tagged_10k_generated+parler-tts/mls-eng-10k-tags_tagged_10k_generated" \
# --train_dataset_config_name "clean+clean+other+default" \
# --train_split_name "train.clean.360+train.clean.100+train.other.500+train" \
# --eval_dataset_name "blabble-io/libritts_r+parler-tts/mls_eng_10k" \
# --eval_metadata_dataset_name "parler-tts/libritts_r_tags_tagged_10k_generated+parler-tts/mls-eng-10k-tags_tagged_10k_generated" \
# --eval_dataset_config_name "other+default" \
# --eval_split_name "test.other+test" \
# --target_audio_column_name "audio" \
# --description_column_name "text_description" \
# --prompt_column_name "text" \
# --max_eval_samples 96 \
# --max_duration_in_seconds 30 \
# --min_duration_in_seconds 2.0 \
# --max_text_length 400 \
# --group_by_length true \
# --add_audio_samples_to_wandb true \
# --id_column_name "id" \
# --preprocessing_num_workers 8 \
# --do_train true \
# --num_train_epochs 40 \
# --gradient_accumulation_steps 8 \
# --gradient_checkpointing false \
# --per_device_train_batch_size 3 \
# --learning_rate 0.00095 \
# --adam_beta1 0.9 \
# --adam_beta2 0.99 \
# --weight_decay 0.01 \
# --lr_scheduler_type "constant_with_warmup" \
# --warmup_steps  20000 \
# --logging_steps 1000 \
# --freeze_text_encoder true \
# --do_eval true \
# --predict_with_generate true \
# --include_inputs_for_metrics true \
# --evaluation_strategy "steps" \
# --eval_steps 10000 \
# --save_steps 10000 \
# --per_device_eval_batch_size 12 \
# --audio_encoder_per_device_batch_size 20 \
# --dtype "bfloat16" \
# --seed 456 \
# --dataloader_num_workers 8

In [32]:
# from parler_tts import ParlerTTSForConditionalGeneration
# from transformers import AutoTokenizer
# import torch

# device = "cuda:0" if torch.cuda.is_available() else "cpu"

# model = ParlerTTSForConditionalGeneration.from_pretrained("./output_dir_training", torch_dtype=torch.float16).to(device)
# tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")

# model.push_to_hub("parler-tts-fulltune")
# tokenizer.push_to_hub("parler-tts-fulltune")