In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [2]:
from malaya_speech.train.model import hubert, ctc
from malaya_speech.train.model.conformer.model import Model as ConformerModel
import malaya_speech
import tensorflow as tf
import numpy as np
import json
from glob import glob
import string

In [3]:
unique_vocab = [''] + list(
    string.ascii_lowercase + string.digits
) + [' ']
len(unique_vocab)

38

In [4]:
# !wget https://f000.backblazeb2.com/file/malaya-speech-model/language-model/dump-combined/model.trie.klm
# !wget https://f000.backblazeb2.com/file/malaya-speech-model/ctc-decoder/ctc_decoders-1.0-cp36-cp36m-linux_x86_64.whl
# !pip3 install ctc_decoders-1.0-cp36-cp36m-linux_x86_64.whl 

In [5]:
from ctc_decoders import Scorer
from ctc_decoders import ctc_beam_search_decoder

In [6]:
n_mels = 80
sr = 16000
maxlen = 18
minlen_text = 1

def mp3_to_wav(file, sr = sr):
    audio = AudioSegment.from_file(file)
    audio = audio.set_frame_rate(sr).set_channels(1)
    sample = np.array(audio.get_array_of_samples())
    return malaya_speech.astype.int_to_float(sample), sr


def generate(file):
    with open(file) as fopen:
        dataset = json.load(fopen)
    audios, cleaned_texts = dataset['X'], dataset['Y']
    for i in range(len(audios)):
        try:
            if audios[i].endswith('.mp3'):
                # print('found mp3', audios[i])
                wav_data, _ = mp3_to_wav(audios[i])
            else:
                wav_data, _ = malaya_speech.load(audios[i], sr = sr)

            if (len(wav_data) / sr) > maxlen:
                # print(f'skipped audio too long {audios[i]}')
                continue

            if len(cleaned_texts[i]) < minlen_text:
                # print(f'skipped text too short {audios[i]}')
                continue

            t = [unique_vocab.index(c) for c in cleaned_texts[i]]

            yield {
                    'waveforms': wav_data,
                    'waveforms_length': [len(wav_data)],
                    'targets': t,
                    'targets_length': [len(t)],
                }
        except Exception as e:
            print(e)


def get_dataset(
    file,
    batch_size = 2,
    shuffle_size = 20,
    thread_count = 24,
    maxlen_feature = 1800,
):
    def get():
        dataset = tf.data.Dataset.from_generator(
            generate,
            {
                'waveforms': tf.float32,
                'waveforms_length': tf.int32,
                'targets': tf.int32,
                'targets_length': tf.int32,
            },
            output_shapes = {
                'waveforms': tf.TensorShape([None]),
                'waveforms_length': tf.TensorShape([None]),
                'targets': tf.TensorShape([None]),
                'targets_length': tf.TensorShape([None]),
            },
            args = (file,),
        )
        dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
        dataset = dataset.padded_batch(
            batch_size,
            padded_shapes = {
                'waveforms': tf.TensorShape([None]),
                'waveforms_length': tf.TensorShape([None]),
                'targets': tf.TensorShape([None]),
                'targets_length': tf.TensorShape([None]),
            },
            padding_values = {
                'waveforms': tf.constant(0, dtype = tf.float32),
                'waveforms_length': tf.constant(0, dtype = tf.int32),
                'targets': tf.constant(0, dtype = tf.int32),
                'targets_length': tf.constant(0, dtype = tf.int32),
            },
        )
        return dataset

    return get

In [7]:
dev_dataset = get_dataset('bahasa-asr-test.json')()
features = dev_dataset.make_one_shot_iterator().get_next()
features

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.


{'waveforms': <tf.Tensor 'IteratorGetNext:2' shape=(?, ?) dtype=float32>,
 'waveforms_length': <tf.Tensor 'IteratorGetNext:3' shape=(?, ?) dtype=int32>,
 'targets': <tf.Tensor 'IteratorGetNext:0' shape=(?, ?) dtype=int32>,
 'targets_length': <tf.Tensor 'IteratorGetNext:1' shape=(?, ?) dtype=int32>}

In [8]:
training = True

In [9]:
class Encoder:
    def __init__(self, config):
        self.config = config
        self.encoder = ConformerModel(**self.config)

    def __call__(self, x, input_mask, training = True):
        return self.encoder(x, training = training)

