In [1]:
import os
import json
import yaml
import sys
import time
import copy
import IPython.display as ipd
import pprint
from pathlib import Path
from tqdm import tqdm


import numpy as np
import torch
import torchaudio
from librosa.filters import mel as librosa_mel_fn
#import matplotlib
#matplotlib.use("Agg")
import matplotlib.pyplot as plt
from scipy.io.wavfile import write


import toybox

In [116]:
def plot_audio(audio, samplerate, title='time-domain waveform'):
    """
    usage:
        # audio is [channel, time(num_frames)] ex.torch.Size([1, 68608])
        # audio[0,:]: list of 1ch audio data
        # audio.shape[1]: int value of 1ch audio data length
        audio, sample_rate = torchaudio.load(str(iwav_path))
        %matplotlib inline
        plot_audio(audio, sample_rate)
    """
    # transform to mono
    channel = 0
    audio = audio[channel,:].view(1,-1)
    # to numpy
    audio = audio.to('cpu').detach().numpy().copy()
    time = np.linspace(0., audio.shape[1]/samplerate, audio.shape[1])
    
    fig, ax = plt.subplots(figsize=(12,9))
    
    ax.plot(time, audio[0, :])
    ax.set_title(title, fontsize=20, y=-0.12)
    ax.tick_params(direction='in')
    #ax.set_xlim(0, 3)
    ax.set_xlabel('Time')
    ax.set_ylabel('Amp')
    #ax.legend()
    plt.tight_layout()
    fig.canvas.draw()
    plt.show()
    #fig.savefig('figure.png')
    plt.close(fig)
    return fig

def plot_mel(tensors:list, titles:list[str]):
    """
    usage:
        mel = mel_process(...)
        fig_mel = plot_mel([mel_groundtruth[0], mel_prediction[0]],
                            ['groundtruth', 'inferenced(model)'])

    """
    xlim = max([t.shape[1] for t in tensors])
    fig, axs = plt.subplots(nrows=len(tensors),
                            ncols=1,
                            figsize=(12, 9),
                            constrained_layout=True)

    if len(tensors) == 1:
        axs = [axs]
    
    for i in range(len(tensors)):
        im = axs[i].imshow(tensors[i],
                           aspect="auto",
                           origin="lower",
                           interpolation='none')
        #plt.colorbar(im, ax=axs[i])
        fig.colorbar(im, ax=axs[i])
        axs[i].set_title(titles[i])
        axs[i].set_xlim([0, xlim])
    fig.canvas.draw()
    #plt.show()
    #plt.close()
    plt.close(fig)  # fig.close() 
    return fig

def convert_phn_to_id(phonemes, phn2id):
    """
    phonemes: phonemes separated by ' '
    phn2id: phn2id dict
    """
    return [phn2id[x] for x in ['<bos>'] + phonemes.split(' ') + ['<eos>']]


def text2phnid(text, phn2id, language='en', add_blank=True):
    if language == 'en':
        from text import G2pEn
        word2phn = G2pEn()
        phonemes = word2phn(text)
        if add_blank:
            phonemes = ' <blank> '.join(phonemes)
        return phonemes, convert_phn_to_id(phonemes, phn2id)
    else:
        raise ValueError(
            'Language should be en (for English)!')

In [154]:
model_index = 3  # 0:gradtts, 1:gradseptts, 2:gradtfktts, 3:gradtfk5tts, 4:gradtimektts, 5:gradfreqktts
ckpt_name = '820_651082.pt'
# for inference
#N_STEP = 100
#TEMP = 1.5

In [155]:
# First, please check changing <model_name>
#ckpt_file_dir: logs4model/<model_name>/<runtime_name>/ckpt/
config_yaml = 'configs/config_exp_mid.yaml'
config = toybox.load_yaml_and_expand_var('configs/config_exp_mid.yaml')

model_info = [
    ['gradtts', 'gt_k3'],
    ['gradseptts', 'sgt_k3'],
    ['gradtfktts', 'tfk_k3'],
    ['gradtfk5tts', 'tfk_k5'],
    ['gradtimektts', 'timek_k3'],
    ['gradfreqktts', 'freqk_k3'],
]
model_name = f'{model_info[model_index][0]}' # gradtts, gradseptts, gradtfktts, gradtfk5tts, gradtimektts, gradfreqktts
runtime_name = f'{model_info[model_index][1]}'

