# TalkNet Training

This notebook is designed to provide a guide on how to train TalkNet as part of the TTS pipeline. It contains the following two sections:
  1. **Introduction**: TalkNet in NeMo
  2. **Preprocessing**: how to prepare data for Talknet 
  3. **Training**: example of TalkNet training

# License

> Copyright 2020 NVIDIA. All Rights Reserved.
> 
> Licensed under the Apache License, Version 2.0 (the "License");
> you may not use this file except in compliance with the License.
> You may obtain a copy of the License at
> 
>     http://www.apache.org/licenses/LICENSE-2.0
> 
> Unless required by applicable law or agreed to in writing, software
> distributed under the License is distributed on an "AS IS" BASIS,
> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
> See the License for the specific language governing permissions and
> limitations under the License.

In [None]:
"""
You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.
Instructions for setting up Colab are as follows:
1. Open a new Python 3 notebook.
2. Import this notebook from GitHub (File -> Upload Notebook -> "GITHUB" tab -> copy/paste GitHub URL)
3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select "GPU" for hardware accelerator)
4. Run this cell to set up dependencies# .
"""
# # If you're using Colab and not running locally, uncomment and run this cell.
# !apt-get install sox libsndfile1 ffmpeg
# !pip install wget unidecode pysptk
# BRANCH = 'v1.0.0'
# !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]

In [None]:
import json
import nemo
import torch
import torchaudio
import numpy as np

from pysptk import sptk
from pathlib import Path
from tqdm.notebook import tqdm

# Introduction

TalkNet is a neural network that converts text characters into a mel spectrogram. For more details about model, please refer to Nvidia's TalkNet Model Card, or the original [paper](https://arxiv.org/abs/2104.08189).

TalkNet like most NeMo models is defined as a LightningModule, allowing for easy training via PyTorch Lightning, and parameterized by a configuration, currently defined via a yaml file and loading using Hydra.

Let's take a look using NeMo's pretrained model and how to use it to generate spectrograms.

In [None]:
# Load the TalkNetSpectModel
from nemo.collections.tts.models import TalkNetSpectModel
from nemo.collections.tts.models.base import SpectrogramGenerator

# Let's see what pretrained models are available
print(TalkNetSpectModel.list_available_models())

In [None]:
# We can load the pre-trained model as follows
pretrained_model = "tts_en_talknet"
model = TalkNetSpectModel.from_pretrained(pretrained_model)

# Load and attach durs and pitch predictors
from nemo.collections.tts.models import TalkNetPitchModel
pitch_model = TalkNetPitchModel.from_pretrained(pretrained_model)
from nemo.collections.tts.models import TalkNetDursModel
durs_model = TalkNetDursModel.from_pretrained(pretrained_model)
model.add_module('_pitch_model', pitch_model)
model.add_module('_durs_model', durs_model)

model.eval()

In [None]:
# TalkNet is a SpectrogramGenerator
assert isinstance(model, SpectrogramGenerator)

# SpectrogramGenerators in NeMo have two helper functions:
#   1. parse(text: str, **kwargs) which takes an English string and produces a token tensor
#   2. generate_spectrogram(tokens: 'torch.tensor', **kwargs) which takes the token tensor and generates a spectrogram
# Let's try it out
tokens = model.parse(text="Hey, this produces speech!")
spectrogram = model.generate_spectrogram(tokens=tokens)

# Now we can visualize the generated spectrogram
# If we want to generate speech, we have to use a vocoder in conjunction to a spectrogram generator.
# Refer to the TTS Inference notebook on how to convert spectrograms to speech.
from matplotlib.pyplot import imshow
from matplotlib import pyplot as plt
%matplotlib inline
imshow(spectrogram.cpu().detach().numpy()[0,...], origin="lower")
plt.show()

# Preprocessing

Now that we looked at the TalkNet model, let's see how to prepare all data for training it. 

Firstly, let's download all necessary training scripts and configs.

In [None]:
!wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/tts/talknet_durs.py
!wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/tts/talknet_pitch.py
!wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/tts/talknet_spect.py

!mkdir -p conf && cd conf \
&& wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/tts/conf/talknet-durs.yaml \
&& wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/tts/conf/talknet-pitch.yaml \
&& wget https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/tts/conf/talknet-spect.yaml \
&& cd ..

