In [1]:
import json
import re
import subprocess
import torch
from pathlib import Path

from scipy.io.wavfile import write as wav_write
from tqdm.notebook import tqdm

from src.models.hifi_gan.models import Generator, load_model as load_hifi
from src.train_config import TrainParams, load_config
from src.preprocessing.text.cleaners import english_cleaners

In [2]:
config = load_config("configs/fastspeech_base_no_vp/fastspeech2_gst_no_vp_tune.yml")

In [3]:
device = config.device

In [4]:
checkpoint_path = Path(f"checkpoints/{config.checkpoint_name}")

In [6]:
#generators = [file for file in Path(config.pretrained_hifi).rglob("*") if file.name.startswith("g_") ]
generators = [file for file in (checkpoint_path / "hifi").rglob("*.*") if file.name.startswith("g_")]
generators

[PosixPath('checkpoints/fastspeech2_no_vp_tune/hifi/g_2669999.pkl')]

In [7]:
G2P_MODEL_PATH = "models/en/g2p/english_g2p.zip"
G2P_OUTPUT_PATH = "predictions/to_g2p.txt"

In [7]:
def text_to_file(user_query: str) -> None:
    text_path = Path("tmp.txt")
    with open(text_path, "w") as fout:
        normalized_content = english_cleaners(user_query)
        normalized_content = " ".join(re.findall("[a-zA-Z]+", normalized_content))
        fout.write(normalized_content)
    subprocess.call(
        ["mfa", "g2p", G2P_MODEL_PATH, text_path.absolute(), G2P_OUTPUT_PATH]
    )
    text_path.unlink()

In [8]:
default = {"he": "HH IY1", "she": "SH IY1", "we": "W IY1", "be": "B IY0", "the": "DH AH0", "whenever": "W EH0 N EH1 V ER0", "year": "AH0 Y IH1 R"}

def parse_g2p(PHONEMES_TO_IDS, g2p_path: str = G2P_OUTPUT_PATH):
    with open(g2p_path, "r") as fin:
        phonemes_ids = []
        phonemes = []
        phonemes_ids.append(PHONEMES_TO_IDS[""])
        for line in fin:
            word, word_to_phones = line.rstrip().split("\t", 1)
            if word in default:
                word_to_phones = default[word]
            phonemes.extend(word_to_phones.split(" "))
            phonemes_ids.extend(
                [PHONEMES_TO_IDS[ph] for ph in word_to_phones.split(" ")]
            )
        phonemes_ids.append(PHONEMES_TO_IDS[""])
    return phonemes_ids

In [9]:
texts = [
    'Do you realize what time it is?',
    'He comes back to the valley.',
    'This dress does not look worth much!',
    'What happened tonight has nothing to do with Henry.',
    'Today, five years later, we are facing a similar situation.',
    'When I saw you kissing, you looked really happy.',
    'Only one vehicle may be allowed to park at any given time.',
    'The deadlines are indeed very tight.',
    "I'm glad you enjoyed yourself.",
    'What are you still doing here?',
    'This is an animal that is admired for its whiteness and cleanliness.  ',
    'Perhaps there is another way to pose these issues.',
    "Your students' test scores drop lower and lower every year.",
    'Wherever her tears fell, a fruit tree grew.',
    'I was about to head back to my hotel and go to sleep.',
    'You said she really helped last time.',
    'My favorite season, spring, is here.',
    "He's the rich guy who built the airplanes.",
    'Otto and Elizabeth gave it to us, for the wedding - incredibly generous.',
    'Look, the police said that there was nothing stolen from the house.',
    'And I suppose we can thank your brother for that.',
    "That's a pretty dangerous thing you're doing.",
    'He arrived in Japan for the first time at the age of twenty six.',
    'Sam thought we were having fun being together.',
    "Well, the true value of something isn't always determined by its price.",
    "No, it's not polite to discuss a lady's age.",
    "Just another quarter-mile and I don't have to be tolerant ever again.",
    "But Jones' apartment had only been rented out for a week.",
    'What your perfect day would have been like?',
    'Not a very useful skill, especially when the money runs out.',
]