config_path4model = Path(f'./configs/config_{runtime_name}.yaml')
config4model = toybox.load_yaml_and_expand_var(config_path4model)
print(f'config_path4model: {config_path4model}')
print(f'exists: {config_path4model.exists()}')

test_ds_path = Path(config['test_datalist_path'])
if test_ds_path.exists():
    print(f'Exists {str(test_ds_path)}')
    with open(config['test_datalist_path']) as j:
        test_ds_list = json.load(j)
    print(f'loaded {test_ds_path}')
else:
    print(f'No exist {test_ds_path}')

ckpt_dir_path = Path(f'./logs4model/{model_name}/run_{runtime_name}/ckpt')
ckpt_path = ckpt_dir_path / f'{model_name}_{ckpt_name}'
print(f"ckpt_path: {ckpt_path}")
print(f'ckpt_dir_exist :{ckpt_dir_path.exists()}')
print(f'ckpt_path_exist:{ckpt_path.exists()}')

config_path4model: configs/config_tfk_k5.yaml
exists: True
Exists configs/test_dataset.json
loaded configs/test_dataset.json
ckpt_path: logs4model/gradtfk5tts/run_tfk_k5/ckpt/gradtfk5tts_820_651082.pt
ckpt_dir_exist :True
ckpt_path_exist:True


In [156]:
# for audio params
n_mels: int = config['n_mels'] # 80
n_fft: int = config['n_fft'] # 1024
sample_rate: int = config['sample_rate'] # 22050
hop_size: int = config['hop_size'] # 256
win_size: int = config['win_size'] # 1024
f_min: int = config['f_min'] # 0
f_max: int = config['f_max'] # 8000
random_seed: int = config['random_seed'] # 1234
print(n_mels, n_fft, sample_rate, hop_size, win_size, f_min, f_max, random_seed)

# for text analysis
print(f"phn2id_path: {config['phn2id_path']}")
with open(config['phn2id_path']) as f:
    phn2id = json.load(f)

vocab_size = len(phn2id) + 1

80 1024 22050 256 1024 0 8000 1234
phn2id_path: ./configs/phn2id.json


In [157]:
# for hifigan
# setting file paths
# from https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS/hifi-gan
# https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing
HiFiGAN_CONFIG = './hifigan/official_pretrained/LJ_V2/config.json'
HiFiGAN_ckpt = './hifigan/official_pretrained/LJ_V2/generator_v2'

from hifigan import models, env

with open(HiFiGAN_CONFIG) as f:
    hifigan_hparams = env.AttrDict(json.load(f))

hifigan_randomseed = hifigan_hparams.seed
print(f'hifigan_randomseed: {hifigan_randomseed}')

hifigan_randomseed: 1234


In [158]:
# for cpu device
import os

print(f"all cpu at using device: {os.cpu_count()}")
print(f"Number of available CPU: {len(os.sched_getaffinity(0))}") # Number of available CPUs can also be obtained. ,use systemcall at linux.
#print(f"GPU_name: {torch.cuda.get_device_name()}\nGPU avail: {torch.cuda.is_available()}\n")
DEVICE = 'cpu'
device = torch.device(DEVICE)
print(f'device: {device}')

# setting random_seed ==============
print(f'device: {random_seed}')
toybox.set_seed(random_seed)
print(str(torch.get_default_device()))

all cpu at using device: 52
Number of available CPU: 4
device: cpu
device: 1234
cpu


In [159]:
# load model
from gradtts import GradTTS
from gradseptts import GradSepTTS
from gradtfktts import GradTFKTTS
from gradtfk5tts import GradTFKTTS as GradTFK5TTS
from gradtimektts import GradTimeKTTS
from gradfreqktts import GradFreqKTTS
from gradtfkfultts import GradTFKFULTTS

print(f'model_name: {model_name}')
print("[seq] loading Model")
print(f'model_step: {N_STEP}')
print(f'model_temp: {TEMP}')

print('loading ', ckpt_path)
_, _, state_dict = torch.load(ckpt_path,
                            map_location=device)


