# Alignment on LJSpeech Dataset

## Inspection

!du -sh ../..
!df -h

!ls ../..

In [None]:
%pip install pandas -q
%pip install tqdm -q
%pip install ipywidgets -q

In [None]:
import os
import IPython.display as ipd

import tensorflow as tf

%load_ext autoreload
%autoreload 2

from data_readers.ljspeech_reader import LJSpeechReader  # noqa

# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
for device in gpu_devices:
    tf.config.experimental.set_memory_growth(device, True)
gpu_devices

In [None]:
using_colab = False
if using_colab:
    !wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
    !tar -xf LJSpeech-1.1.tar.bz2 --checkpoint=.5000
    !rm LJSpeech-1.1.tar.bz2

In [None]:
ljs_file = r'LJSpeech-1.1'
ljsr = LJSpeechReader(ljs_file)
gen = ljsr.generate_audios()

In [None]:
cur_id, cur_txt, cur_audio, sr = next(gen)
s = LJSpeechReader.serialize(cur_id, cur_txt, cur_audio, sr)
cur_id, cur_txt, cur_audio, sr = LJSpeechReader.deserialize(s)

In [None]:
print(cur_id, cur_txt)
ipd.Audio(cur_audio[:, 0].numpy(), rate=sr.numpy())

# Write tfrecords file

In [None]:
ljstfrecords = 'ljspeech.tfrecords'
if not os.path.isfile(ljstfrecords):
    ljsr.write_tfrecords_file(ljstfrecords)

In [None]:
dataset = tf.data.TFRecordDataset(
    'ljspeech.tfrecords'
).map(LJSpeechReader.deserialize)

In [None]:
sample = [x for x in dataset.skip(5).take(1)][0]
cur_id, cur_txt, cur_audio, sr = sample[0], sample[1], sample[2], sample[3]

print(cur_id, cur_txt, sr.numpy())
ipd.Audio(cur_audio[:, 0].numpy(), rate=sr.numpy())

## Inspect models

In [None]:
from models.alignment_model import PraticantoForcedAligner  # noqa
from models import alignment_losses  # noqa

In [None]:
pfa = PraticantoForcedAligner(vocab=ljsr.tokens, sampling_rate=22050)
alignment_model = pfa.build_models()
alignment_model.summary()

In [None]:
tf.expand_dims(cur_txt, axis=0).shape, tf.expand_dims(cur_audio[:, 0], axis=0).shape

In [None]:
alignment_model([
    tf.expand_dims(tf.strings.bytes_split(cur_txt), axis=0),
    tf.expand_dims(cur_audio[:, 0], axis=0)
]).shape

In [None]:
sample = 'This is my text and it is quite long'
char_input = tf.expand_dims(tf.strings.bytes_split(sample), axis=0)
audio = tf.zeros((1, sr * 2))
out = alignment_model([char_input, audio])
out.shape

In [None]:
t = tf.convert_to_tensor([
    [
        [1, 1, 1,   0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0,   1, 1, 0, 0, 0, 0, 0],
        [0, 0, 0,   0, 0, 1, 1, 0, 0, 0],
        [0, 0, 0,   0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0,   0, 0, 0, 0, 0, 0, 0],
    ],
    [
        [1, 1, 1,   0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0,   0, 0, 1, 1, 0, 0, 0],
        [0, 0, 0,   1, 1, 0, 0, 0, 0, 0],
        [0, 0, 0,   0, 0, 0, 0, 1, 1, 0],
        [0, 0, 0,   0, 0, 0, 0, 0, 0, 1],
    ],
    [
        [1, 1, 1, 0.2, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0.8, 1, 0, 0, 0, 0, 0],
        [0, 0, 0,   0, 0, 1, 1, 0, 0, 0],
        [0, 0, 0,   0, 0, 0, 0, 1, 1, 0],
        [0, 0, 0,   0, 0, 0, 0, 0, 0, 1],
    ],
    [
        [1, 1, 1,   0, 0, 0, 0, 0, 0, 0.0],
        [0, 0, 0,   1, 1, 0, 0, 0, 0, 0.0],
        [0, 0, 0,   0, 0, 1, 1, 0, 1, 0.0],
        [0, 0, 0,   0, 0, 0, 0, 1, 0, 0.1],
        [1, 1, 0,   0, 1, 0, 0, 0, 0, 0.1],
    ],
])
unpadded_shapes = [
    [4, 10],
    [5, 10],
    [5, 10],
    [4, 8],
]
a_loss = alignment_losses.alignment_loss()
a_loss(unpadded_shapes, t)

