In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import sys

SOURCE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__name__)))
sys.path.insert(0, SOURCE_DIR)

In [3]:
import tensorflow as tf

# tf.compat.v1.enable_eager_execution()

In [4]:
from malaya_speech.train.model.utils import shape_list
import numpy as np

In [5]:
x = np.load('test.npy')

In [6]:
ll = tf.convert_to_tensor(x.astype(np.float32))
mask = tf.cast(tf.sequence_mask([60, 80], 100), tf.float32)
mask = tf.expand_dims(mask, -1)
mask = tf.tile(mask, [1, 1, 80])
mask.shape

TensorShape([Dimension(2), Dimension(100), Dimension(80)])

In [38]:
def monotonic_alignment_search(ll: tf.Tensor,
                               mask: tf.Tensor) -> tf.Tensor:
    """Monotonic aligment search, reference from jaywalnut310's glow-tts.
    https://github.com/jaywalnut310/glow-tts/blob/master/commons.py#L60
    Args:
        ll: [tf.float32; [B, T, S]], loglikelihood matrix.
        mask: [tf.float32; [B, T, S]], attention mask.
    Returns:
        [tf.float32; [B, T, S]], alignment.
    """
    # B, T, S
    bsize, timestep, seqlen = shape_list(ll)
    # (expected) T x [B, S]
    direction = tf.TensorArray(dtype = tf.bool, size = timestep)
    prob = tf.zeros([bsize, seqlen], dtype=tf.float32)
    # [1, S]
    x_range = tf.range(seqlen)[None]
    
    def condition(j, direction, prob):
        return j < timestep
    
    def body(j, direction, prob):
        prev = tf.pad(prob, [[0, 0], [1, 0]],
                      mode='CONSTANT',
                      constant_values=tf.float32.min)[:, :-1]
        cur = prob
        # larger value mask
        max_mask = tf.math.greater_equal(cur, prev)
        # select larger value
        prob_max = tf.where(max_mask, cur, prev)
        # write direction
        direction = direction.write(j, max_mask)
        # update prob
        
        x_range_ = tf.tile(x_range, [tf.shape(prob_max)[0], 1])
        j_ = tf.fill(tf.shape(x_range_), j)
        min_ = tf.fill(tf.shape(x_range_), tf.float32.min)
        prob = tf.where(tf.math.less_equal(x_range_, j_), 
                        prob_max + ll[:, j], min_)
        
        return j + 1, direction, prob
    
    init_state = (0, direction, prob)
    j, direction, prob = tf.while_loop(condition, body, init_state)
    # return direction.stack()
    direction = tf.cast(tf.transpose(direction.stack(), [1, 0, 2]), tf.int32)
    direction.set_shape((None, None, None))
    
    correct = tf.fill(tf.shape(direction), 1)
    direction = tf.where(tf.cast(mask, tf.bool), direction, correct)
    # (expected) T x [B, S]
    attn = tf.TensorArray(dtype = tf.float32, size = timestep)
    # [B]
    index = tf.cast(tf.reduce_sum(mask[:, 0], axis=-1), tf.int32) - 1
    # [B], [B]
    index_range, values = tf.range(bsize), tf.ones(bsize)
    
    def condition(j, attn, index):
        return j >= 0
    
    def body(j, attn, index):
        
        attn = attn.write(j, tf.scatter_nd(
            tf.stack([index_range, index], axis=1),
            values, [bsize, seqlen]))
        # [B]
        dir = tf.gather_nd(
            direction,
            tf.stack([index_range, tf.cast(values, tf.int32) * j, index],
                     axis=1))
        # [B]
        index = index + dir - 1
        return j - 1, attn, index
    
    init_state = (timestep - 1, attn, index)
    _, attn, _ = tf.while_loop(condition, body, init_state)
    stacked = attn.stack()
    stacked = tf.transpose(stacked, [1, 0, 2])
    stacked.set_shape((None, None, None))
    return stacked * mask

o = monotonic_alignment_search(ll, mask)

In [33]:
o

<tf.Tensor 'Cast_11:0' shape=(2, ?, 80) dtype=int32>

In [27]:
sess = tf.Session()

In [39]:
o_ = sess.run(o)
o_.shape

(2, 100, 80)

In [40]:
align = np.load('align.npy')

In [41]:
align.shape, o_.shape

((2, 100, 80), (2, 100, 80))

In [42]:
(align == o_.astype(np.int32)).mean()

1.0