print("[seq] Initializing diffusion-TTS...")
if model_name == "gradtts":
    model = GradTTS.build_model(config4model, vocab_size)
elif model_name == "gradseptts":
    model = GradSepTTS.build_model(config4model, vocab_size)
elif model_name == "gradtfktts":
    model = GradTFKTTS.build_model(config4model, vocab_size)
elif model_name == "gradtfk5tts":
    model = GradTFK5TTS.build_model(config4model, vocab_size)
elif model_name == "gradtfkfultts":
    model = GradTFKFULTTS.build_model(config4model, vocab_size)
elif model_name == "gradtimektts":
    model = GradTimeKTTS.build_model(config4model, vocab_size)
elif model_name == "gradfreqktts":
    model = GradFreqKTTS.build_model(config4model, vocab_size)
else:
    raise ValueError(f"Error: '{model_name}' is not supported")

model = model.to(device)
model.load_state_dict(state_dict)
print(f'Number of encoder + duration predictor parameters: {model.encoder.nparams/1e6}m')
print(f'Number of decoder parameters: {model.decoder.nparams/1e6}m')
print(f'Total parameters: {model.nparams/1e6}m')

# generator ===================
print("[seq] loading HiFiGAN")
vocoder = models.Generator(hifigan_hparams)

vocoder.load_state_dict(torch.load(
    HiFiGAN_ckpt, map_location=device)['generator'])
vocoder = vocoder.eval().to(device)
vocoder.remove_weight_norm()

print("loading UTMOS ===================================")
predictor_utmos = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)


model_name: gradtfk5tts
[seq] loading Model
model_step: 100
model_temp: 1.5
loading  logs4model/gradtfk5tts/run_tfk_k5/ckpt/gradtfk5tts_820_651082.pt
[seq] Initializing diffusion-TTS...
Number of encoder + duration predictor parameters: 3.549137m
Number of decoder parameters: 2.044639m
Total parameters: 5.593776m
[seq] loading HiFiGAN
Removing weight norm...


Using cache found in /work/sora-sa/.cache/torch/hub/tarepan_SpeechMOS_v1.2.0


In [166]:
N_STEP = 100
TEMP = 1.5

# temp infer
i = 56
infer_data_num = 5

In [167]:


for i in tqdm(range(infer_data_num)):
    # temp infer
    i = 56
    print(i)
    print('[seq]text2mel')
    text = test_ds_list[i]['text']
    phonemes, phnid = text2phnid(text, phn2id, 'en')
    phonemes_len_int = len(phonemes)
    phnid_len_int = len(phnid)
    print(f'phonemes_len: {phonemes_len_int}')
    print(f'phnid_len: {phnid_len_int}')
    phnid_len = torch.tensor(len(phnid), dtype=torch.long).unsqueeze(0).to(device)
    phnid = torch.tensor(phnid).unsqueeze(0).to(device)

    # [seq] synth speech
    # process text to mel
    # mel is [n_mels, n_frame]
    start_time = time.perf_counter()
    _, mel_prediction, _ = model.forward(phnid,
                                     phnid_len,
                                     n_timesteps=N_STEP,
                                     temperature=TEMP,
                                     solver='original')
    end_time = time.perf_counter()

    dt = end_time - start_time
    dt4mel = dt * 22050 / ( mel_prediction.shape[-1] * 256)
    print(f'{model_name} dt: {dt}')
    print(f'{model_name} RTF: {dt4mel}')
    
    # for save mel
    x = mel_prediction.unsqueeze(0) # [batch, channel(freq), n_frame(time)] ex.[1, 80, 619]
    """
    # save
    #mel_npy_path =  RESULT_MEL_DIR_PATH / f"{test_ds_filename}.npy"
    #print(f'test_ds_index_{i}: {mel_npy_path}')
    np.save(mel_npy_path, mel4save.cpu().detach().numpy().copy())
    """

    # [seq]mel2wav =========================================================
    print('[seq]mel2wav')
    """
    x = np.load(mel_npy_path) # [1, n_mel, n_frame]
    """
    x2audio = torch.FloatTensor(x).to(device)
    x2audio = x2audio.squeeze().unsqueeze(0)
    # x2audio is [1, n_mels, n_frames]
    assert x2audio.shape[0] == 1
    with torch.no_grad():
        # vocoder.forward(x).cpu() is torch.Size([1, 1, 167168])
        audio = (vocoder.forward(x2audio).cpu().squeeze().clamp(-1,1).numpy() * 32768).astype(np.int16)
    """
    write(
        synth_wav_path,
        hifigan_hparams.sampling_rate,
        audio)
    """
    # [seq]wav2utmos =========================================================
    print('[seq]wav2utmos')
    #iwav_path = RESULT_WAV_DIR_PATH / f"{filename}.wav"
    #wav, samplerate = torchaudio.load(iwav_path)
    """
    wav, samplerate = torchaudio.load(synth_wav_path)
    """
    audio = torch.from_numpy(audio).unsqueeze(0).to(torch.float32)
    score_utmos = predictor_utmos(audio, 22050)
    score_utmos_float = score_utmos.item()
    print(f'utmos: {score_utmos_float}')
    #eval_dict = {'name': filename, 'path': str(iwav_path), 'utmos': score_float}
    #score_utmos_list.append(eval_dict)
    
    # path, テキスト文、phonimes, phonimes数, dt, RTF, utmos
    """
    eval_dict = {
        'name': test_ds_filename,
        'phonemes_len': phonemes_len_int,
        'phnid_len': phnid_len_int,
        'dt': dt,
        'RTF4mel': dt4mel,
        'utmos': score_utmos_float
    }
    eval_list.append(eval_dict)
    """

  0%|                                                                                                       | 0/5 [00:00<?, ?it/s]

