In [1]:
import argparse
import json
import logging
import os
import random
import sys
import time
import collections
import re
import uuid

import torch
import librosa
import numpy as np
from scipy.io import wavfile

from shutil import copyfile
from shutil import rmtree
from copy import deepcopy

MAX_WAV_VALUE = 32768.0

FILE_ROOT = os.path.dirname(os.path.realpath("."))
PROJECT_ROOT = os.path.dirname(os.path.realpath("."))
FILE_ROOT = os.path.join(FILE_ROOT, "tmp")
os.makedirs(FILE_ROOT, exist_ok=True)

os.environ['PYTHONPATH'] = PROJECT_ROOT
sys.path.append(PROJECT_ROOT)

from daft_exprt.extract_features import extract_energy, extract_pitch, mel_spectrogram_HiFi, rescale_wav_to_float32
from daft_exprt.hparams import HyperParams
from daft_exprt.model import DaftExprt
from daft_exprt.cleaners import collapse_whitespace, text_cleaner
from daft_exprt.symbols import ascii, eos, punctuation, whitespace
from daft_exprt.utils import chunker

from hifi_gan.models import Generator
from hifi_gan import AttrDict

_logger = logging.getLogger(__name__)
random.seed(1234)

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
from daft_exprt.synthesize import *

In [3]:
chkpt_path = "/work/tc046/tc046/lordzuko/work/daft-exprt/trainings/daft_bc2013_v1/checkpoints/DaftExprt_best"
vocoder_config_path = "/work/tc046/tc046/lordzuko/work/daft-exprt/hifi_gan/config_v1.json"
vocoder_chkpt_path = "/work/tc046/tc046/lordzuko/work/daft-exprt/trainings/hifigan/checkpoints/g_00100000"
daft_config_path = "/work/tc046/tc046/lordzuko/work/speech-editor/conf/daft_config.json"

In [4]:
hparams = HyperParams(**json.load(open(daft_config_path)), verbose=False)
random.seed(hparams.seed)
torch.manual_seed(hparams.seed)
torch.backends.cudnn.deterministic = True
_logger.warning('You have chosen to seed training. This will turn on the CUDNN deterministic setting, '
                'which can slow down your training considerably! You may see unexpected behavior when '
                'restarting from checkpoints.\n')

You have chosen to seed training. This will turn on the CUDNN deterministic setting, which can slow down your training considerably! You may see unexpected behavior when restarting from checkpoints.



In [5]:
model = get_model(chkpt_path, hparams)
vocoder = get_vocoder(vocoder_config_path, vocoder_chkpt_path)
dictionary = get_dictionary(hparams)

Removing weight norm...


