In [None]:
import torch
import torchaudio
import matplotlib.pyplot as plt
import IPython

torch.random.manual_seed(0)
if torch.cuda.is_available():
    device = "cuda"
else:
    device = 'CPU'
print(f'Using {device}')
print(torch.__version__)
print(torchaudio.__version__)

## Text Processing
### Character-based encoding

In [None]:
bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
processor = bundle.get_text_processor()
tacotron2 = bundle.get_tacotron2().to(device)
vocoder = bundle.get_vocoder().to(device)

# text = "Hello World, this is the new age of A.I. and Python!"
# text = "Hello"
text = """Pride and Prejudice is set in rural England at the turn of the 19th century, and it follows the Bennet family, which includes five very different sisters. The eldest, Jane, is sweet-tempered and modest. She is her sister Elizabeth’s confidant and friend. Elizabeth, the heroine of the novel, is intelligent and high-spirited. She shares her father’s distaste for the conventional views of society as to the importance of wealth and rank. The third daughter, Mary, is plain, bookish, and pompous, while Lydia and Kitty, the two youngest, are flighty and immature."""

with torch.inference_mode():
    processed, lengths = processor(text)
    processed = processed.to(device)
    lengths = lengths.to(device)
    spec, spec_lengths, _ = tacotron2.infer(processed, lengths)
    # waveforms, lengths = vocoder(spec, spec_lengths)
print(processor, lengths)

In [None]:
def plot(wfms, spec, sample_rate):
    wfms = wfms.cpu().detach()
    fig, [ax1, ax2] = plt.subplots(2, 1)
    ax1.plot(wfms[0])
    ax1.set_xlim(0, wfms.size(-1))
    ax1.grid(True)

    ax2.imshow(spec[0].cpu().detach(), origin="lower", aspect='auto')
    return IPython.display.Audio(wfms[0:1], rate=sample_rate)

In [None]:
plot(waveforms, spec, vocoder.sample_rate)

In [None]:
# Workaround to load model mapped on GPU
# https://stackoverflow.com/a/61840832
waveglow = torch.hub.load(
    "NVIDIA/DeepLearningExamples:torchhub",
    "nvidia_waveglow",
    model_math="fp16",
    pretrained=False,
)
checkpoint = torch.hub.load_state_dict_from_url(
    "https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth",  # noqa: E501
    progress=False,
    map_location=device,
)
state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()}

waveglow.load_state_dict(state_dict)
waveglow = waveglow.remove_weightnorm(waveglow)
waveglow = waveglow.to(device)
waveglow.eval()

with torch.no_grad():
    waveforms = waveglow.infer(spec, )

In [None]:
plot(waveforms, spec, vocoder.sample_rate)

In [None]:
plot()