# Inference Demo for Mellotron on Google COLAB
basic colab modifications by [Hyungon Ryu](https://github.com/yhgon)  | NVAITC Sr. Data Scientist | Center Lead @ NVIDIA 


modification from original inference.ipynb

This notebook requires a GPU runtime to run.
Please select the menu option **"Runtime"** -> **"Change runtime type"**, select **"Hardware Accelerator"** -> **"GPU"** and click **"SAVE"**


## Model Description

Mellotron: a multispeaker voice synthesis model based on Tacotron 2 GST that can make a voice emote and sing without emotive or singing training data.

By explicitly conditioning on rhythm and continuous pitch contours from an audio signal or music score, Mellotron is able to generate speech in a variety of styles ranging from read speech to expressive speech, from slow drawls to rap and from monotonous voice to singing voice.


## DevOps for Google Colab
install required python modules and APEX

In [1]:
%tensorflow_version 1.x 
!nvidia-smi

TensorFlow 1.x selected.
Fri Feb 19 19:01:20 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.39       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   61C    P8    12W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+--------------------------------------------------------------

In [2]:
%%time
%%bash
pip install tensorflow==1.15 inflect==0.2.5 numba==0.48 librosa==0.6.0 scipy==1.3.1 Unidecode==1.0.22 pillow nltk==3.4.5 jamo==0.4.1  vamp

Collecting tensorflow==1.15
  Downloading https://files.pythonhosted.org/packages/3f/98/5a99af92fb911d7a88a0005ad55005f35b4c1ba8d75fba02df726cd936e6/tensorflow-1.15.0-cp36-cp36m-manylinux2010_x86_64.whl (412.3MB)
Collecting inflect==0.2.5
  Downloading https://files.pythonhosted.org/packages/66/15/2d176749884cbeda0c92e0d09e1303ff53a973eb3c6bb2136803b9d962c9/inflect-0.2.5-py2.py3-none-any.whl (58kB)
Collecting numba==0.48
  Downloading https://files.pythonhosted.org/packages/23/7f/dbe85f5f419dca88509d829df90dfa5aefa39c39f6b7020dfc206a386279/numba-0.48.0-1-cp36-cp36m-manylinux2014_x86_64.whl (3.5MB)
Collecting librosa==0.6.0
  Downloading https://files.pythonhosted.org/packages/6b/f4/422bfbefd581f74354ef05176aa48558c548243c87e359d91512d4b65523/librosa-0.6.0.tar.gz (1.5MB)
Collecting scipy==1.3.1
  Downloading https://files.pythonhosted.org/packages/29/50/a552a5aff252ae915f522e44642bb49a7b7b31677f9580cfd11bcc869976/scipy-1.3.1-cp36-cp36m-manylinux1_x86_64.whl (25.2MB)
Collecting Unidecode

ERROR: magenta 0.3.19 has requirement librosa>=0.6.2, but you'll have librosa 0.6.0 which is incompatible.
ERROR: umap-learn 0.5.0 has requirement numba>=0.49, but you'll have numba 0.48.0 which is incompatible.
ERROR: pynndescent 0.5.1 has requirement numba>=0.51.2, but you'll have numba 0.48.0 which is incompatible.
ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.


CPU times: user 13.5 ms, sys: 4.59 ms, total: 18.1 ms
Wall time: 2min 4s


In [3]:
%%time
%%bash
git clone https://github.com/NVIDIA/apex
pip install  --no-cache-dir ./apex # only python  

Processing ./apex
Building wheels for collected packages: apex
  Building wheel for apex (setup.py): started
  Building wheel for apex (setup.py): finished with status 'done'
  Created wheel for apex: filename=apex-0.1-cp36-none-any.whl size=194551 sha256=c1f9f710ecdb4be3ec7a58c3ca8c199c0528678a3c903b172250242f909d19a2
  Stored in directory: /tmp/pip-ephem-wheel-cache-pzt3au0z/wheels/b1/3a/aa/d84906eaab780ae580c7a5686a33bf2820d8590ac3b60d5967
Successfully built apex
Installing collected packages: apex
Successfully installed apex-0.1


Cloning into 'apex'...


CPU times: user 1.53 ms, sys: 7.08 ms, total: 8.61 ms
Wall time: 11 s


In [4]:
%%bash
git clone https://github.com/NVIDIA/mellotron.git
cd mellotron
git submodule init
git submodule update

Submodule path 'waveglow': checked out '2fd4e63e2918012f55eac2c8a8e75622a39741be'


Cloning into 'mellotron'...
Submodule 'waveglow' (https://github.com/NVIDIA/waveglow.git) registered for path 'waveglow'
Cloning into '/content/mellotron/waveglow'...


## download official checkpoint
use google drive utilities

In [5]:
%%bash
wget -N  -q https://raw.githubusercontent.com/yhgon/colab_utils/master/gfile.py
python gfile.py -u 'https://drive.google.com/open?id=1ZesPPyRRKloltRIuRnGZ2LIUEuMSVjkI' -f 'mellotron_libritts.pt'
python gfile.py -u 'https://drive.google.com/open?id=1Rm5rV5XaWWiUbIpg5385l5sh68z2bVOE' -f 'waveglow_256channels_v4.pt'

It took  2.85sec to download 121.9 MB mellotron_libritts.pt 
It took  16.57sec to download 1.3 GB waveglow_256channels_v4.pt 


## Patchs
It's temporal solution for inference on COLAB, (TODO code clean )
- modify CMUDict directories in `hparams.py` with `cmudict_path="/content/mellotron/data/cmu_dictionary"`
- ignore distributed module using `train_utils.py` instead of `train.py`
- modify CMUDict directories for `CMUDICT_PATH`  `/content/mellotron/data/cmu_dictionary`  

In [6]:
#@title
%%file hparams.py
import tensorflow as tf
from text.symbols import symbols


def create_hparams(hparams_string=None, verbose=False):
    """Create model hyperparameters. Parse nondefault from given string."""

    hparams = tf.contrib.training.HParams(
        ################################
        # Experiment Parameters        #
        ################################
        epochs=50000,
        iters_per_checkpoint=500,
        seed=1234,
        dynamic_loss_scaling=True,
        fp16_run=False,
        distributed_run=False,
        dist_backend="nccl",
        dist_url="tcp://localhost:54321",
        cudnn_enabled=True,
        cudnn_benchmark=False,
        ignore_layers=['speaker_embedding.weight'],

        ################################
        # Data Parameters             #
        ################################
        training_files='filelists/ljs_audiopaths_text_sid_train_filelist.txt',
        validation_files='filelists/ljs_audiopaths_text_sid_val_filelist.txt',
        text_cleaners=['english_cleaners'],
        p_arpabet=1.0,
        cmudict_path="/content/mellotron/data/cmu_dictionary",

        ################################
        # Audio Parameters             #
        ################################
        max_wav_value=32768.0,
        sampling_rate=22050,
        filter_length=1024,
        hop_length=256,
        win_length=1024,
        n_mel_channels=80,
        mel_fmin=0.0,
        mel_fmax=8000.0,
        f0_min=80,
        f0_max=880,
        harm_thresh=0.25,

        ################################
        # Model Parameters             #
        ################################
        n_symbols=len(symbols),
        symbols_embedding_dim=512,

        # Encoder parameters
        encoder_kernel_size=5,
        encoder_n_convolutions=3,
        encoder_embedding_dim=512,

        # Decoder parameters
        n_frames_per_step=1,  # currently only 1 is supported
        decoder_rnn_dim=1024,
        prenet_dim=256,
        prenet_f0_n_layers=1,
        prenet_f0_dim=1,
        prenet_f0_kernel_size=1,
        prenet_rms_dim=0,
        prenet_rms_kernel_size=1,
        max_decoder_steps=1000,
        gate_threshold=0.5,
        p_attention_dropout=0.1,
        p_decoder_dropout=0.1,
        p_teacher_forcing=1.0,

        # Attention parameters
        attention_rnn_dim=1024,
        attention_dim=128,

        # Location Layer parameters
        attention_location_n_filters=32,
        attention_location_kernel_size=31,

        # Mel-post processing network parameters
        postnet_embedding_dim=512,
        postnet_kernel_size=5,
        postnet_n_convolutions=5,

        # Speaker embedding
        n_speakers=123,
        speaker_embedding_dim=128,

        # Reference encoder
        with_gst=True,
        ref_enc_filters=[32, 32, 64, 64, 128, 128],
        ref_enc_size=[3, 3],
        ref_enc_strides=[2, 2],
        ref_enc_pad=[1, 1],
        ref_enc_gru_size=128,

        # Style Token Layer
        token_embedding_size=256,
        token_num=10,
        num_heads=8,

        ################################
        # Optimization Hyperparameters #
        ################################
        use_saved_learning_rate=False,
        learning_rate=1e-3,
        learning_rate_min=1e-5,
        learning_rate_anneal=50000,
        weight_decay=1e-6,
        grad_clip_thresh=1.0,
        batch_size=32,
        mask_padding=True,  # set model's padded outputs to padded values

    )

    if hparams_string:
        tf.compat.v1.logging.info('Parsing command line hparams: %s', hparams_string)
        hparams.parse(hparams_string)

    if verbose:
        tf.compat.v1.logging.info('Final parsed hparams: %s', hparams.values())

    return hparams


Writing hparams.py


In [7]:
#@title
%%file train_utils.py
import os
import time
import argparse
import math
from numpy import finfo

import torch
#from distributed import apply_gradient_allreduce
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader

from model import Tacotron2
from data_utils import TextMelLoader, TextMelCollate
from loss_function import Tacotron2Loss
#from logger import Tacotron2Logger
from hparams import create_hparams


def reduce_tensor(tensor, n_gpus):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= n_gpus
    return rt


def init_distributed(hparams, n_gpus, rank, group_name):
    assert torch.cuda.is_available(), "Distributed mode requires CUDA."
    print("Initializing Distributed")

    # Set cuda device so everything is done on the right GPU.
    torch.cuda.set_device(rank % torch.cuda.device_count())

    # Initialize distributed communication
    dist.init_process_group(
        backend=hparams.dist_backend, init_method=hparams.dist_url,
        world_size=n_gpus, rank=rank, group_name=group_name)

    print("Done initializing distributed")


def prepare_dataloaders(hparams):
    # Get data, data loaders and collate function ready
    trainset = TextMelLoader(hparams.training_files, hparams)
    valset = TextMelLoader(hparams.validation_files, hparams,
                           speaker_ids=trainset.speaker_ids)
    collate_fn = TextMelCollate(hparams.n_frames_per_step)

    if hparams.distributed_run:
        train_sampler = DistributedSampler(trainset)
        shuffle = False
    else:
        train_sampler = None
        shuffle = True

    train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=hparams.batch_size, pin_memory=False,
                              drop_last=True, collate_fn=collate_fn)
    return train_loader, valset, collate_fn, train_sampler


def prepare_directories_and_logger(output_directory, log_directory, rank):
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        logger = None    
        #logger = Tacotron2Logger(os.path.join(output_directory, log_directory))
    else:
        logger = None
    return logger


def load_model(hparams):
    model = Tacotron2(hparams).cuda()
    if hparams.fp16_run:
        model.decoder.attention_layer.score_mask_value = finfo('float16').min

    if hparams.distributed_run:
        model = apply_gradient_allreduce(model)

    return model


def warm_start_model(checkpoint_path, model, ignore_layers):
    assert os.path.isfile(checkpoint_path)
    print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
    checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
    model_dict = checkpoint_dict['state_dict']
    if len(ignore_layers) > 0:
        model_dict = {k: v for k, v in model_dict.items()
                      if k not in ignore_layers}
        dummy_dict = model.state_dict()
        dummy_dict.update(model_dict)
        model_dict = dummy_dict
    model.load_state_dict(model_dict)
    return model


def load_checkpoint(checkpoint_path, model, optimizer):
    assert os.path.isfile(checkpoint_path)
    print("Loading checkpoint '{}'".format(checkpoint_path))
    checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint_dict['state_dict'])
    optimizer.load_state_dict(checkpoint_dict['optimizer'])
    learning_rate = checkpoint_dict['learning_rate']
    iteration = checkpoint_dict['iteration']
    print("Loaded checkpoint '{}' from iteration {}" .format(
        checkpoint_path, iteration))
    return model, optimizer, learning_rate, iteration