In [151]:
def phonemize_sentence(sentence, dictionary, hparams):
    ''' Phonemize sentence using MFA
    '''
    # get MFA variables
    word_trans = dictionary
    g2p_model = hparams.mfa_g2p_model

    # characters to consider in the sentence
    if hparams.language == 'english':
        all_chars = ascii + punctuation
    else:
        raise NotImplementedError()
    
    # clean sentence
    # "that's, an 'example! ' of a sentence. '"
    sentence = text_cleaner(sentence.strip(), hparams.language).lower().strip()
    # split sentence:
    # [',', "that's", ',', 'an', "example'", '!', "'", 'of', 'a', 'sentence', '.', '.', '.', "'"]
    sent_words = re.findall(f"[\w']+|[{punctuation}]", sentence.lower().strip())
    # remove characters that are not letters or punctuation:
    # [',', "that's", ',', 'an', "example'", '!', 'of', 'a', 'sentence', '.', '.', '.']
    sent_words = [x for x in sent_words if len(re.sub(f'[^{all_chars}]', '', x)) != 0]
    # be sure to begin the sentence with a word and not a punctuation
    # ["that's", ',', 'an', "example'", '!', 'of', 'a', 'sentence', '.', '.', '.']
    while sent_words[0] in punctuation:
        sent_words.pop(0)
    # keep only one punctuation type at the end
    # ["that's", ',', 'an', "example'", '!', 'of', 'a', 'sentence']
    punctuation_end = "."
    while sent_words[-1] in punctuation:
        punctuation_end = sent_words.pop(-1)
    sent_words.append(punctuation_end)
    
    _words = deepcopy(sent_words)

    # phonemize words and add word boundaries
    sentence_phonemized, unk_words = [], []
    while len(sent_words) != 0:
        word = sent_words.pop(0)
        if word in word_trans:
            # phones = random.choice(word_trans[word])
            phones = word_trans[word][0]
            sentence_phonemized.append(phones)
        else:
            unk_words.append(word)
            sentence_phonemized.append('<unk>')
        # at this point we pass to the next word
        # we must add a word boundary between two consecutive words
        # print("sent_words: ", sent_words)
        if len(sent_words) != 0:
            word_bound = sent_words.pop(0) if sent_words[0] in punctuation else whitespace
            sentence_phonemized.append(word_bound)
    # add EOS token
    sentence_phonemized.append(eos)
    
    # use MFA g2p model to phonemize unknown words
    if len(unk_words) != 0:
        rand_name = str(uuid.uuid4())
        oovs = os.path.join(FILE_ROOT, f'{rand_name}_oovs.txt')
        with open(oovs, 'w', encoding='utf-8') as f:
            for word in unk_words:
                f.write(f'{word}\n')
        # generate transcription for unknown words
        oovs_trans = os.path.join(FILE_ROOT, f'{rand_name}_oovs_trans.txt')
        tmp_dir = os.path.join(FILE_ROOT, f'{rand_name}')
        os.system(f'mfa g2p {g2p_model} {oovs} {oovs_trans} -t {tmp_dir}')
        # extract transcriptions
        with open(oovs_trans, 'r', encoding='utf-8') as f:
            lines = [line.strip().split() for line in f.readlines()]
        for line in lines:
            transcription = line[1:]
            unk_idx = sentence_phonemized.index('<unk>')
            sentence_phonemized[unk_idx] = transcription
        # remove files
        os.remove(oovs)
        os.remove(oovs_trans)
        rmtree(tmp_dir, ignore_errors=True)

    nb_symbols = 0
    word_idx = 0
    idxs = []
    words = []
    phones = []
    ignore_idxs = []
    for item in sentence_phonemized:
        if isinstance(item, list):  # correspond to phonemes of a word
            nb_symbols += len(item)
            idxs.append(nb_symbols)
            words.append(_words[word_idx])
            phones.extend(item)
            word_idx += 1
        else:  # correspond to word boundaries
            nb_symbols += 1
            idxs.append(nb_symbols)
            words.append(item)
            phones.append(item)
            ignore_idxs.append(nb_symbols)

    return sentence_phonemized, words, phones, idxs, ignore_idxs


In [158]:
text = "She wondered why he should be miserable."
sentences = [text]
phonemeized_sents = prepare_sentences_for_inference(sentences,dictionary, hparams)
filenames = ["a.wav"]

In [159]:
# sentence_phonemized, words, phones, idxs, ignore_idxs
print("phonemized_sents: ", phonemeized_sents[0][0], len(phonemeized_sents[0][0]))
print("words: ",phonemeized_sents[0][1], len(phonemeized_sents[0][1]))
print("phones: ", phonemeized_sents[0][2], len(phonemeized_sents[0][2]))
print("idxs: ", phonemeized_sents[0][3], len(phonemeized_sents[0][3]))
print("ignore_idxs: ", phonemeized_sents[0][4], len(phonemeized_sents[0][4]))

phonemized_sents:  [['SH', 'IY1'], ' ', ['W', 'AH1', 'N', 'D', 'ER0', 'D'], ' ', ['W', 'AY1'], ' ', ['HH', 'IY1'], ' ', ['SH', 'UH1', 'D'], ' ', ['B', 'IY1'], ' ', ['M', 'IH1', 'Z', 'R', 'AH0', 'B', 'AH0', 'L'], '.', '~'] 15
words:  ['she', ' ', 'wondered', ' ', 'why', ' ', 'he', ' ', 'should', ' ', 'be', ' ', 'miserable', '.', '~'] 15
phones:  ['SH', 'IY1', ' ', 'W', 'AH1', 'N', 'D', 'ER0', 'D', ' ', 'W', 'AY1', ' ', 'HH', 'IY1', ' ', 'SH', 'UH1', 'D', ' ', 'B', 'IY1', ' ', 'M', 'IH1', 'Z', 'R', 'AH0', 'B', 'AH0', 'L', '.', '~'] 33
idxs:  [2, 3, 9, 10, 12, 13, 15, 16, 19, 20, 22, 23, 31, 32, 33] 15
ignore_idxs:  [3, 10, 13, 16, 20, 23, 32, 33] 8


In [8]:
phonemeized_sents = [phonemeized_sents[0][0]]
style_bank = os.path.join(PROJECT_ROOT, 'scripts', 'style_bank', 'english')
# ref_path = "/scratch/space1/tc046/lordzuko/work/data/raw_data/BC2013_daft_orig/CB/wavs/CB-EM-01-05.wav"
ref_path = "/scratch/space1/tc046/lordzuko/work/data/raw_data/BC2013_daft_orig/CB/wavs/CB-EM-04-96.wav"
# ref_path = "/scratch/space1/tc046/lordzuko/work/data/raw_data/BC2013_daft_orig/CB/wavs/CB-EM-04-100.wav"
ref_parameters = extract_reference_parameters(ref_path, hparams)

  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore


