## Homework 09_codec_models [15 points]

First, let's download the required files and packages

In [None]:
# !wget https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_09_tts_tranformer/data.py
# !wget https://raw.githubusercontent.com/yandexdataschool/speech_course/main/week_09_tts_tranformer/model.py

In [None]:
# !pip install deep-phonemizer librosa matplotlib numpy pyannote.audio pyloudnorm torch torchaudio tqdm

In [None]:
import os
from urllib.parse import urlencode

import requests

def download_file(public_link):
    base_url = 'https://cloud-api.yandex.net/v1/disk/public/resources/download?'
    final_url = base_url + urlencode(dict(public_key=public_link))
    response = requests.get(final_url)
    parse_href = response.json()['href']

    url = parse_href
    start_filename = url.find('filename=')
    end_filename = url[start_filename:].find('&')
    end_name = start_filename + end_filename
    filename = url[start_filename:end_name][9:]
    download_url = requests.get(url)
    final_link = os.path.join(os.getcwd(), filename)
    with open(final_link, 'wb') as ff:
        ff.write(download_url.content)

In [None]:
### To download the file uncomment the following line

# link_to_archive = "https://disk.yandex.ru/d/XvDaWCWch6hWTw"
# download_file(link_to_archive)
# !unzip lingware.zip
# !rm lingware.zip
# !mkdir -p ../data/09_tts_transformers
# !mv lingware ../data/09_tts_transformers

Lingware - is a folder with all of the stuff, needed for synthesis: tokenizer weights, g2p dictionaries and so on...

### 0. Transformer

In this homework we will download a pretrained transformer and write inference for the model.

First let's take a look on the required tools:

In [None]:
from pathlib import Path

import torch
import torch.nn.functional as F
import torchaudio
from IPython.display import Audio, display

In [None]:
%load_ext autoreload
%autoreload 2

from model import *
from data import *

In [None]:
device = torch.device("cuda:7")

In [None]:
### Paths
lingware_folder = Path("../data/09_tts_transformers/lingware")

ckpt_path = lingware_folder / "ckpt"

codec_model_path = lingware_folder / "codec" / "ckpt"
codec_config_path = lingware_folder / "codec" / "config.json"

phonemizer_path = lingware_folder / "phonemizer_en_us.pt"

dataset_url = "dev-clean"
data_path = Path("../data/09_tts_transformers")
data_path.mkdir(exist_ok=True)

Let's download data for playing with our model. 

In [None]:
tts_dataset = torchaudio.datasets.LIBRITTS(root=data_path, url=dataset_url, download=True)

Let's create a **phonemizer**. It will be a simple phonemizer that will use the `deep-phonemizer` library. It has 2 methods:
- `phonemize` - that will take a text and return a phonemized version of it, as a sequence of phonemes.
- `tokenize` - that will take a text, phonemize it and return a list of indices, assigned to each phoneme.

In [None]:
phonemizer = Phonemizer(phonemizer_path)

Now let's create a **bioembedding model**. We will condition our model on its outputs, to mimic the speaker in synthesis.

It has method `__call__` that takes a waveform and returns the embedding of the speaker from this waveform.

In [None]:
bioemb_model = BioembModel(device=device)

Now let's create a codec model. It can convert wav to codecs and back: 
- `encode` - transforms waveform of size `[Time]` to a sequence of codecs of size `[short_time, 4]`
- `decode` - transforms sequence of codecs of size `[short_time, 4]` back to waveform.

Note, that:
- `160` - index of end_token
- `161` - index of start_token
- `162` - index of pad_token

In [None]:
codec_model = CodecApplier(
    config_path=codec_config_path,
    ckpt_path=codec_model_path,
    sample_rate=16000,
    device=device,
)

Now let's assemble everything in one dataset

In [None]:
infer_dataset = CodecsDataset(
    dataset=tts_dataset,
    phonemizer=phonemizer,
    bioemb_model=bioemb_model,
    codec_model=codec_model,
)

### Model

