In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [2]:
import malaya_speech.train.model.mini_jasper as jasper
import malaya_speech
import tensorflow as tf
import numpy as np
import json






The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.




In [3]:
with open('malaya-speech-sst-vocab.json') as fopen:
    unique_vocab = json.load(fopen) + ['{', '}', '[']

In [4]:
featurizer = malaya_speech.tf_featurization.STTFeaturizer(
    normalize_per_feature = True
)

In [5]:
X = tf.placeholder(tf.float32, [None, None])
X_len = tf.placeholder(tf.int32, [None])

In [6]:
batch_size = tf.shape(X)[0]
features = tf.TensorArray(dtype = tf.float32, size = batch_size, dynamic_size = True, infer_shape = False)
features_len = tf.TensorArray(dtype = tf.int32, size = batch_size)

init_state = (0, features, features_len)

def condition(i, features, features_len):
    return i < batch_size

def body(i, features, features_len):
    f = featurizer(X[i, :X_len[i]])
    f_len = tf.shape(f)[0]
    return i + 1, features.write(i, f), features_len.write(i, f_len)

_, features, features_len = tf.while_loop(condition, body, init_state)
features_len = features_len.stack()
padded_features = tf.TensorArray(dtype = tf.float32, size = batch_size)
maxlen = tf.reduce_max(features_len)

init_state = (0, padded_features)

def condition(i, padded_features):
    return i < batch_size

def body(i, padded_features):
    f = features.read(i)
    f = tf.pad(f, [[0, maxlen - tf.shape(f)[0]], [0,0]])
    return i + 1, padded_features.write(i, f)

_, padded_features = tf.while_loop(condition, body, init_state)
padded_features = padded_features.stack()
padded_features.set_shape((None, None, 80))

In [7]:
model = jasper.Model(padded_features, features_len, training = False)
logits = tf.layers.dense(model.logits['outputs'], len(unique_vocab) + 1)
seq_lens = model.logits['src_length']
logits = tf.transpose(logits, [1, 0, 2])
logits = tf.identity(logits, name = 'logits')
seq_lens = tf.identity(seq_lens, name = 'seq_lens')


Instructions for updating:
Use `tf.keras.layers.SeparableConv1D` instead.
Instructions for updating:
Please use `layer.__call__` method instead.
Instructions for updating:
Use keras.layers.BatchNormalization instead.  In particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not be used (consult the `tf.keras.layers.batch_normalization` documentation).
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Instructions for updating:
Use keras.layers.Dense instead.


In [8]:
decoded = tf.nn.ctc_beam_search_decoder(logits, seq_lens, beam_width=100, top_paths=1, merge_repeated=True)
preds = tf.sparse.to_dense(tf.to_int32(decoded[0][0]))
preds = tf.identity(preds, 'preds')

Instructions for updating:
Use `tf.cast` instead.


In [9]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

In [10]:
var_lists = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
saver = tf.train.Saver(var_list = var_lists)
saver.restore(sess, 'asr-mini-jasper-ctc/model.ckpt-365000')

INFO:tensorflow:Restoring parameters from asr-mini-jasper-ctc/model.ckpt-365000


In [14]:
files = [
    'savewav_2020-11-26_22-36-06_294832.wav',
    'savewav_2020-11-26_22-40-56_929661.wav',
    'download.wav',
    'husein-zolkepli.wav',
    'mas-aisyah.wav',
    'khalil-nooh.wav',
    'wadi-annuar.wav',
    '675.wav',
    '664.wav',
]

ys = [malaya_speech.load(f)[0] for f in files]

In [15]:
padded, lens = malaya_speech.padding.sequence_1d(ys, return_len = True)
decoded = sess.run(preds, feed_dict = {X: padded, X_len: lens})

In [16]:
results = []
for i in range(len(decoded)):
    results.append(malaya_speech.char.decode(decoded[i], lookup = unique_vocab).replace('<PAD>', ''))
results

['halo nama saya sin saya tak suka mandi ketak saya masam',
 'helo nama saya musin saya suke mandi sa man di gitiap hari',
 'halau lo hlo anasalamualaikum luar tantu achana sekolah malaysia',
 'testing nama saya musin bin ja kapi',
 'sebut perkatan ungla',
 'tolong sebut atik kata',
 'jadi dalam perjalanan ini dunia yang susah ini ketika dabi mengajar muazbin jabat tadi ni alah mak ini',
 'ini dan melalui kenyatan mesej itu mastura menegaskan',
 'pilihan tepat apabila dia kini lebih berani dan']