### Initial Inference

In [81]:
dur_factor = None #1.25  # decrease speed
pitch_transform = 'add'  # pitch shift
pitch_factor = None # 50Hz
energy_factor = None

# add duration factors for each symbol in the sentence
dur_factors = [] if dur_factor is not None else None
energy_factors = [] if energy_factor is not None else None
pitch_factors = [pitch_transform, []] if pitch_factor is not None else None

for sentence in phonemeized_sents:
    # count number of symbols in the sentence
    nb_symbols = 0
    for item in sentence:
        if isinstance(item, list):  # correspond to phonemes of a word
            nb_symbols += len(item)
        else:  # correspond to word boundaries
            nb_symbols += 1
    print("num symbols: ", nb_symbols)
    # append to lists
    if dur_factors is not None:
        dur_factors.append([dur_factor for _ in range(nb_symbols)])
    if energy_factors is not None:
        energy_factors.append([energy_factor for _ in range(nb_symbols)])
    if pitch_factors is not None:
        pitch_factors[1].append([pitch_factor for _ in range(nb_symbols)])

num symbols:  47


In [82]:
speaker_ids = [0]*len(sentences)
refs = [ref_parameters]
batch_size = 1
# generate mel-specs and synthesize audios with Griffin-Lim
batch_predictions = generate_mel_specs(model, phonemeized_sents, speaker_ids, refs,
                   hparams, dur_factors, energy_factors, pitch_factors, batch_size, filenames)


# duration_pred, duration_int, energy_pred, pitch_pred    
control_values_init = {}
v = batch_predictions["a.wav"]

control_values_init["d"] = v[0].unsqueeze(0).detach().cpu().numpy()
control_values_init["e"] = v[2].unsqueeze(0).detach().cpu().numpy()
control_values_init["p"] = v[3].unsqueeze(0).detach().cpu().numpy()
mels = v[4].unsqueeze(0)
wavs = vocoder_infer(mels, vocoder, lengths=None)
    
wavfile.write(os.path.join("./", "{}".format("a.wav")), hparams.sampling_rate, wavs[0])

enc_outputs:before-:  tensor([[[-0.0109,  0.0335, -0.7462,  ..., -0.3758,  0.0136, -0.0476],
         [-0.0145, -0.1048,  0.1548,  ..., -0.3636,  0.0388, -0.0189],
         [ 0.0066,  0.2483,  0.6379,  ..., -0.3526,  0.0311, -0.0394],
         ...,
         [-0.0240, -0.8161, -0.3820,  ..., -0.3446,  0.0325, -0.0781],
         [ 0.1019, -0.0201,  0.0670,  ..., -0.3558,  0.0062, -0.0278],
         [-0.1175, -0.2773, -0.1465,  ..., -0.3587,  0.0020, -0.0246]]])
shapes:  torch.Size([1, 47]) torch.Size([1, 47]) torch.Size([1, 47])
enc_outputs:after-:  tensor([[[-0.0109,  0.0335, -0.7462,  ..., -0.3758,  0.0136, -0.0476],
         [-0.0145, -0.1048,  0.1548,  ..., -0.3636,  0.0388, -0.0189],
         [ 0.0066,  0.2483,  0.6379,  ..., -0.3526,  0.0311, -0.0394],
         ...,
         [-0.0240, -0.8161, -0.3820,  ..., -0.3446,  0.0325, -0.0781],
         [ 0.1019, -0.0201,  0.0670,  ..., -0.3558,  0.0062, -0.0278],
         [-0.1175, -0.2773, -0.1465,  ..., -0.3587,  0.0020, -0.0246]]])
no-f

In [83]:
control_values_init["d"]

array([[0.07153996, 0.02839828, 0.04961199, 0.08612181, 0.05773258,
        0.03707456, 0.        , 0.06056041, 0.03342576, 0.04543884,
        0.05716084, 0.        , 0.04015773, 0.02630617, 0.06014895,
        0.        , 0.0827309 , 0.07458448, 0.05278093, 0.05188501,
        0.        , 0.07195716, 0.06014519, 0.05228197, 0.        ,
        0.0517866 , 0.        , 0.07471395, 0.05485927, 0.07311358,
        0.        , 0.03576265, 0.07061581, 0.        , 0.04703812,
        0.06260144, 0.06879132, 0.08413483, 0.03990091, 0.03874783,
        0.04140083, 0.05174303, 0.06163193, 0.05684355, 0.07259117,
        0.        , 0.        ]], dtype=float32)