def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
    print("Saving model and optimizer state at iteration {} to {}".format(
        iteration, filepath))
    torch.save({'iteration': iteration,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'learning_rate': learning_rate}, filepath)


def validate(model, criterion, valset, iteration, batch_size, n_gpus,
             collate_fn, logger, distributed_run, rank):
    """Handles all the validation scoring and printing"""
    model.eval()
    with torch.no_grad():
        val_sampler = DistributedSampler(valset) if distributed_run else None
        val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1,
                                shuffle=False, batch_size=batch_size,
                                pin_memory=False, collate_fn=collate_fn)

        val_loss = 0.0
        for i, batch in enumerate(val_loader):
            x, y = model.parse_batch(batch)
            y_pred = model(x)
            loss = criterion(y_pred, y)
            if distributed_run:
                reduced_val_loss = reduce_tensor(loss.data, n_gpus).item()
            else:
                reduced_val_loss = loss.item()
            val_loss += reduced_val_loss
        val_loss = val_loss / (i + 1)

    model.train()
    if rank == 0:
        print("Validation loss {}: {:9f}  ".format(iteration, reduced_val_loss))
        logger.log_validation(val_loss, model, y, y_pred, iteration)



Writing train_utils.py


In [8]:
#@title
%%file /content/mellotron/mellotron_utils.py
# code heavily inspired from NVIDIA/mellotron/mellotron_utils.py
import numpy as np
import re
import torch

