## Representation Mixing 
The representation mixing method is described in detail in this [arxiv paper](https://arxiv.org/abs/1811.07240), and sample comparison to baselines can be seen [here](https://s3.amazonaws.com/representation-mixing-site/index.html). 

This notebook shows a simple example of inference using a model trained with representation mixing, and the specifics of choosing what type of input and input mask to feed the model.

Models were trained on the [LJSpeech dataset](https://keithito.com/LJ-Speech-Dataset/). Special thanks to Ryuichi Yamamoto for inspiration to make a Colab notebook demo based on the [Tacotron 2 + WaveNet example](https://colab.research.google.com/github/r9y9/Colaboratory/blob/master/Tacotron2_and_WaveNet_text_to_speech_demo.ipynb) as part of the [blog post](https://r9y9.github.io/blog/2018/05/20/tacotron2/), and also for the pretrained conditional WaveNet used in the paper.

**Total runtime for two minimal sentences, fast output:** ~3 minutes


**Total runtime for two minimal sentences, WaveGlow output:** ~6 minutes

## Setup

In [0]:
# Author: Kyle Kastner
# License: BSD 3-Clause

import os
from os.path import exists, join, expanduser
import IPython
from IPython.display import Audio
import matplotlib.pyplot as plt
plt.style.use('classic')


os.chdir(os.path.expanduser("~"))

representation_mixing_dir = "representation_mixing"
if not os.path.exists(representation_mixing_dir):
  ! git clone https://github.com/kastnerkyle/$representation_mixing_dir


In [0]:
# Install dependencies - use the default TF for now, but if we get version errors 1.6.0 is 
# what I used to train everything
#! pip uninstall tensorflow-gpu
#! pip uninstall tensorflow
#! pip install -q --upgrade "tensorflow<=1.6.0"
! pip install -q --upgrade "unidecode"

In [0]:
os.chdir(representation_mixing_dir)
os.chdir("pretrained")

##Imports and Model Download

Here we setup necessary imports, then download and unpack the pretrained model if it isn't already present.

In [0]:
import tensorflow as tf
import numpy as np
from collections import namedtuple
import sys
import matplotlib.pyplot as plt
import copy
import time
from audio import soundsc
from audio import stft
from audio import iterate_invert_spectrogram
from transform_text import transform_text
from transform_text import inverse_transform_text
from scipy.io import wavfile


dl_link = "https://www.dropbox.com/s/domlnxdk7wetqr3/representation_mixing_model.tar.gz?raw=1"
model_path = "model_final_blended_maskfix_dirtyattnandhidconnect_softplusstep_1scale_pt925drop/models/model-512000"

if not os.path.exists(model_path.split(os.sep)[0]):
    print("Downloading model from {}".format(dl_link))
    import urllib
    urllib.urlretrieve (dl_link, "representation_mixing_model.tar.gz")
    import tarfile
    tar = tarfile.open("representation_mixing_model.tar.gz", "r:gz")
    tar.extractall()
    tar.close()


## Hyperparameters and Configs

Now we set up general configs for pretrained model.

In [0]:
random_state = np.random.RandomState(1999)
# change this if you change the runtime type
# CPU sampling is MUCH faster than GPU here, at least in my tests
#config = tf.ConfigProto(
#    device_count={'GPU': 0}
#)
config = None

# The following hyperparameters shouldn't be changed, they match pretrained model
# and data settings.
batch_size = 64
seq_len = 256
window_mixtures = 10
enc_units = 128
dec_units = 512
emb_dim = 15
sample_rate = 22050
window_size = 512
step = 128
n_mel = 80

# These hyperparameters can be modified
# For example try sonify_iter = 1000 for higher quality output
sonify_iter = 100
gl_window = 512
gl_step = 32
gl_iter = 100


# dummy batch sizes from validation iterator in training code
# mels = (256, 64, 80)
# mel_mask = (256, 64)
# text = (145, 64, 1)
# text_mask = (145, 64)
# mask = (145, 64)
# mask_mask = (145, 64)
# reset = (64, 1)
mels = np.zeros((seq_len, batch_size, n_mel))
mel_mask = np.ones_like(mels[..., 0])
text = np.zeros((145, batch_size, 1))
text_mask = np.ones_like(text[..., 0])
mask = np.ones((145, batch_size))
mask_mask = np.ones_like(mask)
reset = np.zeros((batch_size, 1))


# pre-calculated per-dimension mean and std for mel data from the training iterator, read from a file
d = np.load("norm-mean-std-txt-cleanenglish_cleanersenglish_phone_cleaners-logmel-wsz512-wst128-leh125-ueh7800-nmel80.npz")
saved_std = d["std"]
saved_mean = d["mean"]

## Input Text to Synthesize
Enter whatever sentences you like in the lines below (below 'cat << EOS > sample_lines.txt
' but above 'EOS'), I put some minimal and boring examples for demonstration. Sampling time will be longer for more sentences, and also for longer sentences. You can put up to 64 sentences if you want, but it will take longer to sample.

In [0]:
%%bash
cat << EOS > sample_lines.txt
The cat ate bread.
Cumulonimbus clouds are not dead.
EOS

cat sample_lines.txt

## Representation Mixing for Inference
Now we need to choose the input representation to use for representation mixing.

Lines like ''@dh@ah @k@ae@t @ey@t @b@r@eh@d", using the "@" symbol represent phonemes, and should have an all 1s mask.

Lines like "the cat ate bread." represent text, and should have an all 0s mask.

Lines like "the @k@ae@t ate @b@r@eh@d" are mixed, and would have a mask made of both 0s and 1s.

You can make these manually, or use the ```transform_text``` function with the ```symbol_processing``` argument set to create the sentence and the mask automatically (using cmudict and a custom function for pronunciation).

```"blended_pref"``` symbol processing is denoted **PWCB** in the paper, and means to use CMUDict pronunciations wherever possible, but fall back to characters if no pronunciation is found (for example, Cumulonimbus in the demo sentence). Note that the first found entry of CMUDict is used for pronunciation, this can lead to issues with words like "desert" and "wind".

In [0]:
# Uncomment these lines if you want to see the char and phone symbols
#from symbols import char_symbols
#print(char_symbols)
#from symbols import phone_symbols
#print(["@" + p for p in phone_symbols])

In [0]:
#symbol_processing = "chars_only"
#symbol_processing = "phones_only"
symbol_processing = "blended_pref"

pre_lines = []
post_lines = []
int_lines = []
masks = []

with open("sample_lines.txt", "r") as f:
    lines = f.readlines()
lines = [l.strip() for l in lines]
print("lines to sample, format 'before : after transformation'")
for l in lines:
    if "." in l:
        # The training set rarely has multi-sentence inputs
        # thus having a period in the middle causes some issues
        # a simple fix is to replace periods in the middle of the input with
        # commas instead
        lm1 = l[:-1].replace(".", ",")
        l = lm1 + l[-1]
    pre_lines.append(l)
    tt, mm = transform_text(l, symbol_processing=symbol_processing)
    ot = inverse_transform_text(tt, mm)
    # cut off eos
    int_lines.append(tt[:-1])
    masks.append(mm[:-1])
    post_lines.append(ot[:-1])
    print("{} : {}".format(l, ot[:-1]))

## Sample from the attention model

With the inputs and masks formatted, it is time to actually sample from the model. This cell will sample until the attention center reaches the end of the sequence for every input sentence, then run for a little more to ensure the end is not cut off (~.2 seconds).

This should be fairly quick (~1 minute) if you are using the demo inputs. Runtime will scale with the length of the longest sentence since we sample a batch all together.

In [0]:
print_every = 20

if len(int_lines) > batch_size:
    raise ValueError("More lines to test ({}) than batch_size ({}), needs multi-batch support!".format(len(int_lines), len(batch_size)))
longest = max([len(il) for il in int_lines])
text = np.zeros((longest, batch_size, 1))
text_mask = np.zeros((longest, batch_size))
mask = np.zeros((longest, batch_size, 1))
mask_mask = np.zeros((longest, batch_size))
for n, il in enumerate(int_lines):
    text[:len(il), n, 0] = il
    text_mask[:len(il), n] = 1.
    mask[:len(il), n, 0] = masks[n]
    mask_mask[:len(il), n] = 1.
n_to_sample = len(int_lines)

start_time = time.time()
with tf.Session(config=config) as sess:
    saver = tf.train.import_meta_graph(model_path + '.meta')
    saver.restore(sess, model_path)
    fields = ["mels",
              "mel_mask",
              "in_mels",
              "in_mel_mask",
              "out_mels",
              "out_mel_mask",
              "text",
              "text_mask",
              "mask",
              "mask_mask",
              "bias",
              "cell_dropout",
              "prenet_dropout",
              "bn_flag",
              "pred",
              "att_w_init",
              "att_k_init",
              "att_h_init",
              "att_c_init",
              "h1_init",
              "c1_init",
              "h2_init",
              "c2_init",
              "att_w",
              "att_k",
              "att_phi",
              "att_h",
              "att_c",
              "h1",
              "c1",
              "h2",
              "c2"]
    vs = namedtuple('Params', fields)(
        *[tf.get_collection(name)[0] for name in fields]
    )
    att_w_init = np.zeros((batch_size, 2 * enc_units))
    att_k_init = np.zeros((batch_size, window_mixtures))
    att_h_init = np.zeros((batch_size, dec_units))
    att_c_init = np.zeros((batch_size, dec_units))
    h1_init = np.zeros((batch_size, dec_units))
    c1_init = np.zeros((batch_size, dec_units))
    h2_init = np.zeros((batch_size, dec_units))
    c2_init = np.zeros((batch_size, dec_units))

    # zero out the mel storage to ensure no leakage
    in_mels = 0. * mels[:1]
    in_mel_mask = 0. * mel_mask[:1] + 1.

    preds = []
    att_ws = []
    att_phis = []
    is_finished_sampling = [False] * n_to_sample
    finished_at = 100000000
    finished_step = [-1] * n_to_sample

    ii = 0

    # add ~.2 sec to the end to ensure it doesn't cut off early
    # sample_rate * .2 / fft_step
    # min_part to account for the last window
    min_part = window_size / float(sample_rate)
    extra_steps = max(0, int((sample_rate * (.2 - min_part)) / float(step)))
    keep_printing = True
    while True:
        if ii % print_every == 0 and keep_printing:
            print("pred step {}".format(ii))
        feed = {
                vs.in_mels: in_mels,
                vs.in_mel_mask: in_mel_mask,
                vs.bn_flag: 1.,
                vs.text: text,
                vs.text_mask: text_mask,
                vs.mask: mask,
                vs.mask_mask: mask_mask,
                vs.cell_dropout: 1.,
                vs.att_w_init: att_w_init,
                vs.att_k_init: att_k_init,
                vs.att_h_init: att_h_init,
                vs.att_c_init: att_c_init,
                vs.h1_init: h1_init,
                vs.c1_init: c1_init,
                vs.h2_init: h2_init,
                vs.c2_init: c2_init}
        outs = [vs.att_w, vs.att_k,
                vs.att_h, vs.att_c,
                vs.h1, vs.c1, vs.h2, vs.c2,
                vs.att_phi, vs.pred]
        r = sess.run(outs, feed_dict=feed)
        att_w_np = r[0]
        att_k_np = r[1]
        att_h_np = r[2]
        att_c_np = r[3]
        h1_np = r[4]
        c1_np = r[5]
        h2_np = r[6]
        c2_np = r[7]
        att_phi_np = r[8]
        pred_np = r[9]

        ii += 1
        max_text = max([text_mask[:, mbi].sum() for mbi in range(n_to_sample)])
        if ii > 30 * max_text:
            # it's gone too far, kill
            finished_step = [int(30 * max_text)] * n_to_sample
            print("Exceeded 30 * max text length of {},  terminating...".format(max_text))
            break

        att_ws.append(att_w_np[0])
        att_phis.append(att_phi_np[0])
        preds.append(pred_np[0])

        # set next inits and input values
        in_mels[0] = pred_np
        att_w_init = att_w_np[-1]
        att_k_init = att_k_np[-1]
        att_h_init = att_h_np[-1]
        att_c_init = att_c_np[-1]
        h1_init = h1_np[-1]
        c1_init = c1_np[-1]
        h2_init = h2_np[-1]
        c2_init = c2_np[-1]

        for mbi in range(n_to_sample):
            last_sym = int(text_mask[:, mbi].sum()) - 1
            if np.argmax(att_phi_np[0, mbi]) >= last_sym or np.argmax(att_phi_np[0, mbi]) == text_mask.shape[0]:
                if is_finished_sampling[mbi] == False:
                    is_finished_sampling[mbi] = True
                    finished_step[mbi] = ii

        if all(is_finished_sampling) and keep_printing:
            keep_printing = False
            print("All samples finished at step {}".format(finished_at))
            print("Extra padding {} finishing at {}".format(extra_steps, ii + extra_steps))
        elif keep_printing:
            # should assign until all are finished
            finished_at = ii

        if ii > (finished_at + extra_steps):
            print("Extra padding {} finished at step {}".format(extra_steps, ii))
            break
end_time = time.time()
preds = np.array(preds)
att_ws = np.array(att_ws)
att_phis = np.array(att_phis)
print("Total time spent in RNN sampling {}".format(end_time - start_time))

## Visualizing the Output
What does the output from the model look like? We can plot both the attention, and the mel output to see what is going on.

First, we define a convenience plotter function.

In [0]:
def implot(arr, axarr, axis_off=True, scale=None, title="", interpolation=None,
           cmap=None, autoaspect=True):
    mag = arr
    # Transpose so time is X axis, and invert y axis so
    # frequency is low at bottom
    mag = mag.T
    if interpolation == None:
        pltr = axarr.matshow
    else:
        pltr = axarr.imshow
    if cmap != None:
        pltr(mag, cmap=cmap, origin="lower", interpolation=interpolation)
    else:
        pltr(mag, origin="lower", interpolation=interpolation)
    if axis_off:
        plt.axis("off")
    
    if autoaspect:
        x1 = mag.shape[0]
        y1 = mag.shape[1]
        if scale == "specgram":
            y1 = int(y1 * .20)

        def _autoaspect(x_range, y_range):
            """
            The aspect to make a plot square with ax.set_aspect in Matplotlib
            """
            mx = max(x_range, y_range)
            mn = min(x_range, y_range)
            if x_range <= y_range:
                return mx / float(mn)
            else:
                return mn / float(mx)
        asp = _autoaspect(x1, y1)
        axarr.set_aspect(asp)
    plt.title(title)

Now we can plot the attention, and the mel spectrogram. Change  ```sentence``` to change which example (from your defined sentences) is viewed.


In [0]:
sentence = 1

format_lines = []
for pl in post_lines:
    parts = pl.split(" ")
    chunks = [["@" + pi for pi in p.split("@")[1:]] if "@" in p else p for p in parts]
    fill = [" "] * len(chunks)
    out = []
    for ii in range(len(chunks)):
        out.append(chunks[ii])
        out.append([fill[ii]])
    # cut trailing space
    out = out[:-1]
    out_flat = [item for sublist in out for item in sublist]
    format_lines.append(out_flat)

spectrogram = preds[:, sentence] * saved_std + saved_mean
f, axarr = plt.subplots(1, 1)
implot(spectrogram[:finished_step[sentence]], axarr, axis_off=False,
       scale="specgram", cmap="viridis")
axarr.set_ylabel("Mel-Freq Bin")
axarr.set_xlabel("Time (sample frames)")
plt.figure()

phi_i = att_phis[:, sentence]
f, axarr = plt.subplots(1, 1)
implot(phi_i[:finished_step[sentence], :len(masks[sentence])], axarr,
       axis_off=False, cmap="gray")
axarr.set_ylabel("Symbols")
axarr.set_xlabel("Time (sample frames)")
axarr.yaxis.set_major_locator(plt.FixedLocator(range(len(masks[sentence]))))
ticks = axarr.set_yticklabels([c for c in format_lines[sentence]])
plt.show()

## Audio Utilities
This cell defines two convenience functions to approximately convert the mel spectrogram into a waveform.

In [0]:
def logmel(waveform):
    z = tf.contrib.signal.stft(waveform, window_size, step)
    magnitudes = tf.abs(z)
    filterbank = tf.contrib.signal.linear_to_mel_weight_matrix(
        num_mel_bins=n_mel,
        num_spectrogram_bins=magnitudes.shape[-1].value,
        sample_rate=sample_rate,
        lower_edge_hertz=125.,
        upper_edge_hertz=7800.)
    melspectrogram = tf.tensordot(magnitudes, filterbank, 1)
    return tf.log1p(melspectrogram)

def sonify(spectrogram, samples, transform_op_fn, logscaled=True):
    # Very special thanks to Carl Thome for showing this technique
    # https://twitter.com/carlthome/status/1002187555700396035
    # All credit to him, any bugs are my own
    graph = tf.Graph()
    with graph.as_default():

        noise = tf.Variable(tf.random_normal([samples], stddev=1e-6))

        x = transform_op_fn(noise)
        y = spectrogram

        if logscaled:
            x = tf.expm1(x)
            y = tf.expm1(y)

        # tf.nn.normalize arguments changed between versions...
        def normalize(a):
            return a / tf.sqrt(tf.maximum(tf.reduce_sum(a ** 2, axis=0), 1E-12))

        x = normalize(x)
        y = normalize(y)
        tf.losses.mean_squared_error(x, y[-tf.shape(x)[0]:])

        optimizer = tf.contrib.opt.ScipyOptimizerInterface(
            loss=tf.losses.get_total_loss(),
            var_list=[noise],
            tol=1e-16,
            method='L-BFGS-B',
            options={
                'maxiter': sonify_iter,
                'disp': False
            })

    # THIS REALLY SHOULDN'T RUN ON GPU BUT SEEMS TO?
    config = tf.ConfigProto(
        device_count={'CPU' : 1, 'GPU' : 0},
        allow_soft_placement=True,
        log_device_placement=False
        )
    with tf.Session(config=config, graph=graph) as session:
        session.run(tf.global_variables_initializer())
        optimizer.minimize(session)
        waveform = session.run(noise)
    return waveform

## Two-stage L-BFGS + GL Waveforms

With the utility code defined above, we can run the two-stage sampling pipeline. 

This will take ~1 minute for the demo sentences, and will take longer for more sentences. In theory this could be parallelized over multiple cores, or machines but the simple version here is easier to work with.

In [0]:
pre_time = 0
post_time = 0
joint_time = 0
for jj in range(n_to_sample):
    # use extra steps from earlier
    pjj = preds[:(finished_step[jj] + extra_steps), jj]
    spectrogram = pjj * saved_std + saved_mean
    mel_dump = "sample_{}_mels.npz".format(jj)
    np.savez(mel_dump, mels=spectrogram)
    prename = "sample_{}_pre.wav".format(jj)

    this_time = time.time()

    reconstructed_waveform = sonify(spectrogram, len(spectrogram) * step, logmel)
    end_this_time = time.time()
    wavfile.write(prename, sample_rate, soundsc(reconstructed_waveform))
    elapsed = end_this_time - this_time

    print("Elapsed pre sampling time {} s".format(elapsed))
    pre_time += elapsed
    joint_time += elapsed

    fftsize = gl_window
    substep = gl_step
    postname = "sample_{}_post.wav".format(jj)
    this_time = time.time()

    rw_s = np.abs(stft(soundsc(reconstructed_waveform).astype("float64"), fftsize=fftsize, step=substep, real=False,
                       compute_onesided=False))
    rw = iterate_invert_spectrogram(rw_s, fftsize, substep, n_iter=gl_iter, verbose=False)
    end_this_time = time.time()
    wavfile.write(postname, sample_rate, soundsc(rw))
    elapsed = end_this_time - this_time

    print("Elapsed post sampling time {} s".format(elapsed))
    post_time += elapsed
    joint_time += elapsed
print("Combined pre time {} s".format(pre_time))
print("Combined post time {} s".format(post_time))
print("Combined joint time {} s".format(joint_time))

## L-BFGS + GL Audio Samples

Now we can listen to the samples, after the two-stage pipeline. Not bad! But perhaps we can do better...

In [0]:
with open("sample_lines.txt", "r") as f:
    lines = f.readlines()
lines = [l.strip() for l in lines]

def sort(files):
    return sorted(files, key=lambda k: int(k.split("_")[1]))
    
mel_files = sort([f for f in os.listdir(".") if "_mels.npz" in f])
audio_files = sort([f for f in os.listdir(".") if "_post.wav" in f])          
maps = zip(lines, mel_files[:len(lines)], audio_files[:len(lines)])

for idx, (text, mel, audio) in enumerate(maps):
  print(idx, text)
  IPython.display.display(Audio(audio, rate=22050))

##BONUS: WaveGlow Sampling

After the paper on representation mixing was finalized, NVidia released [WaveGlow](https://github.com/NVIDIA/waveglow). Alongside, [FloWaveNet](https://github.com/ksw0306/FloWaveNet) (a very similar idea) was also released.

These are both *much* faster ways to create high quality audio than a standard autoregressive conditional WaveNet, so we demonstrate output with this instead.

However the quality is currently not quite as good as WaveNet, perhaps because WaveGlow is more sensitive to the quality and scale of the input conditioning and wasn't trained on the outputs of my model. Specifically, there seems to be some kind of distortion after WaveGlow sampling when going directly from the waveforms heard above. 

This will need more work, but take this as a simple demonstration of using WaveGlow at least.

Ultimately, the choice of "neural vocoder" is a design choice, which is largely independent of the frontend - a frontend trained with representation mixing should be compatible with nearly any "neural vocoder" or DSP inversion routine.

In [0]:
if not os.path.exists("waveglow"):
    !git clone https://github.com/NVIDIA/waveglow.git
    os.chdir("waveglow")
    !git submodule init
    !git submodule update
    os.chdir("..")

In [0]:
! pip3 install librosa
# this will be 1.0 once there is a pip-able version
! pip3 install torch_nightly -f https://download.pytorch.org/whl/nightly/cu92/torch_nightly.html


In [0]:
waveglow_model = "https://drive.google.com/file/d/1cjKPHbtAMh_4HTHmuIGNkbOkPBD9qwhj/view?usp=sharing"
waveglow_model_id = waveglow_model.split("//")[1].split("/")[3]
waveglow_mels = "https://drive.google.com/file/d/1g_VXK2lpP9J25dQFhQwx7doWl_p20fXA/view?usp=sharing"
waveglow_mels_id = waveglow_mels.split("//")[1].split("/")[3]

In [0]:
# from https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive/39225039#39225039
import requests

def download_file_from_google_drive(id, destination):
    def get_confirm_token(response):
        for key, value in response.cookies.items():
            if key.startswith('download_warning'):
                return value

        return None

    def save_response_content(response, destination):
        CHUNK_SIZE = 32768

        with open(destination, "wb") as f:
            for chunk in response.iter_content(CHUNK_SIZE):
                if chunk: # filter out keep-alive new chunks
                    f.write(chunk)

    URL = "https://docs.google.com/uc?export=download"

    session = requests.Session()

    response = session.get(URL, params = { 'id' : id }, stream = True)
    token = get_confirm_token(response)

    if token:
        params = { 'id' : id, 'confirm' : token }
        response = session.get(URL, params = params, stream = True)

    save_response_content(response, destination)    

In [0]:
os.chdir("waveglow")
if not os.path.exists("waveglow_old.pt"):
    download_file_from_google_drive(waveglow_model_id, "waveglow_old.pt")
os.chdir("..")

In [0]:
os.chdir("waveglow")
# test example, from the NVidia repo
if not os.path.exists("mel_spectrograms.zip"):
    download_file_from_google_drive(waveglow_mels_id, "mel_spectrograms.zip")
    ! unzip mel_spectrograms.zip
os.chdir("..")

## Sampling with WaveGlow
Do a bit of name cleanup, and move the mel spectrograms into the ``samples_temp`` directory.

There are many warnings about `'nn.functional.tanh is deprecated. Use torch.tanh instead.
'` , but for now we can ignore them.

In [0]:
os.chdir("waveglow")
# bug? requires train files.txt
!touch train_files.txt
!rm *.wav
!python3 mel2samp.py -f <(ls ../sample*.wav) -o . -c config.json
!mkdir -p samples_temp
!rm samples_temp/*.pt
!for qq in *.wav.pt; do echo $(basename "$qq" .wav.pt); mv "$qq" samples_temp/$(basename "$qq" .wav.pt).pt; done
!python3 inference.py -f <(ls samples_temp/*.pt) -w waveglow_old.pt -o . --is_fp16 -s 0.6
os.chdir("..")

## WaveGlow Audio Samples
With this complete, we can listen to the samples after WaveGlow synthesis. There is some distortion, which may be due to the pretrained WaveGlow not being trained on this output directly.

In [0]:
print("WaveGlow samples")
with open("sample_lines.txt", "r") as f:
    lines = f.readlines()
lines = [l.strip() for l in lines]

def sort(files):
    return sorted(files, key=lambda k: int(k.split("_")[1]))
    
audio_files = sort(["waveglow/" + f for f in os.listdir("waveglow") if "post" in f])          
maps = zip(lines, audio_files[:len(lines)])

for idx, (text, audio) in enumerate(maps):
  print(idx, text)
  IPython.display.display(Audio(audio, rate=22050))

If you made it this far, congratulations! Have fun with the notebook, and thanks for reading.

kk