## Training

In [None]:
def prep_inputs(cur_id, cur_txt, cur_audio, sr):
    cur_txt = tf.ensure_shape(cur_txt, ())
    cur_txt = tf.strings.bytes_split(cur_txt)
    # cur_txt = tf.concat([['[BOS]'], cur_txt, ['[EOS]']], axis=0)

    shapes = tf.concat([
        tf.shape(cur_txt),
        1 + (tf.shape(cur_audio[:, 0]) - pfa.frame_length) // pfa.frame_step

        # tf.cast(tf.math.ceil(
        #     (tf.shape(cur_audio[:, 0]) - pfa.frame_length) // pfa.frame_step
        # ) + 1, tf.int32)

    ], axis=0)
    return cur_txt, cur_audio[:, 0], shapes


def prep_batch_inputs(cur_txt, cur_audio, seq_lengths):
    return {
        'char_seq': cur_txt,
        'waveform': cur_audio,
    }, seq_lengths

In [None]:
pad_index = pfa.char_table('[PAD]')
pad_index

In [None]:
batch_size = 32
dataset = tf.data.TFRecordDataset(
    'ljspeech.tfrecords'
).shuffle(6 * batch_size).repeat().map(LJSpeechReader.deserialize).map(
    prep_inputs
).padded_batch(
    # batch_size, padding_values=(pad_index, 0.0), padded_shapes=(200, 400000)
    batch_size, padding_values=('[PAD]', 0.0, 0), drop_remainder=True
).map(prep_batch_inputs).prefetch(tf.data.AUTOTUNE)

In [None]:
sample = [x for x in dataset.take(1)]
sample[0][0]['char_seq'].shape, sample[0][0]['waveform'].shape, str(sample[0][1])

In [None]:
tf.strings.join(sample[0][0]['char_seq'][0])

In [None]:
pfa.MelSpectrogram(tf.zeros((219293,))).shape

In [None]:
(1023 * 10 - 1024) / 256

In [None]:
def create_mask(unpadded_shape, padded_shape):
    """ Creates a mask that is 1 in unpadded shape and zero elsewhere
    e.g.
    1 1 1 1 1 1 0 0 0 0
    1 1 1 1 1 1 0 0 0 0
    1 1 1 1 1 1 0 0 0 0
    1 1 1 1 1 1 0 0 0 0
    0 0 0 0 0 0 0 0 0 0
    0 0 0 0 0 0 0 0 0 0
    """
    vec = tf.ones(unpadded_shape)
    pad_shape = tf.stack([
        tf.zeros((2,), dtype=tf.int32),
        padded_shape - unpadded_shape
    ], axis=1)
    vec = tf.pad(vec, pad_shape)
    return vec


s1 = tf.convert_to_tensor([72, 391])
s2 = tf.convert_to_tensor([166, 792])
v1 = tf.ones(s1)

pad_shape = tf.stack([
    tf.zeros((2,), dtype=tf.int32),
    s2-s1
], axis=1)

v2 = tf.pad(v1, pad_shape)

In [None]:
v1.shape, v2.shape

In [None]:
tf.stack([
    tf.zeros((2,), dtype=tf.int32),
    s2-s1
], axis=1)

In [None]:
# !ls checkpoints -l

In [None]:
# cp checkpoints/m_13_0.414.chkpt*.* .

In [None]:
model_losses = [
    alignment_losses.alignment_loss(x)
    for x in alignment_losses.possible_losses
]
print(model_losses)

alignment_model.compile(
    optimizer=tf.keras.optimizers.Adam(
        learning_rate=1e-3, clipnorm=0.1, beta_1=0.8, beta_2=0.99, epsilon=0.1),
    loss=model_losses[0],
    metrics=model_losses[1:],
)
alignment_model.load_weights('checkpoints/m_44_0.397.chkpt')

In [None]:
!rm -rf checkpoints
os.makedirs('checkpoints', exist_ok=True)
filepath = 'checkpoints/m_{epoch}_{loss:.3f}.chkpt'
chkpt_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='loss', verbose=1, save_best_only=True,
    save_weights_only=True, mode='auto', save_freq='epoch',
)