In [84]:
control_values_init["e"]

array([[-1.259926  ,  0.0547027 ,  2.2075067 , -0.53379595,  1.4796746 ,
        -0.1130432 ,  0.        , -1.1543552 ,  1.0563328 ,  1.0694501 ,
        -0.74712014,  0.        , -0.94322866,  0.89613485, -0.63331246,
         0.        , -1.130187  ,  0.7843849 , -0.6115345 ,  0.36612192,
         0.        , -0.09563295,  0.44997376, -0.64130163,  0.        ,
         0.8105938 ,  0.        , -0.58725715,  0.66506827, -1.0940711 ,
         0.        ,  0.87069833, -0.8133275 ,  0.        ,  0.6626423 ,
         0.9503517 ,  0.6938566 , -1.172732  , -0.3553613 ,  0.49991325,
         0.7131107 ,  0.5386547 , -0.7668078 ,  0.04634795,  0.05067344,
         0.        ,  0.        ]], dtype=float32)

In [85]:
control_values_init["p"]

array([[-0.02353824,  1.7063059 ,  2.5300465 ,  1.2209028 ,  0.96856123,
         0.18960619,  0.        ,  0.0254237 ,  0.99284846,  0.3658704 ,
         0.09693921,  0.        ,  0.03084239,  0.27619335,  0.29360187,
         0.        ,  0.3773249 , -0.13530944, -0.54148346, -0.12286822,
         0.        ,  0.06538778, -0.2298479 , -0.15782307,  0.        ,
         0.44838503,  0.        ,  0.11425585,  0.4341223 , -0.1507506 ,
         0.        ,  0.7297614 ,  0.4344019 ,  0.        ,  1.0624006 ,
         0.9117279 ,  1.0594965 ,  0.42424563,  0.8624264 ,  0.626301  ,
         0.6383288 ,  0.6641659 ,  0.27507088,  0.6182978 ,  0.08646534,
         0.        ,  0.        ]], dtype=float32)

### Fine control

In [86]:
# fc = {}
# fc["d"] = np.ones(control_values_init["d"].shape)
# fc["e"] = np.ones(control_values_init["e"].shape)
# fc["p"] = np.ones(control_values_init["p"].shape)


In [87]:
fc = {}
fc["d"] = control_values_init["d"]
# fc["d"][0][:6] *= 2.
fc["e"] = control_values_init["e"]
fc["p"] = control_values_init["p"]
fc["p"][0][:6] *= 3.

In [88]:
fc["p"][0][:6]

array([-0.07061473,  5.1189175 ,  7.5901394 ,  3.6627083 ,  2.9056838 ,
        0.56881857], dtype=float32)

In [89]:
control_values_init["p"]

array([[-0.07061473,  5.1189175 ,  7.5901394 ,  3.6627083 ,  2.9056838 ,
         0.56881857,  0.        ,  0.0254237 ,  0.99284846,  0.3658704 ,
         0.09693921,  0.        ,  0.03084239,  0.27619335,  0.29360187,
         0.        ,  0.3773249 , -0.13530944, -0.54148346, -0.12286822,
         0.        ,  0.06538778, -0.2298479 , -0.15782307,  0.        ,
         0.44838503,  0.        ,  0.11425585,  0.4341223 , -0.1507506 ,
         0.        ,  0.7297614 ,  0.4344019 ,  0.        ,  1.0624006 ,
         0.9117279 ,  1.0594965 ,  0.42424563,  0.8624264 ,  0.626301  ,
         0.6383288 ,  0.6641659 ,  0.27507088,  0.6182978 ,  0.08646534,
         0.        ,  0.        ]], dtype=float32)

In [90]:
dur_factor = fc["d"] #1.25  # decrease speed
pitch_transform = 'add'  # pitch shift
pitch_factor = fc["p"] # 50Hz
energy_factor = fc["e"]

# add duration factors for each symbol in the sentence
dur_factors = [] if dur_factor is not None else None
energy_factors = [] if energy_factor is not None else None
pitch_factors = [pitch_transform, []] if pitch_factor is not None else None

if dur_factors is not None:
    dur_factors = dur_factor
if energy_factors is not None:
    energy_factors = energy_factor
if pitch_factors is not None:
    pitch_factors[1] = pitch_factor