56
[seq]text2mel
phonemes_len: 1016
phnid_len: 193
gradtfk5tts dt: 26.09882480185479
gradtfk5tts RTF: 3.2252011235703133
[seq]mel2wav
[seq]wav2utmos


 20%|███████████████████                                                                            | 1/5 [00:32<02:09, 32.40s/it]

utmos: 3.861962080001831
56
[seq]text2mel
phonemes_len: 1016
phnid_len: 193
gradtfk5tts dt: 28.3867592420429
gradtfk5tts RTF: 3.4929020161107474
[seq]mel2wav
[seq]wav2utmos


 40%|██████████████████████████████████████                                                         | 2/5 [01:01<01:32, 30.68s/it]

utmos: 3.9090278148651123
56
[seq]text2mel
phonemes_len: 1016
phnid_len: 193
gradtfk5tts dt: 27.371107167564332
gradtfk5tts RTF: 3.2473146578253784
[seq]mel2wav
[seq]wav2utmos


 60%|█████████████████████████████████████████████████████████                                      | 3/5 [01:35<01:03, 31.98s/it]

utmos: 3.5535855293273926
56
[seq]text2mel
phonemes_len: 1016
phnid_len: 193
gradtfk5tts dt: 28.324207940138876
gradtfk5tts RTF: 3.4409643042581
[seq]mel2wav
[seq]wav2utmos


 80%|████████████████████████████████████████████████████████████████████████████                   | 4/5 [02:10<00:33, 33.14s/it]

utmos: 4.15135383605957
56
[seq]text2mel
phonemes_len: 1016
phnid_len: 193
gradtfk5tts dt: 28.249231291934848
gradtfk5tts RTF: 3.4464387282398823
[seq]mel2wav
[seq]wav2utmos


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [02:44<00:00, 32.96s/it]

utmos: 4.078372001647949





In [168]:
print(model_name)
print(N_STEP)

gradtfk5tts
100


In [115]:
#4.23,4.20,4.35,4.24
#3.664,4.26,4.04,3.889
#4.34, 4.31,4.33, 4.163
#3.95, 3.923,4.178, 3.765

In [None]:
#tfk
#N=50:3.8072, 3.899, 3.64, 3.597, 3.504
#N=100:3.697, 4.347, 3.898, 3.750, 4.1905
#N=200:3.938, 4.109, 3.854, 3.869, 3.709
#tfk5
#N=50:3.880,3.738, 3.849, 3.290, 3.657
#N=100:3.861, 3.9090, 3.5535, 4.151, 4.078

In [106]:
219419-214649

4770

In [None]:
30000