# Model inference

Create a new Python 3 (PyTorch 1.4 Python 3.6 GPU Optimized) kernel for this notebook on Amazon SageMaker Studio.

#### Import libraries

In [None]:
import sys

sys.path.append('tacotron2')
sys.path.append('tacotron2/waveglow')

import numpy as np
import torch

%matplotlib inline
from matplotlib import pylab as plt

import IPython.display as ipd

from hparams import create_hparams
from text import text_to_sequence
from denoiser import Denoiser

#### Define utilities

In [None]:
def plot_data(data, figsize = (16, 4)):
    fig, axes = plt.subplots(1, len(data), figsize = figsize)
    
    for i in range(len(data)):
        axes[i].imshow(data[i], aspect = 'auto', origin = 'bottom', interpolation = 'none')

#### Setup parameters

In [None]:
tacotron2_checkpoint_path = "tacotron2/outdir/tacotron2_statedict.pt"
waveglow_checkpoint_path = "tacotron2/waveglow/outdir/waveglow_256channels.pt"

parameters = create_hparams()
parameters.sampling_rate = 22050

#### Load Tacotron2 model

In [None]:
tacotron2_model = load_model(parameters)
tacotron2_model.load_state_dict(torch.load(tacotron2_checkpoint_path)['state_dict'])
tacotron2_model.cuda().eval().half()

#### Load WaveGlow model

In [None]:
waveglow_model = torch.load(waveglow_checkpoint_path)['model']
waveglow_model.cuda().eval().half()

for k in waveglow_model.convinv:
    k.float()

denoiser = Denoiser(waveglow_model)

#### Prepare input text

In [None]:
text = "Waveglow is really awesome!"

sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :]
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()

#### Decode input text and plot results

In [None]:
mel_outputs, mel_outputs_postnet, _, alignments = tacotron2_model.inference(sequence)

plot_data((mel_outputs.float().data.cpu().numpy()[0],
           mel_outputs_postnet.float().data.cpu().numpy()[0],
           alignments.float().data.cpu().numpy()[0].T))

#### Synthesize audio from spectrogram using WaveGlow

In [None]:
with torch.no_grad():
    audio = waveglow_model.infer(mel_outputs_postnet, sigma = 0.666)

ipd.Audio(audio[0].data.cpu().numpy(), rate = hparams.sampling_rate)

#### Remove WaveGlow bias

In [None]:
denoised_audio = denoiser(audio, strength = 0.01)[:, 0]
ipd.Audio(audio_denoised.cpu().numpy(), rate = hparams.sampling_rate)