Skip to content

Commit

Permalink
Implement a native client using TensorFlow C API
Browse files Browse the repository at this point in the history
  • Loading branch information
Cwiiis committed Feb 23, 2017
1 parent b876f6f commit 2740336
Show file tree
Hide file tree
Showing 50 changed files with 5,752 additions and 11 deletions.
32 changes: 27 additions & 5 deletions DeepSpeech.py
Expand Up @@ -17,6 +17,7 @@
from tensorflow.contrib.session_bundle import exporter
from tensorflow.python.ops import ctc_ops
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.python.tools import freeze_graph
from util.gpu import get_available_gpus
from util.log import merge_logs
from util.spell import correction
Expand Down Expand Up @@ -1151,7 +1152,7 @@ def train():
# Run inference

# Input tensor will be of shape [batch_size, n_steps, n_input + 2*n_input*n_context]
input_tensor = tf.placeholder(tf.float32, [None, None, n_input + 2*n_input*n_context])
input_tensor = tf.placeholder(tf.float32, [None, None, n_input + 2*n_input*n_context], name='input_node')

# Calculate input sequence length. This is done by tiling n_steps, batch_size times.
# If there are multiple sequences, it is assumed they are padded with zeros to be of
Expand All @@ -1166,7 +1167,7 @@ def train():
# Beam search decode the batch
decoded, _ = ctc_ops.ctc_beam_search_decoder(logits, seq_length, merge_repeated=False)
decoded = tf.convert_to_tensor(
[tf.sparse_tensor_to_dense(sparse_tensor) for sparse_tensor in decoded])
[tf.sparse_tensor_to_dense(sparse_tensor) for sparse_tensor in decoded], name='output_node')

# TODO: Transform the decoded output to a string

Expand All @@ -1178,8 +1179,9 @@ def train():
# TODO: This restores the most recent checkpoint, but if we use validation to counterract
# over-fitting, we may want to restore an earlier checkpoint.
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
saver.restore(session, checkpoint.model_checkpoint_path)
print 'Restored checkpoint at training epoch %d' % (int(checkpoint.model_checkpoint_path.split('-')[-1]) + 1)
checkpoint_path = checkpoint.model_checkpoint_path
saver.restore(session, checkpoint_path)
print 'Restored checkpoint at training epoch %d' % (int(checkpoint_path.split('-')[-1]) + 1)

# Initialise the model exporter and export the model
model_exporter.init(session.graph.as_graph_def(),
Expand All @@ -1194,8 +1196,28 @@ def train():
print 'Removing old export'
shutil.rmtree(actual_export_dir)
try:
# Export serving model
model_exporter.export(export_dir, tf.constant(export_version), session)
print 'Model exported at %s' % (export_dir)

# Export graph
input_graph_name = 'input_graph.pb'
tf.train.write_graph(session.graph, export_dir, input_graph_name, as_text=False)

# Freeze graph
input_graph_path = os.path.join(export_dir, input_graph_name)
input_saver_def_path = ''
input_binary = True
output_node_names = 'output_node'
restore_op_name = 'save/restore_all'
filename_tensor_name = 'save/Const:0'
output_graph_path = os.path.join(export_dir, 'output_graph.pb')
clear_devices = False
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_graph_path, clear_devices, '')

print 'Models exported at %s' % (export_dir)
except RuntimeError:
print sys.exc_info()[1]

Expand Down
19 changes: 19 additions & 0 deletions native_client/BUILD
@@ -0,0 +1,19 @@
# Description: Deepspeech native client library.

cc_library(
name = "deepspeech",
srcs = ["deepspeech.cc",
"c_speech_features/c_speech_features.c",
"kiss_fft130/kiss_fft.c",
"kiss_fft130/tools/kiss_fftr.c"],
hdrs = ["deepspeech.h",
"c_speech_features/c_speech_features.h",
"kiss_fft130/kiss_fft.h",
"kiss_fft130/_kiss_fft_guts.h",
"kiss_fft130/tools/kiss_fftr.h"],
includes = ["c_speech_features",
"kiss_fft130"],
deps = [
"//tensorflow/core:tensorflow"
]
)
14 changes: 14 additions & 0 deletions native_client/Makefile
@@ -0,0 +1,14 @@

TFDIR ?= ../../tensorflow
CFLAGS ?= -O2 -Wall

default: deepspeech

clean:
rm -f deepspeech

deepspeech: client.cc
c++ -o deepspeech ${CFLAGS} client.cc `pkg-config --cflags --libs sox` -L${TFDIR}/bazel-bin/tensorflow -L${TFDIR}/bazel-bin/native_client -ldeepspeech -ltensorflow

run: deepspeech
LD_LIBRARY_PATH=${TFDIR}/bazel-bin/tensorflow:${TFDIR}/bazel-bin/native_client:${LD_LIBRARY_PATH} ./deepspeech ${ARGS}
48 changes: 48 additions & 0 deletions native_client/README.md
@@ -0,0 +1,48 @@
# DeepSpeech native client

A native client for running queries on an exported DeepSpeech model.

## Requirements

* [TensorFlow source](https://www.tensorflow.org/install/install_sources)
* [libsox](https://sourceforge.net/projects/sox/)

## Preparation

Create a symbolic link in the TensorFlow checkout to the deepspeech `native_client` directory.

```
cd tensorflow
ln -s ../DeepSpeech/native_client ./
```

## Building

Before building the TensorFlow stand-alone library, you will need to prepare your environment to configure and build TensorFlow. Follow the [instructions](https://www.tensorflow.org/install/install_sources) on the TensorFlow site for your platform, up to the end of 'Configure the installation'.

To build the TensorFlow library, execute the following command:

```
bazel build -c opt //tensorflow:libtensorflow.so
```

Then you can build the DeepSpeech native library.

```
bazel build -c opt //native_client:deepspeech
```

Finally, you can change to the `native_client` directory and use the `Makefile`. By default, the `Makefile` will assume there is a TensorFlow checkout in a directory above the DeepSpeech checkout. If that is not the case, set the environment variable `TFDIR` to point to the right directory.

```
cd ../DeepSpeech/native_client
make deepspeech
```

## Running

The client can be run via the `Makefile`. The client will accept audio of any format your installation of SoX supports.

```
ARGS="/path/to/output_graph.pb /path/to/audio/file.ogg" make run
```

0 comments on commit 2740336

Please sign in to comment.