_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')

########################
#  CONSONANT DURATION  #
########################
PHONEMEDURATION = {
    'B': 0.05,
    'CH': 0.1,
    'D': 0.075,
    'DH': 0.05,
    'DX': 0.05,
    'EL': 0.05,
    'EM': 0.05,
    'EN': 0.05,
    'F': 0.1,
    'G': 0.05,
    'HH': 0.05,
    'JH': 0.05,
    'K': 0.05,
    'L': 0.05,
    'M': 0.15,
    'N': 0.15,
    'NG': 0.15,
    'NX': 0.05,
    'P': 0.05,
    'Q': 0.075,
    'R': 0.05,
    'S': 0.1,
    'SH': 0.05,
    'T': 0.075,
    'TH': 0.1,
    'V': 0.05,
    'Y': 0.05,
    'W': 0.05,
    'WH': 0.05,
    'Z': 0.05,
    'ZH': 0.05
}

valid_symbols = [
  'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
  'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
  'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
  'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
  'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
  'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
  'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
]

_arpabet = ['@' + s for s in valid_symbols]

_punctuation = '!\'",.:;? '
_math = '#%&*+-/[]()'
_special = '_@©°½—₩€$'
_accented = 'áçéêëñöøćž'
_numbers = '0123456789'
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'

symbols = list(_punctuation + _math + _special + _accented + _numbers + _letters) + _arpabet

_symbol_to_id = {s: i for i, s in enumerate(symbols)}

def text_to_sequence(text):
    sequence = []

    while len(text):
        m = _curly_re.match(text)
        if not m:
            sequence += _symbols_to_sequence(text)
            break

        sequence += text_to_sequence(m.group(1))
        sequence += _arpabet_to_sequence(m.group(2))
        text = m.group(3)

    return sequence

def _symbols_to_sequence(symbols):
    return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]


def _arpabet_to_sequence(text):
    return _symbols_to_sequence(['@' + s for s in text.split()])


def _should_keep_symbol(s):
    return s in _symbol_to_id and s != '_' and s != '~'


def add_space_between_events(events, connect=False):
    new_events = []
    for i in range(1, len(events)):
        token_a, freq_a, start_time_a, end_time_a = events[i-1][-1]
        token_b, freq_b, start_time_b, end_time_b = events[i][0]

        if token_a in (' ', '') and len(events[i-1]) == 1:
            new_events.append(events[i-1])
        elif token_a not in (' ', '') and token_b not in (' ', ''):
            new_events.append(events[i-1])
            if connect:
                new_events.append([[' ', 0, end_time_a, start_time_b]])
            else:
                new_events.append([[' ', 0, end_time_a, end_time_a]])
        else:
            new_events.append(events[i-1])

    if new_events[-1][0][0] != ' ':
        new_events.append([[' ', 0, end_time_a, end_time_a]])
    new_events.append(events[-1])

    return new_events