In [8]:
huawei_phon_to_mfa_phon_ = {
    'AX1': 'AO1',
    'UX1': 'UW1'
}

huawei_phones = [
    ' D UW1 Y UW1 R IY1 AH0 L AY2 Z W AH1 T T AY1 M IH1 T IH1 Z  ',
    ' HH IY1 K AH1 M Z B AE1 K T UW1 DH AH0 V AE1 L IY0  ',
    ' DH IH1 S D R EH1 S D AH1 Z N AA1 T L UH1 K W ER1 TH M AH1 CH  ',
    ' W AH1 T HH AE1 P AH0 N D T AH0 N AY1 T HH AE1 Z N AH1 TH IH0 NG T UW1 D UW1 W IH1 DH HH EH1 N R IY0  ',
    ' T AH0 D EY1  F AY1 V Y IH1 R Z L EY1 T ER0  W IY1 AA1 R F EY1 S IH0 NG AH0 S IH1 M AH0 L ER0 S IH2 CH UW0 EY1 SH AH0 N  ',
    ' W EH1 N AY1 S AO1 Y UW1 K IH1 S IH0 NG  Y UW1 L UH1 K T R IH1 L IY0 HH AE1 P IY0  ',
    ' OW1 N L IY0 W AH1 N V IY1 HH IH0 K AH0 L M EY1 B IY1 AH0 L AW1 D T UW1 P AA1 R K AE1 T EH1 N IY0 G IH1 V AH0 N T AY1 M  ',
    ' DH AH0 D EH1 D L AY2 N Z AA1 R IH2 N D IY1 D V EH1 R IY0 T AY1 T  ',
    ' AY1 EH1 M G L AE1 D Y UW1 EH2 N JH OY1 D Y ER0 S EH1 L F  ',
    ' W AH1 T AA1 R Y UW1 S T IH1 L D UW1 IH0 NG HH IY1 R  ',
    ' DH IH1 S IH1 Z AE1 N AE1 N AH0 M AH0 L DH AE1 T IH1 Z AH0 D M AY1 ER0 D F AO1 R IH1 T S W AY1 T N AH0 S AH0 N D K L EH1 N L IY0 N IH0 S  ',
    ' P ER0 HH AE1 P S DH EH1 R IH1 Z AH0 N AH1 DH ER0 W EY1 T UW1 P OW1 Z DH IY1 Z IH1 SH UW0 Z  ',
    ' Y AO1 R S T UW1 D AH0 N T S T EH1 S T S K AO1 R Z D R AA1 P L OW1 ER0 AH0 N D L OW1 ER0 EH1 V ER0 IY0 Y IH1 R  ',
    ' W EH0 R EH1 V ER0 HH ER1 T IH1 R Z F EH1 L  AH0 F R UW1 T T R IY1 G R UW1  ',
    ' AY1 W AA1 Z AH0 B AW1 T T UW1 HH EH1 D B AE1 K T UW1 M AY1 HH OW0 T EH1 L AH0 N D G OW1 T UW1 S L IY1 P  ',
    ' Y UW1 S EH1 D SH IY1 R IH1 L IY0 HH EH1 L P T L AE1 S T T AY1 M  ',
    ' M AY1 F EY1 V ER0 IH0 T S IY1 Z AH0 N  S P R IH1 NG  IH1 Z HH IY1 R  ',
    ' HH IY1 EH1 S DH AH0 R IH1 CH G AY1 HH UW1 B IH1 L T DH IY0 EH1 R P L EY0 N Z  ',
    ' AA1 T OW2 AH0 N D IH0 L IH1 Z AH0 B AH0 TH G EY1 V IH1 T T UW1 AH1 S  F AO1 R DH AH0 W EH1 D IH0 NG  IH2 N K R EH1 D AH0 B L IY0 JH EH1 N ER0 AH0 S  ',
    ' L UH1 K  DH AH0 P AH0 L IY1 S S EH1 D DH AE1 T DH EH1 R W AA1 Z N AH1 TH IH0 NG S T OW1 L AH0 N F R AH1 M DH AH0 HH AW1 S  ',
    ' AH0 N D AY1 S AH0 P OW1 Z W IY1 K AE1 N TH AE1 NG K Y AO1 R B R AH1 DH ER0 F AO1 R DH AE1 T  ',
    ' DH AE1 T EH1 S EY0 P R IH1 T IY0 D EY1 N JH ER0 AH0 S TH IH1 NG Y UW1 R EY1 D UW1 IH0 NG  ',
    ' HH IY1 ER0 AY1 V D IH0 N JH AH0 P AE1 N F AO1 R DH AH0 F ER1 S T T AY1 M AE1 T DH IY0 EY1 JH AH1 V T W EH1 N T IY0 S IH1 K S  ',
    ' S AE1 M TH AO1 T W IY1 W ER1 HH AE1 V IH0 NG F AH1 N B IY1 IH0 NG T AH0 G EH1 DH ER0  ',
    ' W EH1 L  DH AH0 T R UW1 V AE1 L Y UW0 AH1 V S AH1 M TH IH0 NG IH1 S N T IY1 AO1 L W EY2 Z D IH0 T ER1 M AH0 N D B AY1 IH1 T S P R AY1 S  ',
    ' N OW1  IH1 T EH1 S N AA1 T P AH0 L AY1 T T UW1 D IH0 S K AH1 S AH0 L EY1 D IY0 EH1 S EY1 JH  ',
    ' JH AH1 S T AH0 N AH1 DH ER0 K W AO1 R T ER0 M AY1 L AH0 N D AY1 D AA1 N T IY1 HH AE1 V T UW1 B IY1 T AA1 L ER0 AH0 N T EH1 V ER0 AH0 G EH1 N  ',
    ' B AH1 T JH OW1 N Z AH0 P AA1 R T M AH0 N T HH AE1 D OW1 N L IY0 B IH1 N R EH1 N T IH0 D AW1 T F AO1 R AH0 W IY1 K  ',
    ' W AH1 T Y AO1 R P ER1 F IH1 K T D EY1 W UH1 D HH AE1 V B IH1 N L AY1 K  ',
    ' N AA1 T AH0 V EH1 R IY0 Y UW1 S F AH0 L S K IH1 L  AH0 S P EH1 SH L IY0 W EH1 N DH AH0 M AH1 N IY0 R AH1 N Z AW1 T  ',
]