We will show example of preprocessing and training using small part of AN4 dataset. It consists of recordings of people spelling out addresses, names, telephone numbers, etc., one letter or number at a time, as well as their corresponding transcripts. Let's download data and prepare manifests.

*NOTE: The sample data is not enough data to properly train a TalkNet. This will not result in a trained TalkNet and is used to just as example.*

In [None]:
!wget https://github.com/NVIDIA/NeMo/releases/download/v0.11.0/test_data.tar.gz && mkdir -p tests/data && tar xzf test_data.tar.gz -C tests/data

# Just like ASR, the TalkNet require .json files to define the training and validation data.
!cat tests/data/asr/an4_val.json
!cat tests/data/asr/an4_train.json tests/data/asr/an4_val.json > tests/data/asr/an4_all.json 

## Extracting phoneme ground truth durations

As a part of whole model, you will need to train duration predictor. We will extract phoneme ground truth durations from ASR model (QuartzNet5x5, trained on LibriTTS) using forward-backward algorithm (see paper for details). Let's download pretrained ASR model and define auxiliary functions. 

In [None]:
from nemo.collections.asr.models import EncDecCTCModel
asr_model = EncDecCTCModel.from_pretrained(model_name="asr_talknet_aligner").cpu().eval()

In [None]:
def forward_extractor(tokens, log_probs, blank):
    """Computes states f and p."""
    n, m = len(tokens), log_probs.shape[0]
    # `f[s, t]` -- max sum of log probs for `s` first codes
    # with `t` first timesteps with ending in `tokens[s]`.
    f = np.empty((n + 1, m + 1), dtype=float)
    f.fill(-(10 ** 9))
    p = np.empty((n + 1, m + 1), dtype=int)
    f[0, 0] = 0.0  # Start
    for s in range(1, n + 1):
        c = tokens[s - 1]
        for t in range((s + 1) // 2, m + 1):
            f[s, t] = log_probs[t - 1, c]
            # Option #1: prev char is equal to current one.
            if s == 1 or c == blank or c == tokens[s - 3]:
                options = f[s : (s - 2 if s > 1 else None) : -1, t - 1]
            else:  # Is not equal to current one.
                options = f[s : (s - 3 if s > 2 else None) : -1, t - 1]
            f[s, t] += np.max(options)
            p[s, t] = np.argmax(options)
    return f, p


def backward_extractor(f, p):
    """Computes durs from f and p."""
    n, m = f.shape
    n -= 1
    m -= 1
    durs = np.zeros(n, dtype=int)
    if f[-1, -1] >= f[-2, -1]:
        s, t = n, m
    else:
        s, t = n - 1, m
    while s > 0:
        durs[s - 1] += 1
        s -= p[s, t]
        t -= 1
    assert durs.shape[0] == n
    assert np.sum(durs) == m
    assert np.all(durs[1::2] > 0)
    return durs

def preprocess_tokens(tokens, blank):
    new_tokens = [blank]
    for c in tokens:
        new_tokens.extend([c, blank])
    tokens = new_tokens
    return tokens

Now we can run extraction and save result. 

In [None]:
data_config = {
    'manifest_filepath': "tests/data/asr/an4_all.json",
    'sample_rate': 16000,
    'labels': asr_model.decoder.vocabulary,
    'batch_size': 1,
}

parser = nemo.collections.asr.data.audio_to_text.AudioToCharWithDursF0Dataset.make_vocab(
    notation='phonemes', punct=True, spaces=True, stresses=False, add_blank_at="last"
)

dataset = nemo.collections.asr.data.audio_to_text._AudioTextDataset(
    manifest_filepath=data_config['manifest_filepath'], sample_rate=data_config['sample_rate'], parser=parser,
)

dl = torch.utils.data.DataLoader(
    dataset=dataset, batch_size=data_config['batch_size'], collate_fn=dataset.collate_fn, shuffle=False,
)

blank_id = asr_model.decoder.num_classes_with_blank - 1

dur_data = {}
for sample_idx, test_sample in tqdm(enumerate(dl), total=len(dl)):
    log_probs, _, greedy_predictions = asr_model(
        input_signal=test_sample[0], input_signal_length=test_sample[1]
    )

    log_probs = log_probs[0].cpu().detach().numpy()
    seq_ids = test_sample[2][0].cpu().detach().numpy()

    target_tokens = preprocess_tokens(seq_ids, blank_id)

    f, p = forward_extractor(target_tokens, log_probs, blank_id)
    durs = backward_extractor(f, p)

    dur_key = Path(dl.dataset.collection[sample_idx].audio_file).stem
    dur_data[dur_key] = {
        'blanks': torch.tensor(durs[::2], dtype=torch.long).cpu().detach(), 
        'tokens': torch.tensor(durs[1::2], dtype=torch.long).cpu().detach()
    }

    del test_sample

torch.save(dur_data, "tests/data/asr/an4_durations.pt")

## Extracting ground truth f0

The second model, that you will need to train before spectrogram generator, is pitch predictor. As labels for pitch predictor, we will use f0 from audio using `pysptk` library (see paper for details). Let's extract f0, calculate stats (mean & std) and save it all.

In [None]:
def extract_f0(audio_file, sample_rate=16000, hop_length=256):
    audio = torchaudio.load(audio_file)[0].squeeze().numpy()
    f0 = sptk.swipe(audio.astype(np.float64), sample_rate, hopsize=hop_length)
    # Hack to make f0 and mel lengths equal
    if len(audio) % hop_length == 0:
        f0 = np.pad(f0, pad_width=[0, 1])
    return torch.from_numpy(f0.astype(np.float32))

In [None]:
f0_data = {}
with open("tests/data/asr/an4_all.json") as f:
    for l in tqdm(f):
        audio_path = json.loads(l)["audio_filepath"]
        f0_data[Path(audio_path).stem] = extract_f0(audio_path)

# calculate f0 stats (mean & std) only for train set
with open("tests/data/asr/an4_train.json") as f:
    train_ids = {Path(json.loads(l)["audio_filepath"]).stem for l in f}
all_f0 = torch.cat([f0[f0 >= 1e-5] for f0_id, f0 in f0_data.items() if f0_id in train_ids])

F0_MEAN, F0_STD = all_f0.mean().item(), all_f0.std().item()        
torch.save(f0_data, "tests/data/asr/an4_f0s.pt")

# Training

Now we are ready for training our models! Let's try to train TalkNet parts consequentially.

In [None]:
!python talknet_durs.py sample_rate=16000 \
train_dataset=tests/data/asr/an4_train.json \
validation_datasets=tests/data/asr/an4_val.json \
durs_file=tests/data/asr/an4_durations.pt \
f0_file=tests/data/asr/an4_f0s.pt \
trainer.max_epochs=3 \
trainer.accelerator=null \
trainer.check_val_every_n_epoch=1 \
model.train_ds.dataloader_params.batch_size=6 \
model.train_ds.dataloader_params.num_workers=0 \
model.validation_ds.dataloader_params.num_workers=0

In [None]:
!python talknet_pitch.py sample_rate=16000 \
train_dataset=tests/data/asr/an4_train.json \
validation_datasets=tests/data/asr/an4_val.json \
durs_file=tests/data/asr/an4_durations.pt \
f0_file=tests/data/asr/an4_f0s.pt \
trainer.max_epochs=3 \
trainer.accelerator=null \
trainer.check_val_every_n_epoch=1 \
model.f0_mean={F0_MEAN} \
model.f0_std={F0_STD} \
model.train_ds.dataloader_params.batch_size=6 \
model.train_ds.dataloader_params.num_workers=0 \
model.validation_ds.dataloader_params.num_workers=0

In [None]:
!python talknet_spect.py sample_rate=16000 \
train_dataset=tests/data/asr/an4_train.json \
validation_datasets=tests/data/asr/an4_val.json \
durs_file=tests/data/asr/an4_durations.pt \
f0_file=tests/data/asr/an4_f0s.pt \
trainer.max_epochs=3 \
trainer.accelerator=null \
trainer.check_val_every_n_epoch=1 \
model.train_ds.dataloader_params.batch_size=6 \
model.train_ds.dataloader_params.num_workers=0 \
model.validation_ds.dataloader_params.num_workers=0

That's it!

In order to train TalkNet for real purposes, it is highly recommended to obtain high quality speech data with the following properties:

* Sampling rate of 22050Hz or higher
* Single speaker
* Speech should contain a variety of speech phonemes
* Audio split into segments of 1-10 seconds
* Audio segments should not have silence at the beginning and end
* Audio segments should not contain long silences inside