In [91]:
speaker_ids = [0]*len(sentences)
refs = [ref_parameters]
batch_size = 1
# generate mel-specs and synthesize audios with Griffin-Lim
batch_predictions = generate_mel_specs(model, phonemeized_sents, speaker_ids, refs,
                   hparams, dur_factors, energy_factors, pitch_factors, batch_size, filenames, fine_control=True)


# duration_pred, duration_int, energy_pred, pitch_pred    
control_values_updated_fc = {}
v = batch_predictions["a.wav"]

control_values_updated_fc["d"] = v[0].unsqueeze(0).detach().cpu().numpy()
control_values_updated_fc["e"] = v[2].unsqueeze(0).detach().cpu().numpy()
control_values_updated_fc["p"] = v[3].unsqueeze(0).detach().cpu().numpy()
mels = v[4].unsqueeze(0)
wavs = vocoder_infer(mels, vocoder, lengths=None)
    
wavfile.write(os.path.join("./", "{}".format("c.wav")), hparams.sampling_rate, wavs[0])

enc_outputs:before-:  tensor([[[-0.0109,  0.0335, -0.7462,  ..., -0.3758,  0.0136, -0.0476],
         [-0.0145, -0.1048,  0.1548,  ..., -0.3636,  0.0388, -0.0189],
         [ 0.0066,  0.2483,  0.6379,  ..., -0.3526,  0.0311, -0.0394],
         ...,
         [-0.0240, -0.8161, -0.3820,  ..., -0.3446,  0.0325, -0.0781],
         [ 0.1019, -0.0201,  0.0670,  ..., -0.3558,  0.0062, -0.0278],
         [-0.1175, -0.2773, -0.1465,  ..., -0.3587,  0.0020, -0.0246]]])
shapes-:  torch.Size([1, 47]) torch.Size([1, 47]) torch.Size([1, 47])
fc:dp:-before  tensor([[0.0715, 0.0284, 0.0496, 0.0861, 0.0577, 0.0371, 0.0000, 0.0606, 0.0334,
         0.0454, 0.0572, 0.0000, 0.0402, 0.0263, 0.0601, 0.0000, 0.0827, 0.0746,
         0.0528, 0.0519, 0.0000, 0.0720, 0.0601, 0.0523, 0.0000, 0.0518, 0.0000,
         0.0747, 0.0549, 0.0731, 0.0000, 0.0358, 0.0706, 0.0000, 0.0470, 0.0626,
         0.0688, 0.0841, 0.0399, 0.0387, 0.0414, 0.0517, 0.0616, 0.0568, 0.0726,
         0.0000, 0.0000]]) torch.Size([1, 47])

In [34]:
# speaker_ids = [0]*len(sentences)
# refs = [ref_parameters]
# batch_size = 1
# # generate mel-specs and synthesize audios with Griffin-Lim
# batch_predictions = generate_mel_specs(model, phonemeized_sents, speaker_ids, refs,
#                    hparams, dur_factors, energy_factors, pitch_factors, batch_size, filenames)


# # duration_pred, duration_int, energy_pred, pitch_pred    
# control_values_updated = {}
# v = batch_predictions["a.wav"]

# control_values_updated["d"] = v[0].unsqueeze(0).detach().cpu().numpy()
# control_values_updated["e"] = v[2].unsqueeze(0).detach().cpu().numpy()
# control_values_updated["p"] = v[3].unsqueeze(0).detach().cpu().numpy()
# mels = v[4].unsqueeze(0)
# wavs = vocoder_infer(mels, vocoder, lengths=None)
    
# wavfile.write(os.path.join("./", "{}".format("b.wav")), hparams.sampling_rate, wavs[0])

In [35]:
# control_values_updated["d"]

In [36]:
# control_values_updated["e"]

In [37]:
# control_values_updated["p"]

In [None]:
### Initial

In [160]:
len([['HH', 'EH1', 'R', 'IY0', 'AH0', 'T'], 
  ' ', ['S', 'M', 'AY1', 'L', 'D'], ' ', 
  ['AH0', 'G', 'EY1', 'N'], ',', 
  ['AE1', 'N', 'D'], ' ', ['HH', 'ER0'], ' ', 
  ['S', 'M', 'AY1', 'L', 'Z'], ' ', ['G', 'R', 'UW1'], 
  ' ', ['S', 'T', 'R', 'AO1', 'NG', 'ER0'], '.', '~'])

17

In [161]:
len(['harriet', ' ', 'smiled', ' ', 'again', ',', ',', ' ', 'and', ' ', 'her', ' ', 'smiles', ' ', 'grew', '.', '~'])


17

In [162]:
len(['harriet', 'smiled', 'again', ',', 'and', 'her', 'smiles', 'grew', 'stronger', '.'])

10