We will work with model, which mimics model from [mqtts paper](https://arxiv.org/abs/2302.04215). It consist of encoder, decoder and sub-decoder. 
- `Encoder` consists of several Self-Attention layers and Feed-Forward layers. It takes a sequence of embeddings of phonemes and return the encoded representation of the sequence.
- `Decoder` consists of several Self-Attention layers and Feed-Forward layers. It uses cross-attention to watch on embeddings from encoder. It takes a sequence of codecs, for each layer creates an embeddings, concatenates them and uses as an input. Then for each codec it predicts an embedding, which is further used by sub-decoder to predict next codec.
- `SubDecoder` - decoder-only transformer, which gets an embedding from decoder and predicts 4 tokens. It makes 4 steps of autoregression to predict 4 tokens of the codec.

This model was discussed on the lecture, you can refer the recording for better understanding of what is happening.

This model was trained on LibriTTS dataset.

In [None]:
ckpt = torch.load(ckpt_path, map_location=torch.device("cpu"))

model = TTSTransformer(
    n_phonemes=49,
    n_codes=163,
    n_codebooks=4,
)

model.load_state_dict(ckpt)
model = model.eval().to(device)

### Inference function

This function iterates over the input dataset `n_samples` times. For each sample predicts tokens in a teacher-forcing regime. Then decodes it with a codec_model back to waveform and plays it.

Here we use teacher forcing, which means we use ground-truth codecs to predict the next token. This is not the best way to generate audio, but it is the simplest one for sanity check.

In [None]:
def infer_teacher_forcing(model, dataset, codec_model, n_samples=1, sampling_fn=lambda x: x.argmax(dim=-1)):
    """
    model: TTSTranformer model, which has `forward` method. It gets phoneme_ids, speaker_embedding and codecs sequence and predicts logits for the next codec.
    dataset: Iterator over the CodecsDataset. On each iteration it shounld return tuple with (phonemes, phoneme_ids, codecs, bioemb)
    codec_model: Codec model, needed to decode codecs sequence back to the waveform
    n_samples: number of samples from the dataset which will be inferred
    sampling_fn: function which takes logits and returns the predicted labels. By default it returns the argmax of the logits
    """
    device = model.parameters().__next__().device

    for idx, (phonemes, phoneme_ids, codecs, bioemb) in zip(range(n_samples), dataset):
        phoneme_ids = torch.tensor([phoneme_ids], device=device)
        codecs = torch.tensor([codecs], device=device)
        bioemb = torch.tensor([bioemb], device=device)

        phones_mask = torch.ones_like(phoneme_ids, dtype=torch.bool)
        codes_mask = torch.ones(codecs.shape[:2], dtype=torch.bool, device=device)

        prediction = model(
            phones=phoneme_ids,
            phones_mask=phones_mask, # [B, l]
            codes=codecs, # [B, L, N]
            codes_mask=codes_mask, # [B, L]
            speaker_embs=bioemb, # [B, d]
        )
        pred_labels = sampling_fn(prediction)

        # [:, 1:, :] is needed to remove the start tokens
        gt_wav = codec_model.decode(codecs[:, 1:, :], bioemb)

        # Clamp is needed to remove the eos, bos or padding token if they emerge in the prediction
        pred_labels = pred_labels.clamp(min=0, max=159)
        synthesized_wav = codec_model.decode(pred_labels, bioemb)

        print(f"Phonemes: {'_'.join(phonemes)}")
        print(f"Ground truth")
        display(Audio(gt_wav, rate=16000))
        print(f"Synthesized")
        display(Audio(synthesized_wav, rate=16000))

In [None]:
infer_teacher_forcing(model, dataset=infer_dataset, codec_model=codec_model, n_samples=5)

### 1. Sampling functions [3 points]
During inference our model predict logits, and we need to sample from these logits to get the next token. We will use several functions to do that.
- `ArgmaxSampling` - dedicated for greedy decoding, it returns the token with the highest logit (aka probability).
- `MultinomialSampling` - samples indices of codecs from multinomial distribution with probabilities `softmax (logits / temperature)`.
- `TopKSampling` - takes only tok-k logits with highest probabilities and samples from them, using multinomial sampling. 

Each function gets FloatTensor of size [\*, logits], and returns LongTensor of size [\*], where \* - is the arbitrary number of dimensions.

These function from torch can be useful:
- [torch.multinomial](https://pytorch.org/docs/stable/generated/torch.multinomial.html)
- [torch.topk](https://pytorch.org/docs/stable/generated/torch.topk.html)
- [torch.gather](https://pytorch.org/docs/stable/generated/torch.gather.html#torch.gather)

In [None]:
# TODO: implement the following functions

class ArgmaxSampling:
    def __init__(self):
        pass

    def __call__(self, logits):
        return torch.argmax(logits, dim=-1)


class MultinomialSampling:
    def __init__(self, temperature=1.0):
        self.temperature = temperature

    def __call__(self, logits):
                # Your code here
        raise NotImplementedError("TODO: assignment")
        # ^^^^^^^^^^^^^^



class TopKSampling:
    def __init__(self, k, temperature=1.0):
        self.k = k
        self.temperature = temperature

    def __call__(self, logits):
                # Your code here
        raise NotImplementedError("TODO: assignment")
        # ^^^^^^^^^^^^^^


In [None]:
infer_teacher_forcing(
    model,
    dataset=infer_dataset,
    codec_model=codec_model,
    n_samples=1,
    sampling_fn=MultinomialSampling(temperature=1.0),
)

infer_teacher_forcing(
    model,
    dataset=infer_dataset,
    codec_model=codec_model,
    n_samples=1,
    sampling_fn=TopKSampling(k=3, temperature=1.0),
)

Let's listen what we've got and how hyperparameters influence the sampling. 

In [None]:
sampling_functions_to_test = [
    (MultinomialSampling, {"temperature": 1.}),
    (MultinomialSampling, {"temperature": 3.}),
    (MultinomialSampling, {"temperature": 0.5}),
    (TopKSampling, {"k": 7, "temperature": 1.}),
    (TopKSampling, {"k": 20, "temperature": 1.}),
    (TopKSampling, {"k": 3, "temperature": 1.}),
]

for sampling_class, sampling_kwargs in sampling_functions_to_test:
    sampling_fn = sampling_class(**sampling_kwargs)
    print(f"======== Sampling function: {sampling_class.__name__} with kwargs {sampling_kwargs} ========")
    infer_teacher_forcing(
        model,
        dataset=infer_dataset,
        codec_model=codec_model,
        n_samples=2,
        sampling_fn=sampling_fn,
    )

Assignment:

What are your notions about these different sampling methods ? What is the difference between them ? What are the advantages and disadvantages of each ? Which is the preferable one ?


TODO

### 2. Autoregressive inference [12 points]

Autoregressive sampling function. It creates, exactly the same, as `infer_teacher_forcing`, but uses `model.autoregressive_sampling` instead of `model.forward`.

In [None]:
def infer_autoregressive(model, dataset, codec_model, n_samples=5, sampling_fn=lambda x: x.argmax(dim=-1)):
    device = model.parameters().__next__().device

    for idx, (phonemes, phoneme_ids, codecs, bioemb) in zip(range(n_samples), dataset):
        phoneme_ids = torch.tensor([phoneme_ids], device=device)
        codecs = torch.tensor([codecs], device=device)
        bioemb = torch.tensor([bioemb], device=device)


        # This function is not supposed to use codecs for prediction
        pred_labels = model.autoregressive_sampling(
            phones=phoneme_ids,
            speaker_embs=bioemb,
            sampling_fn=sampling_fn,
        )

        print(f"{codecs.shape=}")
        # [:, 1:, :] is needed to remove the start tokens
        gt_wav = codec_model.decode(codecs[:, 1:, :], bioemb)

        # Clamp is needed to remove the eos, bos or padding token if they emerge in the prediction
        pred_labels = pred_labels.clamp(min=0, max=159)
        synthesized_wav = codec_model.decode(pred_labels, bioemb)

        print(f"Phonemes: {'_'.join(phonemes)}")
        print(f"Ground truth")
        display(Audio(gt_wav, rate=16000))
        print(f"Synthesized")
        display(Audio(synthesized_wav, rate=16000))

Assignment:

Go to the file model.py and implement SubDecoder.autoregressive_sampling and TTSTranformer.autoregressive_sampling methods.

Notes:
- The model is allmost exact copy of MQTTS model from the lecture. Except that it doesn't use trick with a window in encoder-decoder attention during inference. 
- You better not modify the `__init__` and `forward` methods of each model. Because the behaviour of the model can change.
- During autoregressive sampling, the model should not use the ground truth codec sequence. Instead, the model should generate the target sequence one token at a time. Starting with a vector of 4 start_tokens.
- You will need to figure out how the model works, so don't hesitate to print the shapes of the tensors you are working with.
- You will need to use SubDecoder.forward and TTSTransformer.forward methods multiple times. But do not modify them.
- The synthesis has two conditions that end the generation of the target sequence:
    - The target sequence is longer than the maximum length.
    - The target sequence contains at least one end token.

In [None]:
infer_autoregressive(model, dataset=infer_dataset, codec_model=codec_model, sampling_fn=MultinomialSampling())

Let's say, that you have implemented those methods successefully if the model 
geneates comprehensible speech.

Now you can play with different types and hyperparameters of sampling in autoregressive synthesis.

Write down:
- What hyperparameters you have played with ?
- Compare them.
- How do they influece autoregressive synthesis ?
- What is their effect on audio-quality, intonation and speaker-similarity ? 
- What are the optimal hyperparameters ?

TODO:

### Rules for commiting homework

- Clear your code from debugging `print`-s
- You need to commit 2 files: `model.py` and `homework.ipynb` and some samples with your synthesis
- The resulted notebook includes some audios and weights a lot of memory. YOu can use this tutorial to share the notebook: https://gist.github.com/yashika51/58d2b6d8d1048a1d0a9ea5949a8aa7f6 . Or you can also try using collab.