def adjust_extensions(events, phoneme_durations):
    if len(events) == 1:
        return events

    idx_last_vowel = None
    n_consonants_after_last_vowel = 0
    rest_after_last_vowel = False
    target_ids = np.arange(len(events))
    for i in range(len(events)):
        token = re.sub('[0-9{}]', '', events[i][0])
        if idx_last_vowel is None and token not in phoneme_durations:
            idx_last_vowel = i
            n_consonants_after_last_vowel = 0
        else:
            if token == '_' and not n_consonants_after_last_vowel:
                events[i][0] = events[idx_last_vowel][0]
            elif token == '_' and n_consonants_after_last_vowel:
                events[i][0] = events[idx_last_vowel][0]
                start = idx_last_vowel + 1
                target_ids[start:start+n_consonants_after_last_vowel] += 1 + int(rest_after_last_vowel)
                target_ids[i] -= n_consonants_after_last_vowel
                if rest_after_last_vowel:
                    target_ids[i-1] -= n_consonants_after_last_vowel
            elif token in phoneme_durations:
                n_consonants_after_last_vowel += 1
            elif token == ' ':
                rest_after_last_vowel = True
            else:
                rest_after_last_vowel = False
                n_consonants_after_last_vowel = 0
                idx_last_vowel = i

    new_events = [0] * len(events)
    for i in range(len(events)):
        new_events[target_ids[i]] = events[i]

    # adjust time of consonants that were repositioned
    for i in range(1, len(new_events)):
        if new_events[i][2] < new_events[i-1][2]:
            new_events[i][2] = new_events[i-1][2]
            new_events[i][3] = new_events[i-1][3]
        if new_events[i][0][0] == '{':
            new_events[i][0] = new_events[i][0][1:]
        if new_events[i][0][-1] == '}' and i < len(new_events) - 1:
            new_events[i][0] = new_events[i][0][:-1]

    first_p = new_events[0][0]
    if not first_p.isspace() and first_p[0] != '{':
        new_events[0][0] = '{' + first_p
    last_p = new_events[-1][0]
    if not last_p.isspace() and last_p[-1] != '}':
        new_events[-1][0] = last_p + '}'
    return new_events


def adjust_consonant_lengths(events, phoneme_durations):
    t_init = events[0][2]
    t_end = events[-1][3]
    duration = t_end - t_init
    consonant_durations = {}
    consonant_duration = 0
    for event in events:
        c = re.sub('[0-9{}]', '', event[0])
        if c in phoneme_durations:
            consonant_durations[c] = phoneme_durations[c]
            consonant_duration += phoneme_durations[c]

    if not consonant_duration <= 0.4 * duration:
        scale = 0.4 * duration / consonant_duration
        for k, v in consonant_durations.items():
            consonant_durations[k] = scale * v

    idx_last_vowel = None
    for i in range(len(events)):
        task = re.sub('[0-9{}]', '', events[i][0])
        if task in consonant_durations:
            duration = consonant_durations[task]
            if idx_last_vowel is None:  # consonant comes before any vowel
                events[i][2] = t_init
                events[i][3] = t_init + duration
            else:  # consonant comes after a vowel, must offset
                events[idx_last_vowel][3] -= duration
                for k in range(idx_last_vowel+1, i):
                    events[k][2] -= duration
                    events[k][3] -= duration
                events[i][2] = events[i-1][3]
                events[i][3] = events[i-1][3] + duration
        else:
            events[i][2] = t_init
            events[i][3] = events[i][3]
            t_init = events[i][3]
            idx_last_vowel = i
        t_init = events[i][3]

    return events


def adjust_consonants(events, phoneme_durations):
    if len(events) == 1:
        return events

    start = 0
    split_ids = []
    t_init = events[0][2]

    # get each substring group
    for i in range(1, len(events)):
        if events[i][2] != t_init:
            split_ids.append((start, i))
            start = i
            t_init = events[i][2]
    split_ids.append((start, len(events)))

    for (start, end) in split_ids:
        events[start:end] = adjust_consonant_lengths(
            events[start:end], phoneme_durations)

    return events


def event2alignment(events, hop_length=256, sampling_rate=22050):
    frame_length = float(hop_length) / float(sampling_rate)

    n_frames = int(events[-1][-1][-1] / frame_length)
    n_tokens = np.sum([len(e) for e in events])
    alignment = np.zeros((n_tokens, n_frames))

    cur_event = -1
    for event in events:
        for i in range(len(event)):
            if len(event) == 1 or cur_event == -1 or event[i][0] != event[i-1][0]:
                cur_event += 1
            token, freq, start_time, end_time = event[i]
            alignment[cur_event, int(start_time/frame_length):int(end_time/frame_length)] = 1

    return alignment[:cur_event+1]


def event2f0(events, hop_length=256, sampling_rate=22050):
    frame_length = float(hop_length) / float(sampling_rate)
    n_frames = int(events[-1][-1][-1] / frame_length)
    f0s = np.zeros((1, n_frames))

    for event in events:
        for i in range(len(event)):
            token, freq, start_time, end_time = event[i]
            f0s[0, int(start_time/frame_length):int(end_time/frame_length)] = freq

    return f0s