def scheduler(epoch, lr):
    return 1e-4
    if epoch <= 1:
        return 2e-6
    elif epoch == 2:
        return lr * 10
    elif epoch == 4:
        return lr * 10
    else:
        return lr
lr_callback = tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1)

reduce_callback = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='loss', factor=0.2, patience=10, verbose=1,
    mode='auto', min_delta=0.0001, cooldown=0, min_lr=1e-7
)

In [None]:
alignment_model.fit(
    dataset,
    epochs=300,
    steps_per_epoch=len(ljsr.df_audios) // batch_size,
    callbacks=[lr_callback, chkpt_callback]
)

# Visual evaluation of results

In [None]:
import models
from matplotlib import pyplot as plt
%matplotlib inline

m_spec = models.alignment_model.get_spectrogram_model()
m_logmel = models.alignment_model.get_melspec_model()

In [None]:
# alignment_model.load_weights('checkpoints/m_1_33.717620849609375.chkpt')
# samples[0][1]

In [None]:
samples = [x for x in dataset.take(1)]

In [None]:
preds = alignment_model(samples[0][0])
padded_char_len = preds.shape[1]
preds.shape, samples[0][1].shape

In [None]:
idx = 4
unpadded_lens = samples[0][1][idx]
char_len = unpadded_lens[0].numpy()
spec_len = unpadded_lens[1].numpy()
print('Unpadded:', unpadded_lens)

xmax = spec_len

plt.figure(figsize=(15, 6))
# for k in range(0, padded_char_len):
# for k in range(0, 15):
# for k in [0, 1, 2, 3, 4, 5, 6, 7, 8, -2, -1]:
for k in range(0, char_len, 1):
    plt.plot(preds[idx, k, 0:spec_len].numpy(), label=str(k))
    # plt.plot(preds[idx, k, :].numpy())
    plt.ylim(0, 1)
    # plt.show()
# plt.legend()
plt.xlim(0, xmax)


audio_data = samples[0][0]['waveform'][idx]
txt_data = tf.strings.join(samples[0][0]['char_seq'][idx]).numpy().decode('UTF-8').replace('[PAD]', '')

logmel = m_logmel(tf.expand_dims(
    audio_data, axis=0)
)
print(logmel.shape, txt_data)
# t = tf.cast(tf.range(0, logmel.shape[1]), tf.float32) * 256.0 / tf.cast(sr, tf.float32)
# mels = tf.range(0, logmel.shape[2], delta=1)
plt.figure(figsize=(15, 6))
plt.pcolormesh(
    # t.numpy(),
    # mels.numpy(),
    tf.transpose(logmel[0]).numpy()
)
plt.xlim(0, xmax)

plt.show()

## Decode prediction

In [None]:
import numpy as np
from models.decoder import PFADecoder
pfa_dec = PFADecoder()

In [None]:
m = preds[idx, 0:char_len, 0:spec_len].numpy()
print(m.shape[0] * m.shape[1])
alignment = np.array(pfa_dec.decode_alignment(m))

In [None]:
plt.figure(figsize=(15, 6))
plt.plot(alignment[:, 1], alignment[:, 0])
plt.xlim(0, xmax)

plt.figure(figsize=(15, 6))
plt.pcolormesh(
    tf.transpose(logmel[0]).numpy()
)
plt.xlim(0, xmax)

plt.show()

In [None]:
contents = tf.audio.encode_wav(tf.expand_dims(audio_data, 1), sr)
tf.io.write_file('outputs/out.wav', contents)

In [None]:
def write_srt(chars, alignment, time_delta, filename):
    char_dict = _compute_chardict(chars, alignment, time_delta, filename)
    _write_chardict(char_dict, filename)


def _write_chardict(char_dict, filename):
    cur_annot = 1
    with open(filename, 'w') as f:
        use_sep = False
        for x in char_dict:
            xmin = char_dict[x]['t0']
            xmax = char_dict[x]['tf']
            txt = char_dict[x]['char']
            if use_sep:
                f.write('\n\n')
            use_sep = True

            f.write(f'{cur_annot}\n')
            # don't show in UI
            base_txt = txt + '|||8760|||9760|||Arial, 44pt|||False|||'
            xmin = convert_seconds_to_srt(xmin)
            xmax = convert_seconds_to_srt(xmax)
            f.write(f'{xmin} --> {xmax}\n')
            f.write(f'{base_txt}')
            cur_annot += 1


