In [None]:
!pip install speechbrain datasets==2.17.0

In [None]:
import os
import torch
import matplotlib.pyplot as plt
from datasets import load_dataset, Audio
from transformers import (AutoProcessor
                          , SpeechT5ForTextToSpeech
                          , Seq2SeqTrainingArguments
                          , Seq2SeqTrainer
                          , SpeechT5HifiGan)
from collections import defaultdict
from speechbrain.inference.speaker import EncoderClassifier
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from functools import partial

from huggingface_hub import notebook_login
from IPython.display import Audio as play_audio

In [None]:
notebook_login()

In [None]:
dataset = load_dataset("facebook/voxpopuli"
                       , "lt"
                       , split="train"
                       , trust_remote_code=True
                       , streaming=True)

In [None]:
checkpoint = "microsoft/speecht5_tts"
processor = AutoProcessor.from_pretrained(checkpoint)

In [None]:
sampling_rate = processor.feature_extractor.sampling_rate
dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))

In [None]:
tokenizer = processor.tokenizer

In [None]:
def extract_all_chars(batch):
    all_text = " ".join(batch["normalized_text"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}


vocabs = dataset.map(
    extract_all_chars,
    batched=True,
    batch_size=-1,
    #keep_in_memory=True,
    remove_columns=dataset.column_names,
)

dataset_vocab = set(j for i in vocabs for j in i["vocab"])
tokenizer_vocab = {k for k, _ in tokenizer.get_vocab().items()}

In [None]:
dataset_vocab - tokenizer_vocab

In [None]:
replacements = [
    ("ą","a")
    , ("č", "c")
    , ("ė", "e")
    , ("ę", "e")
    , ("į", "j")
    , ("š", "s")
    , ("ū", "u")
    , ("ų", "u")
    , ("ž", "z")
]
def cleanup_text(inputs):
    for src, dst in replacements:
        inputs["normalized_text"] = inputs["normalized_text"].replace(src, dst)
    return inputs

dataset = dataset.map(cleanup_text)

In [None]:
speaker_counts = defaultdict(int)
for x in dataset:
    speaker_counts[x["speaker_id"]] += 1

In [None]:
plt.figure()
plt.hist(speaker_counts.values(), bins=20)
plt.ylabel("Speakers")
plt.xlabel("Examples")
plt.show()

In [None]:
spk_model_name = "speechbrain/spkrec-xvect-voxceleb"

device = "cuda" if torch.cuda.is_available() else "cpu"
speaker_model = EncoderClassifier.from_hparams(
    source=spk_model_name,
    run_opts={"device": device},
    savedir=os.path.join("/tmp", spk_model_name),
)


def create_speaker_embedding(waveform):
    with torch.no_grad():
        speaker_embeddings = speaker_model.encode_batch(torch.tensor(waveform))
        speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings, dim=2)
        speaker_embeddings = speaker_embeddings.squeeze().cpu().numpy()
    return speaker_embeddings

In [None]:
def prepare_dataset(example):
    audio = example["audio"]

    example = processor(
        text=example["normalized_text"],
        audio_target=audio["array"],
        sampling_rate=audio["sampling_rate"],
        return_attention_mask=False,
    )

    # strip off the batch dimension
    example["labels"] = example["labels"][0]

    # use SpeechBrain to obtain x-vector
    example["speaker_embeddings"] = create_speaker_embedding(audio["array"])

    return example

In [None]:
processed_example = prepare_dataset(next(iter(dataset)))
list(processed_example.keys())

In [None]:
processed_example["speaker_embeddings"].shape

In [None]:
plt.figure(figsize = (16, 9))
plt.imshow(processed_example["labels"].T)
plt.show()

In [None]:
dataset = dataset.map(prepare_dataset
#                    , batched=True
#                    , batch_size=-1
                   , remove_columns=dataset.column_names)

In [None]:
dataset = dataset.shuffle(seed=0)
length = sum(1 for _ in dataset)
train_ratio = 0.9
train_size = round(train_ratio * length)
train_dataset = dataset.take(train_size)
test_dataset = dataset.skip(train_size)

In [None]:
@dataclass
class TTSDataCollatorWithPadding:
    processor: Any

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        input_ids = [{"input_ids": feature["input_ids"]} for feature in features]
        label_features = [{"input_values": feature["labels"]} for feature in features]
        speaker_features = [feature["speaker_embeddings"] for feature in features]

        # collate the inputs and targets into a batch
        batch = processor.pad(
            input_ids=input_ids, labels=label_features, return_tensors="pt"
        )

        # replace padding with -100 to ignore loss correctly
        batch["labels"] = batch["labels"].masked_fill(
            batch.decoder_attention_mask.unsqueeze(-1).ne(1), -100
        )

        # not used during fine-tuning
        del batch["decoder_attention_mask"]

        # round down target lengths to multiple of reduction factor
        if model.config.reduction_factor > 1:
            target_lengths = torch.tensor(
                [len(feature["input_values"]) for feature in label_features]
            )
            target_lengths = target_lengths.new(
                [
                    length - length % model.config.reduction_factor
                    for length in target_lengths
                ]
            )
            max_length = max(target_lengths)
            batch["labels"] = batch["labels"][:, :max_length]

        # also add in the speaker embeddings
        batch["speaker_embeddings"] = torch.tensor(speaker_features)

        return batch

In [None]:
data_collator = TTSDataCollatorWithPadding(processor=processor)

In [None]:
model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)

In [None]:
# disable cache during training since it's incompatible with gradient checkpointing
model.config.use_cache = False

# set language and task for generation and re-enable cache
model.generate = partial(model.generate, use_cache=True)

In [None]:
model_name = checkpoint.split("/")[-1]
hf_dir = f"{model_name}_finetuned_voxpopuli_lt"
training_args = Seq2SeqTrainingArguments(
    output_dir=hf_dir,
    auto_find_batch_size=True,
    gradient_accumulation_steps=2,
    learning_rate=5e-5,
    warmup_steps=50,
    max_steps=400,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    save_steps=100,
    eval_steps=100,
    logging_steps=100,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    greater_is_better=False,
    label_names=["labels"],
    push_to_hub=True,
)

In [None]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator,
    tokenizer=processor,
)

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub()

In [None]:
model = SpeechT5ForTextToSpeech.from_pretrained(
    "jaymanvirk/speecht5_tts_finetuned_voxpopuli_lt"
)

In [None]:
example = next(iter(test_dataset))
speaker_embeddings = torch.tensor(example["speaker_embeddings"]).unsqueeze(0)

In [None]:
text = "Šis yra bandymo pranešimas. Tikrinama 400 žingsnių smulkiu nustatymu modelio kokybė."

In [None]:
inputs = processor(text=text, return_tensors="pt")

In [None]:
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)

In [None]:
play_audio(speech.numpy(), rate=16000)