In [10]:
config_conformer = malaya_speech.config.conformer_base_encoder_config
config_conformer['subsampling']['type'] = 'none'
config_conformer['dropout'] = 0.0
encoder = Encoder(config_conformer)
cfg = hubert.HuBERTConfig(
    extractor_mode='layer_norm',
    dropout=0.0,
    attention_dropout=0.0,
    encoder_layerdrop=0.0,
    dropout_input=0.0,
    dropout_features=0.0,
    final_dim=256,
)
model = hubert.Model(cfg, encoder, ['pad', 'eos', 'unk'] + [str(i) for i in range(100)])
X = features['waveforms']
X_len = features['waveforms_length'][:, 0]
r = model(X, padding_mask = X_len, features_only = True, mask = False)
logits = tf.layers.dense(r['x'], len(unique_vocab) + 1)
log_probs = tf.nn.log_softmax(logits)
seq_lens = tf.reduce_sum(
    tf.cast(tf.logical_not(r['padding_mask']), tf.int32), axis = 1
)


Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Instructions for updating:
Use keras.layers.Dense instead.
Instructions for updating:
Please use `layer.__call__` method instead.


In [11]:
logits = tf.transpose(logits, [1, 0, 2])
logits = tf.identity(logits, name = 'logits')
seq_lens = tf.identity(seq_lens, name = 'seq_lens')

In [12]:
# decoded = tf.nn.ctc_beam_search_decoder(
#     logits,
#     seq_lens,
#     beam_width = beam_size,
#     top_paths = 1,
#     merge_repeated = True)[0][0]
# decoded._indices, decoded._values

In [13]:
logits, seq_lens, log_probs

(<tf.Tensor 'logits:0' shape=(?, ?, 39) dtype=float32>,
 <tf.Tensor 'seq_lens:0' shape=(?,) dtype=int32>,
 <tf.Tensor 'LogSoftmax:0' shape=(?, ?, 39) dtype=float32>)

In [14]:
decoded = tf.nn.ctc_beam_search_decoder(logits, seq_lens, beam_width=10, top_paths=1, merge_repeated=True)
preds = tf.sparse.to_dense(tf.to_int32(decoded[0][0]))
preds = tf.identity(preds, 'preds')
log_probs = tf.identity(log_probs, 'log_probs')
preds, log_probs

Instructions for updating:
Use `tf.cast` instead.


(<tf.Tensor 'preds:0' shape=(?, ?) dtype=int32>,
 <tf.Tensor 'log_probs:0' shape=(?, ?, 39) dtype=float32>)

In [15]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
saver = tf.train.Saver(var_list = var_list)
saver.restore(sess, 'hubert-conformer-base-ctc-char/model.ckpt-2000000')

INFO:tensorflow:Restoring parameters from hubert-conformer-base-ctc-char/model.ckpt-2000000


In [16]:
import six
import string
from typing import List

def decode(ids, lookup: List[str] = None):
    """
    Decode integer representation to string based on ascii table or lookup variable.

    Parameters
    -----------
    ids: List[int]
    lookup: List[str], optional (default=None)
        list of unique strings.

    Returns
    --------
    result: str
    """
    decoded_ids = []
    int2byte = six.int2byte
    for id_ in ids:
        if lookup:
            decoded_ids.append(lookup[id_])
        else:
            decoded_ids.append(
                int2byte(id_ - NUM_RESERVED_TOKENS).decode('utf-8')
            )

    return ''.join(decoded_ids)

In [17]:
# %%time

# kenlm_model = kenlm.Model('model.trie.klm')
# decoder = build_ctcdecoder(
#     unique_vocab + ['_'],
#     kenlm_model,
#     alpha=0.1,
#     beta=3.0,
#     ctc_token_idx=len(unique_vocab)
# )

In [18]:
%%time

from pyctcdecode import build_ctcdecoder
import kenlm

kenlm_model = kenlm.Model('out.trie.klm')
decoder = build_ctcdecoder(
    unique_vocab + ['_'],
    kenlm_model,
    alpha=0.2,
    beta=1.0,
    ctc_token_idx=len(unique_vocab)
)