def event2text(events, convert_stress):
    text_clean = ''
    for event in events:
        for i in range(len(event)):
            if i > 0 and event[i][0] == event[i-1][0]:
                continue
            if event[i][0] == ' ' and len(event) > 1:
                if text_clean[-1] != "}":
                    text_clean = text_clean[:-1] + '} {'
                else:
                    text_clean += ' {'
            else:
                if event[i][0][-1] in ('}', ' '):
                    text_clean += event[i][0]
                else:
                    text_clean += event[i][0] + ' '

    if convert_stress:
        text_clean = re.sub('[0-9]', '1', text_clean)

    text_encoded = text_to_sequence(text_clean)
    return text_encoded, text_clean


def remove_excess_frames(alignment, f0s):
    excess_frames = np.sum(alignment.sum(0) == 0)
    alignment = alignment[:, :-excess_frames] if excess_frames > 0 else alignment
    f0s = f0s[:, :-excess_frames] if excess_frames > 0 else f0s
    return alignment, f0s


def get_data_from_text_events_with_phonemes(text_events, ticks=False, tempo=120, resolution=960, phoneme_durations=None, convert_stress=False):
    def pitch_to_freq(pitch):
        return 440*(2**((pitch - 69)/12))

    if ticks:
        num = int
        to_time = lambda t: t * 60/(tempo*resolution)
    else:
        num = float
        to_time = lambda t: t

    if phoneme_durations is None:
        phoneme_durations = PHONEMEDURATION
    
    events = []
    phonemes = ''
    word = []
    time = 0
    note_off = True
    rest_start = -1
    new_word = True
    for e in text_events:
        e_split = e.split('_')
        if '_' not in e:
            phonemes = e
        elif e == '_R_':
            phonemes = '_'
        elif e_split[0] == 'ON':
            if rest_start >= 0:
                if new_word and word:
                    last_p = word[-1][0]
                    if last_p != ' ':
                        word[-1][0] = last_p + '}'
                    events.append(word)
                    word = []
                word.append([' ', 0, to_time(rest_start), to_time(time)])
                rest_start = -1
            freq = pitch_to_freq(int(e_split[1]))
            start = time
            note_off = False
        elif e_split[0] == 'W':
            t = num(e_split[1])
            if note_off and rest_start < 0:
                rest_start = time
            time += t
        elif e == '_OFF_':
            if new_word and word:
                last_p = word[-1][0]
                if last_p != ' ':
                    word[-1][0] = last_p + '}'
                events.append(word)
                word = []            
            for i, p in enumerate(phonemes.split()):
                if new_word and i == 0:
                    word.append(['{' + p, freq, to_time(start), to_time(time)])
                else:
                    word.append([p, freq, to_time(start), to_time(time)])
            new_word = False
            note_off = True
        else:
            new_word = True

    last_p = word[-1][0]
    if last_p != ' ':
        word[-1][0] = last_p + '}'
    events.append(word)

    # make adjustments
    events = [adjust_extensions(e, phoneme_durations)
                      for e in events]
    events = [adjust_consonants(e, phoneme_durations)
                      for e in events]
    events = add_space_between_events(events)

    # convert data to alignment, f0 and text encoded
    alignment = event2alignment(events)
    f0s = event2f0(events)
    alignment, f0s = remove_excess_frames(alignment, f0s)
    text_encoded, text_clean = event2text(events, convert_stress)

    # convert data to torch
    alignment = torch.from_numpy(alignment).permute(1, 0)[:, None].float()
    f0s = torch.from_numpy(f0s)[None].float()
    text_encoded = torch.LongTensor(text_encoded)[None]

    return {'rhythm': alignment, 'pitch_contour': f0s, 'text_encoded': text_encoded}


Overwriting /content/mellotron/mellotron_utils.py


In [9]:
import IPython.display as ipd

import sys
################################################################################################
sys.path.append('/content/mellotron')             #####  modified for  colab ######
sys.path.append('/content/mellotron/waveglow/')   #####  modified for  colab ######
################################################################################################

from itertools import cycle
import numpy as np
import scipy as sp
from scipy.io.wavfile import write
import pandas as pd
import librosa
import torch

from hparams import create_hparams
from model import Tacotron2
from waveglow.denoiser import Denoiser
from layers import TacotronSTFT
################################################################################################
from train_utils import load_model #####  modified for inference  on colab #####################
################################################################################################
from data_utils import TextMelLoader, TextMelCollate
from text import cmudict, text_to_sequence
%cd mellotron/
from mellotron_utils import get_data_from_text_events_with_phonemes
%cd ..

/content/mellotron
/content


In [10]:
def load_mel(path):
    audio, sampling_rate = librosa.core.load(path, sr=hparams.sampling_rate)
    audio = torch.from_numpy(audio)
    if sampling_rate != hparams.sampling_rate:
        raise ValueError("{} SR doesn't match target {} SR".format(
            sampling_rate, stft.sampling_rate))
    audio_norm = audio / hparams.max_wav_value
    audio_norm = audio_norm.unsqueeze(0)
    audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
    melspec = stft.mel_spectrogram(audio_norm)
    melspec = melspec.cuda()
    return melspec
        

In [11]:
hparams = create_hparams()

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [12]:
stft = TacotronSTFT(hparams.filter_length, hparams.hop_length, hparams.win_length,
                    hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
                    hparams.mel_fmax)

## Load Models


In [13]:
checkpoint_path = "/content/mellotron_libritts.pt"
tacotron = load_model(hparams).cuda().eval()
tacotron.load_state_dict(torch.load(checkpoint_path)['state_dict'])

<All keys matched successfully>

