## Libs

In [1]:
import torch
from torch.autograd import Variable

import numpy as np
from IPython.display import Audio
import matplotlib.pyplot as plt
%matplotlib inline

import models
from tacotron2.text import text_to_sequence
from common.utils import load_wav_to_torch, to_gpu
from common.layers import TacotronSTFT
from hparams import Hyperparameters as hp

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

## Paths to checkpoints

In [3]:
taco_path = '/workspace/output/new_try_2/sm_bl_lj_1st700_anneal/checkpoint_Tacotron2_1300'
wg_path = '/workspace/output/sm_wg/checkpoint_WaveGlow_1750'

## Load models

In [4]:
taco_checkpoint = torch.load(taco_path, map_location='cpu')
wg_checkpoint = torch.load(wg_path, map_location='cpu')

In [7]:
taco_checkpoint['optimizer_state_dict']

{'state': {140194943760136: {'step': 46924,
   'exp_avg': tensor([[ 1.3449e-06, -1.5754e-05, -2.1479e-06,  ...,  3.5996e-06,
             6.0767e-06,  2.0678e-05],
           [-3.2015e-08,  7.3799e-08,  2.3185e-07,  ..., -4.1461e-07,
             5.1538e-08, -9.9704e-07],
           [-1.7068e-07,  5.2330e-07,  6.9569e-08,  ...,  8.8282e-08,
             7.8452e-08, -1.7846e-06],
           ...,
           [-5.6052e-45,  5.6052e-45,  5.6052e-45,  ..., -4.2039e-45,
             5.6052e-45, -5.6052e-45],
           [-5.6052e-45,  5.6052e-45, -5.6052e-45,  ..., -4.2039e-45,
            -5.6052e-45,  5.6052e-45],
           [ 5.6052e-45, -5.6052e-45, -5.6052e-45,  ..., -5.6052e-45,
            -5.6052e-45,  5.6052e-45]]),
   'exp_avg_sq': tensor([[1.7013e-10, 4.2036e-09, 1.5343e-10,  ..., 6.3668e-10, 1.4013e-10,
            3.2214e-08],
           [1.7564e-13, 4.4413e-12, 2.5017e-13,  ..., 6.4007e-13, 2.3757e-13,
            2.9646e-11],
           [2.2412e-13, 4.6623e-12, 2.6825e-13,  ...,

In [None]:
t2 = models.get_model('Tacotron2', taco_checkpoint['config'], to_cuda=True)
wg = models.get_model('WaveGlow', wg_checkpoint['config'], to_cuda=True)

In [None]:
for model, checkpoint in [(t2, taco_checkpoint), (wg, wg_checkpoint)]:
    new_state_dict = {}
    for key, value in checkpoint['state_dict'].items():
        new_key = key.replace('module.', '')
        new_state_dict[new_key] = value

    model.load_state_dict(new_state_dict)

In [None]:
t2.eval()
wg.eval()
print('Done')

## Set speaker and text

In [None]:
text = "hello. how are you today?"
speaker_id = 1

## Select inference type

In [None]:
#inf_type = 'ref'

### Reference audio

In [None]:
#ref_audio = '/workspace/training_data/blizzard_2013/wavs/CA-MP2-03-013.wav'

#### Listen to ref Audio

In [None]:
#Audio(ref_audio, rate=hp.sampling_rate)

### Or GST token

In [None]:
style_token = None

In [None]:
# if inf_type == 'ref':
#     stft = TacotronSTFT(
#         hp.filter_length, hp.hop_length, hp.win_length,
#         hp.n_mel_channels, hp.sampling_rate, hp.mel_fmin,
#         hp.mel_fmax
#     )

#     audio, sampling_rate = load_wav_to_torch(ref_audio)

#     if sampling_rate != stft.sampling_rate:
#         raise ValueError("{} {} SR doesn't match target {} SR".format(
#             sampling_rate, stft.sampling_rate))

#     audio_norm = audio / hp.max_wav_value
#     audio_norm = audio_norm.unsqueeze(0)
#     audio_norm = Variable(audio_norm, requires_grad=False)
#     ref_mel = stft.mel_spectrogram(audio_norm)
#     ref_mel = torch.squeeze(ref_mel, 0)

#     ref_mel = ref_mel.unsqueeze(0)
    
#     ref_mel = to_gpu(ref_mel)
# elif inf_type == 'token':
#     pass

## Infer

In [None]:
inputs = np.array(text_to_sequence(text, ['english_cleaners']))[None, :]
inputs = torch.from_numpy(inputs).to(device='cuda', dtype=torch.int64)
#input_lengths = torch.IntTensor([inputs.size(1)]).cuda().long()
speaker_id = torch.IntTensor([speaker_id]).cuda().long()

In [None]:
embedded_speaker = t2.speakers_embedding(speaker_id)

In [None]:
embedded_speaker

In [None]:
with torch.no_grad():
    _, mel, _, _ = t2.infer(inputs, speaker_id)
    audio = wg.infer(mel)

In [None]:
plt.imshow(mel.squeeze(0).detach().cpu().numpy())

In [None]:
audio_numpy = audio[0].data.cpu().numpy()
rate = 22050

In [None]:
Audio(audio_numpy, rate=rate)