CPU times: user 6.69 ms, sys: 19.8 ms, total: 26.5 ms
Wall time: 27.6 ms


In [19]:
scorer = Scorer(0.5, 1.0, 'out.trie.klm', unique_vocab)

In [20]:
logits_t = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))

In [21]:
# r = sess.run([preds, logits_t, seq_lens, features['targets']])
# out = decoder2.decode_beams(r[1][1,:r[2][1]], 
#                            prune_history=True)
# text, lm_state, timesteps, logit_score, lm_score = out[0]

# text, ctc_beam_search_decoder(r[1][1,:r[2][1]], unique_vocab, 20, ext_scoring_func = scorer)[0][1]

In [22]:
# out = decoder.decode_beams(np.pad(r[1][:,0], [[0,0], [1,0]], constant_values = -13.0), 
#                            prune_history=True)
# text, lm_state, timesteps, logit_score, lm_score = out[0]
# text

In [23]:
# out = decoder2.decode_beams(np.pad(r[1][:,0], [[0,0], [1,0]], constant_values = -13.0), 
#                            prune_history=True)
# text, lm_state, timesteps, logit_score, lm_score = out[0]
# text

In [24]:
# decode(r[0][1], unique_vocab), decode(r[-1][1], unique_vocab)

In [25]:
# %%time

# ctc_beam_search_decoder(r[1][0,:r[2][0]], unique_vocab, 20, ext_scoring_func = scorer)[0][1]

In [26]:
# %%time

# ctc_beam_search_decoder(r[1][1,:r[2][1]], unique_vocab, 20, ext_scoring_func = scorer)[0]

In [27]:
from malaya_speech.utils import metrics, char

wer, cer, wer_lm, cer_lm = [], [], [], []
wer_lm2, cer_lm2 = [], []
index = 0
while True:
    try:
        r = sess.run([preds, logits_t, seq_lens, features['targets']])
        for no, row in enumerate(r[0]):
            d = decode(row, lookup = unique_vocab).replace('<PAD>', '')
            t = decode(r[-1][no], lookup = unique_vocab).replace('<PAD>', '')
            wer.append(malaya_speech.metrics.calculate_wer(t, d))
            cer.append(malaya_speech.metrics.calculate_cer(t, d))
            
            d_lm = ctc_beam_search_decoder(r[1][no,:r[2][no]], 
                                           unique_vocab, 20, ext_scoring_func = scorer)[0][1]
            wer_lm.append(malaya_speech.metrics.calculate_wer(t, d_lm))
            cer_lm.append(malaya_speech.metrics.calculate_cer(t, d_lm))
            
            out = decoder.decode_beams(r[1][no,:r[2][no]], prune_history=True)
            d_lm2, lm_state, timesteps, logit_score, lm_score = out[0]
            wer_lm2.append(malaya_speech.metrics.calculate_wer(t, d_lm2))
            cer_lm2.append(malaya_speech.metrics.calculate_cer(t, d_lm2))
            
        index += 1
    except Exception as e:
        break

In [28]:
np.mean(wer), np.mean(cer), np.mean(wer_lm), np.mean(cer_lm), np.mean(wer_lm2), np.mean(cer_lm2)

(0.23871400816641347,
 0.06089981404967854,
 0.14147911604534796,
 0.04507517237844968,
 0.1456196894458469,
 0.043937028321459626)

In [29]:
d, t, d_lm, d_lm2

('saji tempat bangdik menarutumi s buatanya itu piring saji yang isinya',
 'saji tempat bangdik menaruh tumis tahu buatannya itu piring saji yang isinya',
 'saji tempat bangdik menaruh tumis buatannya itu piring saji yang isinya',
 'saji tempat bangdik menaruh tumis buatannya itu piring saji yang isinya')

In [30]:
%%time

d_lm = ctc_beam_search_decoder(r[1][no,:r[2][no]], 
                                           unique_vocab, 20, ext_scoring_func = scorer)[0][1]

CPU times: user 8.12 ms, sys: 0 ns, total: 8.12 ms
Wall time: 8.08 ms


In [31]:
%%time

out = decoder.decode_beams(r[1][no,:r[2][no]], prune_history=True)

CPU times: user 168 ms, sys: 4.18 ms, total: 173 ms
Wall time: 170 ms