In [14]:
waveglow_path = '/content/waveglow_256channels_v4.pt'
waveglow = torch.load(waveglow_path)['model']
waveglow = waveglow.remove_weightnorm(waveglow)
waveglow.cuda().eval()
from apex import amp
waveglow, _ = amp.initialize(waveglow, [], opt_level="O3")
denoiser = Denoiser(waveglow).cuda().eval()



Selected optimization level O3:  Pure FP16 training.
Defaults for this optimization level are:
enabled                : True
opt_level              : O3
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : False
master_weights         : False
loss_scale             : 1.0
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O3
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : False
master_weights         : False
loss_scale             : 1.0


## Setup dataloaders for Google colab

In [15]:
%%file /content/mellotron/data/examples_filelist.txt
/content/mellotron/data/example1.wav|exploring the expanses of space to keep our planet safe|1
/content/mellotron/data/example2.wav|and all the species that call it home|1

Overwriting /content/mellotron/data/examples_filelist.txt


In [16]:
arpabet_dict = cmudict.CMUDict('/content/mellotron/data/cmu_dictionary')
audio_paths = '/content/mellotron/data/examples_filelist.txt'
dataloader = TextMelLoader(audio_paths, hparams)
datacollate = TextMelCollate(1)

## Load data for reference voice

In [17]:
file_idx = 0
audio_path, text, sid = dataloader.audiopaths_and_text[file_idx]

# get audio path, encoded text, pitch contour and mel for gst
text_encoded = torch.LongTensor(text_to_sequence(text, hparams.text_cleaners, arpabet_dict))[None, :].cuda()    
pitch_contour = dataloader[file_idx][3][None].cuda()
mel = load_mel(audio_path)
print(audio_path, text)

# load source data to obtain rhythm using tacotron 2 as a forced aligner
x, y = tacotron.parse_batch(datacollate([dataloader[file_idx]]))

/content/mellotron/data/example1.wav exploring the expanses of space to keep our planet safe


## Define Speakers Set

In [18]:
speaker_ids = TextMelLoader("/content/mellotron/filelists/libritts_train_clean_100_audiopath_text_sid_atleast5min_val_filelist.txt", hparams).speaker_ids
speakers = pd.read_csv('/content/mellotron/filelists/libritts_speakerinfo.txt', engine='python',header=None, comment=';', sep=' *\| *', 
                       names=['ID', 'SEX', 'SUBSET', 'MINUTES', 'NAME'])

speakers['MELLOTRON_ID'] = speakers['ID'].apply(lambda x: speaker_ids[x] if x in speaker_ids else -1)
female_speakers = cycle(
    speakers.query("SEX == 'F' and MINUTES > 20 and MELLOTRON_ID >= 0")['MELLOTRON_ID'].sample(frac=1).tolist())
male_speakers = cycle(
    speakers.query("SEX == 'M' and MINUTES > 20 and MELLOTRON_ID >= 0")['MELLOTRON_ID'].sample(frac=1).tolist())

## Singing Voice from Text Events


In [19]:
frequency_scaling = 0.5

def get_audio(data):
    rhythm = data['rhythm'].cuda()
    pitch_contour = data['pitch_contour'].cuda()
    text_encoded = data['text_encoded'].cuda()
    
    speaker_id = torch.LongTensor([next(male_speakers)]).cuda()
    with torch.no_grad():
        mel_outputs, mel_outputs_postnet, gate_outputs, alignments_transfer = tacotron.inference_noattention(
            (text_encoded, mel, speaker_id, pitch_contour * frequency_scaling, rhythm))
        audio = denoiser(waveglow.infer(mel_outputs_postnet.half(), sigma=0.8), 0.01)[0, 0]
        audio = audio.cpu().numpy()
        torch.cuda.empty_cache()
        return ipd.Audio(audio, rate=hparams.sampling_rate)