In [18]:
results = []
for i in range(len(decoded)):
    results.append(malaya_speech.char.decode(decoded[i], lookup = unique_vocab).replace('<PAD>', ''))
results

['halo nama saya sin saya tak suka mandi ketap saya masan',
 'helo enam saya husin saya suka mandi saya mandi tetiap hari',
 'halau l samualaikum luar tantu awashanal sekolah malaysia',
 'testing nama saya musin binza kapli',
 'tolong sebut arti kata',
 'jadi dalam perjalanan ini dunia yang susah ini ketika dabi mengajar muaz bin jabat tadi ni alah mak ini',
 'ini dan melalui kenyatan mesej itu mastura menegaskan',
 'pilihan tepat apabila dia kini lebih berani dan']

In [19]:
saver = tf.train.Saver()
saver.save(sess, 'asr-mini-jasper-ctc-output/model.ckpt')

'asr-mini-jasper-ctc-output/model.ckpt'

In [20]:
strings = ','.join(
    [
        n.name
        for n in tf.get_default_graph().as_graph_def().node
        if ('Variable' in n.op
        or 'Placeholder' in n.name
        or 'logits' in n.name
        or 't_logits' in n.name
        or 'seq_lens' in n.name
        or 'alphas' in n.name
        or 'self/Softmax' in n.name)
        and 'adam' not in n.name
        and 'beta' not in n.name
        and 'global_step' not in n.name
        and 'Assign' not in n.name
    ]
)
strings.split(',')

['Placeholder',
 'Placeholder_1',
 'w2l_encoder/conv11/depthwise_kernel',
 'w2l_encoder/conv11/pointwise_kernel',
 'w2l_encoder/conv11/bn/gamma',
 'w2l_encoder/conv11/bn/moving_mean',
 'w2l_encoder/conv11/bn/moving_variance',
 'w2l_encoder/conv21/depthwise_kernel',
 'w2l_encoder/conv21/pointwise_kernel',
 'w2l_encoder/conv21/bn/gamma',
 'w2l_encoder/conv21/bn/moving_mean',
 'w2l_encoder/conv21/bn/moving_variance',
 'w2l_encoder/conv22/depthwise_kernel',
 'w2l_encoder/conv22/pointwise_kernel',
 'w2l_encoder/conv22/bn/gamma',
 'w2l_encoder/conv22/bn/moving_mean',
 'w2l_encoder/conv22/bn/moving_variance',
 'w2l_encoder/conv23/res/depthwise_kernel',
 'w2l_encoder/conv23/res/pointwise_kernel',
 'w2l_encoder/conv23/res_bn/gamma',
 'w2l_encoder/conv23/res_bn/moving_mean',
 'w2l_encoder/conv23/res_bn/moving_variance',
 'w2l_encoder/conv23/depthwise_kernel',
 'w2l_encoder/conv23/pointwise_kernel',
 'w2l_encoder/conv23/bn/gamma',
 'w2l_encoder/conv23/bn/moving_mean',
 'w2l_encoder/conv23/bn/movi

In [21]:
def freeze_graph(model_dir, output_node_names):

    if not tf.gfile.Exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            'directory: %s' % model_dir
        )

    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/frozen_model.pb'
    clear_devices = True
    with tf.Session(graph = tf.Graph()) as sess:
        saver = tf.train.import_meta_graph(
            input_checkpoint + '.meta', clear_devices = clear_devices
        )
        saver.restore(sess, input_checkpoint)
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            tf.get_default_graph().as_graph_def(),
            output_node_names.split(','),
        )
        with tf.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print('%d ops in the final graph.' % len(output_graph_def.node))

In [22]:
freeze_graph('asr-mini-jasper-ctc-output', strings)

INFO:tensorflow:Restoring parameters from asr-mini-jasper-ctc-output/model.ckpt
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
INFO:tensorflow:Froze 260 variables.
INFO:tensorflow:Converted 260 variables to const ops.
1901 ops in the final graph.


In [23]:
def load_graph(frozen_graph_filename):
    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)
    return graph