In [9]:
def to_phones(PHONEMES_TO_IDS, phones):
    """For new ones"""
    phonemes_ids = (
       [PHONEMES_TO_IDS[ph] for ph in phones.rstrip().split()]
    )
    return phonemes_ids

In [11]:
phonemes_list = []
with open(checkpoint_path / "fastspeech2"/ "phonemes.json") as f:
    phonemes_to_ids = json.load(f)
for hp in huawei_phones:
    phoneme_ids = to_phones(phonemes_to_ids, hp)
    phonemes_list.append(phoneme_ids)
    #break

In [10]:
fastspeech2_model = torch.load(checkpoint_path / "fastspeech2" / "fastspeech2_model.pth", map_location=device)

In [11]:
fastspeech2_model = fastspeech2_model.eval()

In [21]:
def get_tacotron_batch(
    phonemes_ids, speaker_id, device, ref_mel
):
    text_lengths_tensor = torch.LongTensor([len(phonemes_ids)]).to(device)
    phonemes_ids_tensor = torch.LongTensor(phonemes_ids).unsqueeze(0).to(device)
    speaker_ids_tensor = torch.LongTensor([speaker_id]).to(device)
    # phonemes, num_phonemes, speaker_ids, reference_mel = batch
    return phonemes_ids_tensor, text_lengths_tensor, speaker_ids_tensor, ref_mel