In [39]:
text = ['W_1600', 'W_24', 'UW1', 'ON_78', 'W_420', '_OFF_', 'W_220', 'D IY0 Z', 'ON_67', 'W_470', '_OFF_', 'W_690', 'N_W', 'S T IY1', 'ON_74', 'W_120', '_OFF_', 'N_W', 'Y UW1', 'ON_66', 'W_230', '_OFF_', 'W_180', 'N_W', 'AA1 R', 'ON_74', 'W_200', '_OFF_', 'N_W', 'W ER1 L D', 'ON_76', 'W_780', '_OFF_', 'N_W', 'AY1', 'ON_74', 'W_890', '_OFF_', 'N_W', 'M IY1', 'ON_66', 'W_430', '_OFF_']#, 'W_2000', 'W_1040', 'W_80', 'N_W', 'N OW1', 'ON_74', 'W_500', '_OFF_', 'W_30', 'N_W', 'DH AH0', 'ON_63', 'W_140', '_OFF_', 'N_W', 'IH1 Z', 'ON_74', 'W_370', '_OFF_', 'W_490', 'N_W', 'HH ER1', 'ON_66', 'W_640', '_OFF_', 'N_W', 'HH IH1 M', 'ON_66', 'W_500', '_OFF_', 'N_W', 'B AH1 T', 'ON_31', 'N_L', 'AY1', 'ON_67', 'W_210', '_OFF_', 'W_120', '_C_', 'N_L', 'W AA1', 'ON_74', 'W_890', '_OFF_', 'W_300', '_OFF_', 'W_360', 'N_W', 'AH0', 'ON_74', 'W_440', '_OFF_', 'W_40', 'N_W', 'K EH1', 'ON_74', 'W_230', '_OFF_', 'W_210', 'N_W', 'T UW1', 'ON_74', 'W_400', '_OFF_', 'W_40', 'N_L', 'AY1', 'ON_74', 'W_440', '_OFF_', 'N AH1 F', 'ON_74', 'W_250', '_OFF_', 'N_W', 'AW1 T', 'ON_74', 'W_680', '_OFF_', 'W_40', 'N_W', 'HH ER0', 'ON_58', 'W_300', '_OFF_', 'N_W', 'DH AE1 T', 'ON_72', 'W_490', '_OFF_', 'W_70', 'N_W', 'N OW1', 'ON_74', 'W_890', 'N_L', 'AH0 N D', 'ON_74', 'W_1560', '_OFF_', 'W_2000', 'W_1220', '_OFF_', 'N_L', 'AH0', 'ON_74', 'W_920', '_OFF_', 'W_550', '_C_', 'N_W', 'Y AO1 R', 'ON_74', 'W_490', '_OFF_', 'W_180', 'N_W', 'Y UW1', 'ON_77', 'W_450', '_OFF_', 'W_160', '_C_', 'N_L', 'L AH1', 'ON_74', 'W_490', '_OFF_', 'W_1460', 'N_L', 'JH AH1 S T', 'ON_74', 'W_280', '_OFF_', 'W_300', '_OFF_', 'W_520', 'N_L', 'G OW1', 'ON_74', 'W_260', 'N_W', 'EH1', 'ON_74', 'W_840', '_OFF_', 'W_50', 'N_W', 'AH0', 'ON_75', 'W_200', '_OFF_', 'N_W', 'AW1 T', 'ON_69', 'W_930', '_OFF_', 'W_210', 'N_W', 'L AH1 V', 'ON_74', 'W_520', '_OFF_', 'N_W', 'Y UW1', 'ON_74', 'W_1890', '_OFF_', 'W_40', 'V IH0 N', 'ON_74', 'W_790', '_OFF_', 'N_W', 'DH AH0', 'ON_74', 'W_850', '_OFF_', 'W_75', '_OFF_', 'N_W', 'K L OW1', 'ON_70', 'W_570', 'N_W', 'IH1 T', 'ON_74', 'W_480', '_OFF_', 'N_W', 'AY1', 'ON_74', 'W_290', '_OFF_', 'N_W', 'W ER1 L D', 'ON_74', 'W_450', '_OFF_', 'W_160', '_R_', 'ON_77', 'W_440', '_OFF_', 'N_W', 'AY1 Z', 'ON_76', 'W_470', '_OFF_', 'N_W', 'S OW1 L', 'ON_74', 'W_520', '_OFF_', 'N_W', 'AH1 V', 'ON_74', 'W_450', '_OFF_', 'N_W', 'HH AY1', 'ON_74', 'W_440', 'N_L', 'G OW1', 'ON_74', 'W_540', '_OFF_', 'W_320', '_OFF_', 'W_320', 'N_L', 'IH1 T S', 'ON_74', 'W_700', '_OFF_', 'W_100', 'N_W', 'AY1', 'ON_74', 'W_190', 'N_W', 'G AA1', 'ON_74', 'W_2000', 'W_1360', '_OFF_', 'N_L', 'AY1', 'ON_74', 'W_300', '_OFF_', 'W_30', 'N_W', 'D AE1', 'ON_74', 'W_490', '_OFF_', 'N_W', 'DH EH1 R', 'ON_74', 'W_30', 'N_L', 'IH1', 'ON_74', 'W_450', '_OFF_', 'W_470', '_OFF_', 'W_910', '_OFF_', 'W_90', 'R G EH1 T', 'ON_70', 'W_300', '_OFF_', 'N D L AH0', 'ON_74', 'W_440', '_OFF_', 'W_40', 'N_W', 'AY1 M', 'ON_74', 'W_580', 'N_L', 'B AH1 T', 'ON_74', 'W_220', '_OFF_', 'W_120', '_OFF_', 'W_120', 'N_W', 'G OW1', 'ON_74', 'W_990', '_OFF_', 'W_490', 'N_W', 'AO1', 'ON_74', 'W_20', 'N_W', 'TH IH1 NG', 'ON_72', 'W_1100', '_OFF_', 'N_L', 'AY1', 'ON_75', 'W_470', '_OFF_', 'L IY0', 'ON_74', 'W_280', '_OFF_', 'W_140', 'N_L', 'B AE1 K', 'ON_74', 'W_250', '_OFF_', 'N_W', 'AY1', 'ON_74', 'W_570', '_OFF_', 'N_W', 'AY1', 'ON_77', 'W_680', '_OFF_', 'N_W', 'S AH1 M', 'ON_67', 'W_2000', 'W_1600', 'N_L', 'M IH2', 'ON_74', 'W_400', '_OFF_', 'W_220', '_OFF_', 'W_210', 'N_W', 'AA1 R M Z', 'ON_74', 'W_270', '_OFF_', 'P IY0', 'ON_70', 'W_720', 'N_L', 'AH0 N D', 'ON_74', 'W_340', '_OFF_', 'W_1430', '_OFF_', 'N_W', 'DH EH1 R', 'ON_74', 'W_880', 'N_W', 'G AA1', 'ON_74', 'W_460', '_OFF_', 'W_1370', 'N_L', 'K L OW1 S', 'ON_74', 'W_260', '_OFF_', 'N_L', 'S OW1', 'ON_79', 'W_440', '_OFF_', 'W_340', '_OFF_', 'W_120', 'N_L', 'OW1', 'ON_74', 'W_180', '_OFF_', 'N_W', 'Y UW1', 'ON_64', 'W_2000', 'W_880', '_OFF_', 'W_2000', 'W_2000', 'W_930', 'V IH0 N', 'ON_74', 'W_450', '_OFF_', 'N_W', 'Y UH0', 'ON_74', 'W_540', '_OFF_', 'W_480', 'N_W', 'B AY1', 'ON_74', 'W_260', '_OFF_', 'W_40', 'N_W', 'ER0', 'ON_74', 'W_660', 'N_L', 'Y UW1', 'ON_74', 'W_410', '_OFF_', '_R_', 'ON_67', 'W_2000', 'W_1200', '_OFF_', 'T IH0 NG', 'ON_77', 'W_880', '_OFF_', 'N_W', 'SH IY1 Z', 'ON_74', 'W_376', '_OFF_', 'W_300', 'N_W', 'AY1', 'ON_74', 'W_480', '_OFF_', 'N_W', 'W AO1', 'ON_74', 'W_2000', 'W_1430', '_OFF_', 'W_260', '_OFF_', 'V AH0 L', 'ON_74', 'W_590', '_OFF_', 'N D ER0', 'ON_74', 'W_250', '_OFF_', 'W_120', 'N_W', 'G AA1', 'ON_74', 'W_170', 'N_L', 'K AH1 M', 'ON_75', 'W_250', '_OFF_', 'N_W', 'T AO1 T', 'ON_72', 'W_440', '_OFF_', '_C_', 'N_L', 'B R EY1', 'ON_74', 'W_480', '_OFF_', 'W_125', 'ON_74', 'W_800', '_OFF_', 'N_W', 'AY1', 'ON_74', 'W_300', '_OFF_', 'N S T EH2 D', 'ON_74', 'W_520', 'N_L', 'HH EH1', 'ON_74', 'W_700', '_OFF_', 'N_W', 'AE1 T', 'ON_74', 'W_120', '_OFF_', 'ON_74', 'W_1150', 'N_L', 'L AH1 V', 'ON_74', 'W_140', '_OFF_', 'N_W', 'G EH1 T', 'ON_74', 'W_200', '_OFF_', 'N_W', 'TH AO1 T', 'ON_74', 'W_480', '_OFF_', 'W_300', '_OFF_', 'N_W', 'S T AA1 R Z', 'ON_74', 'W_480', '_OFF_', 'W_300', 'N_W', 'P EY1 JH', 'ON_74', 'W_1150', '_OFF_', 'W_50', 'N_L', 'Y UW1', 'ON_74', 'W_480', '_OFF_', 'N_L', 'Y UW1', 'ON_74', 'W_470', '_OFF_', 'N_W', 'S L OW1', 'ON_81', 'W_430', '_OFF_', '_C_', 'N_W', 'Y UW1', 'ON_74', 'W_340', '_OFF_', 'W_80', 'N_W', 'F UW1 L', 'ON_74', 'W_440', '_OFF_', 'W_180', 'G IY0', 'ON_74', 'W_2000', 'W_2000', 'W_340', 'N_L', 'F AY1', 'ON_74', 'W_814', '_OFF_', 'W_120', 'N_L', 'IH1 T', 'ON_74', 'W_340', '_OFF_', 'N_W', 'AW1 T', 'ON_72', 'W_250', '_OFF_', 'W_140', 'L AH0', 'ON_74', 'W_580', '_OFF_', 'W_40', 'N T AH0 N', 'ON_74', 'W_440', '_OFF_', 'W_20', 'N_L', 'DH EH1 R', 'ON_74', 'W_1900', '_OFF_', 'W_320', '_OFF_', 'N_W', 'F AO1 R', 'ON_74', 'W_1300', 'N_L', 'AY1', 'ON_69', 'W_970', '_OFF_', 'W_580', '_OFF_', 'N_W', 'S EY1', 'ON_74', 'W_710', '_OFF_', 'W_210', 'N_L', 'AY1', 'ON_74', 'W_550', 'N_L', 'AY1', 'ON_74', 'W_450', '_OFF_', 'W_170', 'N_L', 'AH0 N D', 'ON_74', 'W_175', '_OFF_', 'W_2000', 'W_900', 'N_L', 'N AH1', 'ON_74', 'W_340', '_OFF_', 'W_60', 'N_L', 'S IH1 K S', 'ON_74', 'W_740', '_OFF_', 'N_L', 'IH1 T', 'ON_74', 'W_370', '_OFF_', 'W_150', '_OFF_', 'N_W', 'AH0', 'ON_74', 'W_850', '_OFF_', 'W_320', '_OFF_', 'W_140', 'W AY1 L', 'ON_74', 'W_400', 'N_L', 'AY1', 'ON_74', 'W_270', 'N_L', 'W AY1', 'ON_74', 'W_440', '_OFF_', 'N_W', 'AA1 N', 'ON_74', 'W_460', '_OFF_', 'W_730', '_OFF_', 'W_40', 'N_W', 'W AA1 Z', 'ON_77', 'W_400', '_OFF_', 'W_300', '_OFF_', 'W_30', 'N_W', 'F R EH1 N D Z', 'ON_74', 'W_1335', '_OFF_', 'W_40', 'N_W', 'L AY1 K', 'ON_76', 'W_580', '_OFF_', 'W_20', 'N_W', 'IH1 T', 'ON_74', 'W_520', 'N_L', 'S OW1']

In [40]:
data = get_data_from_text_events_with_phonemes(text, ticks=True, tempo=102.000051)
get_audio(data)


## end of notebook