In [25]:
g = load_graph('asr-mini-jasper-ctc-output/frozen_model.pb')
x = g.get_tensor_by_name('import/Placeholder:0')
x_lens = g.get_tensor_by_name('import/Placeholder_1:0')
logits = g.get_tensor_by_name('import/logits:0')
seq_lens = g.get_tensor_by_name('import/seq_lens:0')
test_sess = tf.InteractiveSession(graph = g)
result = test_sess.run([logits, seq_lens], feed_dict = {x: padded, x_lens: lens})
result

[array([[[-3.56353378e+01, -3.54871101e+01, -4.68669176e+00, ...,
          -3.52585754e+01, -3.42276573e+01,  2.41703320e+00],
         [-4.02817230e+01, -4.03476601e+01, -3.51572442e+00, ...,
          -4.01500816e+01, -3.98890228e+01,  8.21235466e+00],
         [-1.18128624e+02, -1.18496262e+02, -2.15755062e+01, ...,
          -1.16172142e+02, -1.18450768e+02,  1.15618849e+01],
         ...,
         [-5.60658607e+01, -5.43611794e+01, -1.12070551e+01, ...,
          -5.68232346e+01, -5.44302444e+01,  1.41392593e+01],
         [-6.58228683e+01, -6.72648010e+01, -1.38980141e+01, ...,
          -6.67019577e+01, -6.65573730e+01,  1.29448862e+01],
         [-5.87072830e+01, -5.75884590e+01, -1.06604137e+01, ...,
          -5.80280991e+01, -5.71711540e+01,  8.39609718e+00]],
 
        [[-3.67545586e+01, -3.60251694e+01,  6.41669178e+00, ...,
          -3.49555550e+01, -3.52595940e+01,  3.58636022e+00],
         [-4.33908730e+01, -4.32421989e+01, -1.83674848e+00, ...,
          -4.31035233

In [26]:
from tensorflow.tools.graph_transforms import TransformGraph

In [27]:
transforms = ['add_default_attributes',
             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',
             'fold_batch_norms',
             'fold_old_batch_norms',
             'quantize_weights(fallback_min=-10, fallback_max=10)',
             'strip_unused_nodes',
             'sort_by_execution_order']

pb = 'asr-mini-jasper-ctc-output/frozen_model.pb'

input_graph_def = tf.GraphDef()
with tf.gfile.FastGFile(pb, 'rb') as f:
    input_graph_def.ParseFromString(f.read())

transformed_graph_def = TransformGraph(input_graph_def, 
                                           ['Placeholder', 'Placeholder_1'],
                                           ['logits', 'seq_lens'], transforms)
    
with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:
    f.write(transformed_graph_def.SerializeToString())

Instructions for updating:
Use tf.gfile.GFile.


In [28]:
g = load_graph(f'{pb}.quantized')
x = g.get_tensor_by_name('import/Placeholder:0')
x_lens = g.get_tensor_by_name('import/Placeholder_1:0')
logits = g.get_tensor_by_name('import/logits:0')
seq_lens = g.get_tensor_by_name('import/seq_lens:0')
test_sess = tf.InteractiveSession(graph = g)
result = test_sess.run([logits, seq_lens], feed_dict = {x: padded, x_lens: lens})
result

[array([[[ -35.352634  ,  -34.899845  ,   -5.159472  , ...,
           -34.89525   ,  -33.7444    ,    2.2478917 ],
         [ -39.960167  ,  -39.98972   ,   -3.643434  , ...,
           -39.698025  ,  -39.401844  ,    7.950354  ],
         [-119.95948   , -120.627     ,  -23.972021  , ...,
          -118.14435   , -120.08638   ,   11.283413  ],
         ...,
         [ -55.476936  ,  -53.85571   ,  -10.735267  , ...,
           -55.997746  ,  -54.019707  ,   14.123409  ],
         [ -64.66348   ,  -66.42641   ,  -14.147051  , ...,
           -65.54734   ,  -65.34446   ,   13.345505  ],
         [ -58.10192   ,  -56.994358  ,  -10.657103  , ...,
           -57.14263   ,  -56.475216  ,    8.465521  ]],
 
        [[ -36.24082   ,  -35.336575  ,    6.0977454 , ...,
           -34.388954  ,  -34.611923  ,    3.697202  ],
         [ -43.161827  ,  -42.81343   ,   -2.0230806 , ...,
           -42.6968    ,  -42.02452   ,    7.72977   ],
         [-130.20447   , -131.03014   ,  -16.43127   , 