def _compute_chardict(chars, alignment, time_delta, filename):
    char_dict = {}
    char_dict[0] = {'char': chars[0], 't0': 0}
    prev_char = 0
    for char_idx, spec_idx in alignment:
        if char_idx > prev_char:
            char_dict[prev_char]['tf'] = spec_idx * time_delta
            char_dict[char_idx] = {
                'char': chars[char_idx], 't0': spec_idx * time_delta
            }
        prev_char = char_idx
    char_dict[prev_char]['tf'] = alignment[-1][1] * time_delta
    return char_dict


def convert_seconds_to_srt(time_in_s):
    # 00:00:01,417 --> 00:00:01,924
    hours = int(time_in_s) // 3600
    remaining = int(time_in_s) - 3600 * hours
    minutes = remaining // 60
    seconds = remaining - 60 * minutes

    milliseconds = str(int(np.round(1000 * (time_in_s - int(time_in_s)))))

    hours = str(hours).rjust(2, '0')
    minutes = str(minutes).rjust(2, '0')
    seconds = str(seconds).rjust(2, '0')
    milliseconds = milliseconds.rjust(3, '0')
    return f'{hours}:{minutes}:{seconds},{milliseconds}'

In [None]:
filename = 'outputs/out.srt'
time_delta = 256 / sr.numpy()
write_srt(txt_data, alignment, time_delta, filename)

In [None]:
def convert_seconds_to_srt(time_in_s):
    # 00:00:01,417 --> 00:00:01,924
    hours = int(time_in_s) // 3600
    remaining = int(time_in_s) - 3600 * hours
    minutes = remaining // 60
    seconds = remaining - 60 * minutes

    milliseconds = str(int(np.round(1000 * (time_in_s - int(time_in_s)))))

    hours = str(hours).rjust(2, '0')
    minutes = str(minutes).rjust(2, '0')
    seconds = str(seconds).rjust(2, '0')
    milliseconds = milliseconds.rjust(3, '0')
    return f'{hours}:{minutes}:{seconds},{milliseconds}'

In [None]:
convert_seconds_to_srt(1*3600 + 28*60 + 4.281)

## Misc

In [None]:
audio_data = samples[0][0]['waveform'][idx]
txt_data = tf.strings.join(samples[0][0]['char_seq'][idx]).numpy().decode('UTF-8').replace('[PAD]', '')

logmel = m_logmel(tf.expand_dims(
    audio_data, axis=0)
)
print(logmel.shape, txt_data)
# t = tf.cast(tf.range(0, logmel.shape[1]), tf.float32) * 256.0 / tf.cast(sr, tf.float32)
# mels = tf.range(0, logmel.shape[2], delta=1)
plt.figure(figsize=(15, 6))

plt.pcolormesh(
    # t.numpy(),
    # mels.numpy(),
    tf.transpose(logmel[0]).numpy()
)
plt.xlim(0, spec_len)

In [None]:
plt.plot(preds[idx, 0, 0:spec_len].numpy())
plt.plot(preds[idx, char_len - 1, 0:spec_len].numpy())

In [None]:
tf.reduce_sum(preds[idx], axis=0), tf.reduce_max(preds[idx], axis=0)


In [None]:
spec = m_spec(tf.expand_dims(
    audio_data[0:16000 * 5, 0], axis=0)
)
print(spec.shape)
t = tf.cast(tf.range(0, spec.shape[1]), tf.float32) * 256.0 / tf.cast(sr, tf.float32)
freqs = tf.range(0, spec.shape[2], delta=1, dtype=tf.float32) * tf.cast(sr, tf.float32) / 1024.

plt.figure(figsize=(15, 6))
plt.pcolormesh(
    t.numpy(),
    freqs.numpy(),
    tf.math.log(1e-6 + tf.transpose(spec[0])).numpy()
)

In [None]:
logmel = m_logmel(tf.expand_dims(
    audio_data[0:16000 * 5, 0], axis=0)
)
print(logmel.shape)
t = tf.cast(tf.range(0, logmel.shape[1]), tf.float32) * 256.0 / tf.cast(sr, tf.float32)
mels = tf.range(0, logmel.shape[2], delta=1)
plt.figure(figsize=(15, 6))

plt.pcolormesh(
    t.numpy(),
    mels.numpy(),
    tf.transpose(logmel[0]).numpy()