In [19]:
reference_pathes = Path("references/english")

In [14]:
generated_path = Path(f"generated_hifi/{config.checkpoint_name}")

In [15]:
with open(checkpoint_path / "fastspeech2"/ "speakers.json") as f:
    speaker_to_id = json.load(f)

In [16]:
mels_mean = torch.load(checkpoint_path / "fastspeech2" / "mels_mean.pth", map_location=device).float()
mels_std = torch.load(checkpoint_path / "fastspeech2" / "mels_std.pth", map_location=device).float()

In [28]:
for reference in tqdm(list(reference_pathes.rglob("*.pkl"))):
    speaker = reference.parent.name
    speaker_id = speaker_to_id[speaker]
    ref_mel = torch.load(reference, map_location=device)
    for i, phonemes in enumerate(phonemes_list):
        # phonemes, num_phonemes, speaker_ids, reference_mel = batch
        batch = get_tacotron_batch(phonemes, speaker_id, device, ref_mel)
        with torch.no_grad():
            output = fastspeech2_model.inference(batch)
            mels = output[1].permute(0, 2, 1).squeeze(0)
            mels = mels * mels_std.to(device) + mels_mean.to(device)
            x = mels.unsqueeze(0)
            for generator_path in generators:
                state_dict = torch.load(generator_path, map_location="cpu")
                state_dict["generator"] = {k: v.to(device) for k, v in state_dict["generator"].items()}
                generator = Generator(config=config.train_hifi.model_param, num_mels=config.n_mels).to(device)
                generator.load_state_dict(state_dict["generator"])
                generator.remove_weight_norm()
                generator.eval()
                y_g_hat = generator(x)
                audio = y_g_hat.squeeze()
                audio = audio * 32768
                audio = audio.type(torch.int16).detach().cpu().numpy()
                save_path = generated_path / generator_path.stem / speaker / reference.stem
                save_path.mkdir(exist_ok=True, parents=True)
                #print(save_path)
                wav_write(save_path / f"{i + 1}.wav", 22050, audio)
                torch.cuda.empty_cache()


  0%|          | 0/45 [00:00<?, ?it/s]

generated_hifi/fastspeech2_no_vp/g_02500000/0016/neutral
generated_hifi/fastspeech2_no_vp/g_02500000/0016/happy
generated_hifi/fastspeech2_no_vp/g_02500000/0016/sad
generated_hifi/fastspeech2_no_vp/g_02500000/0016/angry
generated_hifi/fastspeech2_no_vp/g_02500000/0016/surprise
generated_hifi/fastspeech2_no_vp/g_02500000/0012/neutral
generated_hifi/fastspeech2_no_vp/g_02500000/0012/happy
generated_hifi/fastspeech2_no_vp/g_02500000/0012/sad
generated_hifi/fastspeech2_no_vp/g_02500000/0012/angry
generated_hifi/fastspeech2_no_vp/g_02500000/0012/surprise
generated_hifi/fastspeech2_no_vp/g_02500000/0018/neutral
generated_hifi/fastspeech2_no_vp/g_02500000/0018/happy
generated_hifi/fastspeech2_no_vp/g_02500000/0018/sad
generated_hifi/fastspeech2_no_vp/g_02500000/0018/angry
generated_hifi/fastspeech2_no_vp/g_02500000/0018/surprise
generated_hifi/fastspeech2_no_vp/g_02500000/0020/neutral
generated_hifi/fastspeech2_no_vp/g_02500000/0020/happy
generated_hifi/fastspeech2_no_vp/g_02500000/0020/sad
g

In [25]:
mels[1].shape

torch.Size([1, 790, 80])