In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [2]:
import sys

SOURCE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__name__)))
sys.path.insert(0, SOURCE_DIR)

In [3]:
import malaya_speech
import malaya_speech.config
from malaya_speech.train.model import srgan
from malaya_speech.train.model import enhancement
import tensorflow as tf
import numpy as np






The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.





In [4]:
sr = 44100
partition_size = 2048
reduction_factor = 4

In [5]:
tf.reset_default_graph()

x = tf.placeholder(tf.float32, (None,))
x_ = tf.expand_dims(x, 1)
partitioned_x = malaya_speech.tf_featurization.pad_and_partition(x_, partition_size // reduction_factor)

with tf.variable_scope('generator') as gen:
    model = srgan.Model(partitioned_x, training = True)
    
model.logits.set_shape((None, partition_size, 1))
logits = tf.reshape(model.logits, (-1, 1))
logits = logits[:tf.shape(x)[0] * reduction_factor, 0]

logits = tf.identity(logits, name = 'logits')

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [6]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

In [7]:
path = 'srgan-mae'
ckpt_path = tf.train.latest_checkpoint(path)
ckpt_path

'srgan-mae/model.ckpt-795000'

In [8]:
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
saver = tf.train.Saver(var_list = var_list)
saver.restore(sess, ckpt_path)

INFO:tensorflow:Restoring parameters from srgan-mae/model.ckpt-795000


In [9]:
import IPython.display as ipd
import museval
import matplotlib.pyplot as plt
from glob import glob

In [10]:
def get_pair(f):
    return f.split('/')[-1].split('-')[0]


def read_wav(f):
    return malaya_speech.load(f, sr = sr)

In [11]:
Y = glob('testset-super-resolution/*-y.wav')
len(Y)

100

In [12]:
from tqdm import tqdm
from glob import glob

sdrs, isrs, sars = [], [], []

for y in tqdm(Y):
    y_ = f'testset-super-resolution/{get_pair(y)}-y_.wav'
    y, _ = read_wav(y)
    y_, _ = read_wav(y_)
    y__ = sess.run(logits, feed_dict = {x: y_})
    sdr, isr, _, sar = museval.evaluate(np.expand_dims(y__, 0), 
                                    np.expand_dims(y, 0))
    sdrs.append(np.nanmean(sdr))
    isrs.append(np.nanmean(isr))
    sars.append(np.nanmean(sar))

100%|██████████| 100/100 [03:16<00:00,  1.96s/it]


In [13]:
np.mean(sdrs), np.mean(isrs), np.mean(sars)

(16.34558233829829, 22.067493520420438, 17.02439164454031)

In [14]:
l, sr_ = malaya_speech.load('89.wav', sr = sr // reduction_factor)
len(l) / sr_, sr_

(4.106031746031746, 11025)

In [15]:
ipd.Audio(l, rate = sr_)

In [16]:
y_ = sess.run(logits, feed_dict = {x: l})

In [17]:
ipd.Audio(y_, rate = sr_ * reduction_factor)

In [18]:
resampled = malaya_speech.resample(l, sr // reduction_factor, sr)
ipd.Audio(resampled, rate = sr)

In [19]:
saver = tf.train.Saver()
saver.save(sess, 'srgan-output/model.ckpt')

'srgan-output/model.ckpt'

In [20]:
strings = ','.join(
    [
        n.name
        for n in tf.get_default_graph().as_graph_def().node
        if ('Variable' in n.op
        or 'Placeholder' in n.name
        or 'logits' in n.name)
        and 'adam' not in n.name
        and 'beta' not in n.name
        and 'global_step' not in n.name
        and 'Assign' not in n.name
    ]
)
strings.split(',')

['Placeholder',
 'generator/conv1d/kernel/Read/ReadVariableOp',
 'generator/conv1d/bias/Read/ReadVariableOp',
 'generator/conv1d/conv1d/ExpandDims_1/ReadVariableOp',
 'generator/conv1d/BiasAdd/ReadVariableOp',
 'generator/p_re_lu/alpha/Read/ReadVariableOp',
 'generator/p_re_lu/ReadVariableOp',
 'generator/conv1d_1/kernel/Read/ReadVariableOp',
 'generator/conv1d_1/bias/Read/ReadVariableOp',
 'generator/conv1d_1/conv1d/ExpandDims_1/ReadVariableOp',
 'generator/conv1d_1/BiasAdd/ReadVariableOp',
 'generator/batch_normalization/gamma/Read/ReadVariableOp',
 'generator/batch_normalization/moving_mean/Read/ReadVariableOp',
 'generator/batch_normalization/moving_variance/Read/ReadVariableOp',
 'generator/batch_normalization/batchnorm/mul/ReadVariableOp',
 'generator/batch_normalization/batchnorm/ReadVariableOp',
 'generator/p_re_lu_1/alpha/Read/ReadVariableOp',
 'generator/p_re_lu_1/ReadVariableOp',
 'generator/conv1d_2/kernel/Read/ReadVariableOp',
 'generator/conv1d_2/bias/Read/ReadVariableOp'

In [21]:
def freeze_graph(model_dir, output_node_names):

    if not tf.gfile.Exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            'directory: %s' % model_dir
        )

    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + '/frozen_model.pb'
    clear_devices = True
    with tf.Session(graph = tf.Graph()) as sess:
        saver = tf.train.import_meta_graph(
            input_checkpoint + '.meta', clear_devices = clear_devices
        )
        saver.restore(sess, input_checkpoint)
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            tf.get_default_graph().as_graph_def(),
            output_node_names.split(','),
        )
        with tf.gfile.GFile(output_graph, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
        print('%d ops in the final graph.' % len(output_graph_def.node))

In [22]:
freeze_graph('srgan-output', strings)

INFO:tensorflow:Restoring parameters from srgan-output/model.ckpt
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
INFO:tensorflow:Froze 225 variables.
INFO:tensorflow:Converted 225 variables to const ops.
1568 ops in the final graph.


In [23]:
def load_graph(frozen_graph_filename, **kwargs):
    with tf.gfile.GFile(frozen_graph_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # https://github.com/onnx/tensorflow-onnx/issues/77#issuecomment-445066091
    # to fix import T5
    for node in graph_def.node:
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in xrange(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] = node.input[index] + '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
        elif node.op == 'AssignAdd':
            node.op = 'Add'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
        elif node.op == 'Assign':
            node.op = 'Identity'
            if 'use_locking' in node.attr:
                del node.attr['use_locking']
            if 'validate_shape' in node.attr:
                del node.attr['validate_shape']
            if len(node.input) == 2:
                node.input[0] = node.input[1]
                del node.input[1]

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def)
    return graph

In [24]:
g = load_graph('srgan-output/frozen_model.pb')
x = g.get_tensor_by_name('import/Placeholder:0')
logits = g.get_tensor_by_name('import/logits:0')

In [25]:
test_sess = tf.InteractiveSession(graph = g)



In [26]:
y_ = test_sess.run(logits, feed_dict = {x: l})

In [27]:
ipd.Audio(y_, rate = sr)

In [28]:
from tensorflow.tools.graph_transforms import TransformGraph

pb = 'srgan-output/frozen_model.pb'

transforms = ['add_default_attributes',
             'remove_nodes(op=Identity, op=CheckNumerics)',
             'fold_batch_norms',
             'fold_old_batch_norms',
             'quantize_weights(fallback_min=-1024, fallback_max=1024)',
             'strip_unused_nodes',
             'sort_by_execution_order']

input_graph_def = tf.GraphDef()
with tf.gfile.FastGFile(pb, 'rb') as f:
    input_graph_def.ParseFromString(f.read())

transformed_graph_def = TransformGraph(input_graph_def, 
                                           ['Placeholder'],
                                           ['logits'], transforms)
    
with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:
    f.write(transformed_graph_def.SerializeToString())

Instructions for updating:
Use tf.gfile.GFile.


In [29]:
g = load_graph('srgan-output/frozen_model.pb.quantized')
x = g.get_tensor_by_name('import/Placeholder:0')
logits = g.get_tensor_by_name('import/logits:0')
test_sess = tf.InteractiveSession(graph = g)

In [None]:
y_ = test_sess.run(logits, feed_dict = {x: l})
ipd.Audio(y_, rate = sr)