Skip to content

Commit

Permalink
Merge of pull requests #49, #50, and #52. Fixes issues #2, #4, #11, #12
Browse files Browse the repository at this point in the history
…, #46, #47, and #48
  • Loading branch information
kdavis-mozilla committed Oct 13, 2016
1 parent 9fb60a7 commit a3abc9d
Show file tree
Hide file tree
Showing 12 changed files with 739 additions and 243 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -2,3 +2,5 @@
*.pyc
.DS_Store
/logs
/data/ted/TEDLIUM_release2
/data/ted/TEDLIUM_release2.tar.gz
127 changes: 62 additions & 65 deletions DeepSpeech.ipynb
Expand Up @@ -83,11 +83,12 @@
"import tempfile\n",
"import subprocess\n",
"import numpy as np\n",
"from math import ceil\n",
"import tensorflow as tf\n",
"from util.log import merge_logs\n",
"from util.gpu import get_available_gpus\n",
"from util.importers.ted_lium import read_data_sets\n",
"from util.text import sparse_tensor_value_to_text, wers\n",
"from util.text import sparse_tensor_value_to_texts, wers\n",
"from tensorflow.python.ops import ctc_ops"
]
},
Expand Down Expand Up @@ -123,11 +124,11 @@
"beta1 = 0.9 # TODO: Determine a reasonable value for this\n",
"beta2 = 0.999 # TODO: Determine a reasonable value for this\n",
"epsilon = 1e-8 # TODO: Determine a reasonable value for this\n",
"training_iters = 1250 # TODO: Determine a reasonable value for this\n",
"batch_size = 1 # TODO: Determine a reasonable value for this\n",
"training_iters = 15 # TODO: Determine a reasonable value for this\n",
"batch_size = 5 # TODO: Determine a reasonable value for this\n",
"display_step = 10 # TODO: Determine a reasonable value for this\n",
"validation_step = 50 # TODO: Determine a reasonable value for this\n",
"checkpoint_step = 1000 # TODO: Determine a reasonable value for this\n",
"checkpoint_step = 5 # TODO: Determine a reasonable value for this\n",
"checkpoint_dir = tempfile.gettempdir() # TODO: Determine a reasonable value for this"
]
},
Expand Down Expand Up @@ -191,14 +192,14 @@
"source": [
"Now we will introduce several constants related to the geometry of the network.\n",
"\n",
"The network views each speech sample as a sequence of time-slices $x^{(i)}_t$ of length $T^{(i)}$. As the speech samples vary in length, we know that $T^{(i)}$ need not equal $T^{(j)}$ for $i \\ne j$. However, BRNN in TensorFlow are unable to deal with sequences with differing lengths. Thus, we must pad speech sample sequences with trailing zeros such that they are all of the same length. This common padded length is captured in the variable `n_steps` which will be set after the data set is loaded. "
"The network views each speech sample as a sequence of time-slices $x^{(i)}_t$ of length $T^{(i)}$. As the speech samples vary in length, we know that $T^{(i)}$ need not equal $T^{(j)}$ for $i \\ne j$. For each batch, BRNN in TensorFlow needs to know `n_steps` which is the maximum $T^{(i)}$ for the batch."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each of the `n_steps` vectors is a vector of MFCC features of a time-slice of the speech sample. We will make the number of MFCC features dependent upon the sample rate of the data set. Generically, if the sample rate is 8kHz we use 13 features. If the sample rate is 16kHz we use 26 features... We capture the dimension of these vectors, equivalently the number of MFCC features, in the variable `n_input`"
"Each of the at maximum `n_steps` vectors is a vector of MFCC features of a time-slice of the speech sample. We will make the number of MFCC features dependent upon the sample rate of the data set. Generically, if the sample rate is 8kHz we use 13 features. If the sample rate is 16kHz we use 26 features... We capture the dimension of these vectors, equivalently the number of MFCC features, in the variable `n_input`"
]
},
{
Expand Down Expand Up @@ -604,10 +605,13 @@
},
"outputs": [],
"source": [
"def calculate_accuracy_and_loss(n_steps, batch_set):\n",
"def calculate_accuracy_and_loss(batch_set):\n",
" # Obtain the next batch of data\n",
" batch_x, batch_y, batch_seq_len = batch_set.next_batch(batch_size)\n",
" batch_x, batch_y, n_steps = ted_lium.train.next_batch()\n",
"\n",
" # Set batch_seq_len for the batch\n",
" batch_seq_len = batch_x.shape[0] * [n_steps]\n",
" \n",
" # Calculate the logits of the batch using BiRNN\n",
" logits = BiRNN(batch_x, n_steps)\n",
" \n",
Expand Down Expand Up @@ -639,14 +643,21 @@
"source": [
"The first lines of `calculate_accuracy_and_loss()`\n",
"```python\n",
"def calculate_accuracy_and_loss(n_steps, batch_set):\n",
"def calculate_accuracy_and_loss(batch_set):\n",
" # Obtain the next batch of data\n",
" batch_x, batch_y, batch_seq_len = batch_set.next_batch(batch_size)\n",
" batch_x, batch_y, n_steps = ted_lium.train.next_batch()\n",
"```\n",
"simply obtian the next mini-batch of data.\n",
"\n",
"The next line\n",
"```python\n",
" # Set batch_seq_len for the batch\n",
" batch_seq_len = batch_x.shape[0] * [n_steps]\n",
"```\n",
"creates `batch_seq_len` a list of the lengths of the sequences in `batch_x`. (As the sequences are zero padded to the same length, the list contains the value `n_steps` a total of `batch_x.shape[0]` times.)\n",
"\n",
"The next line\n",
"```python\n",
" # Calculate the logits from the BiRNN\n",
" logits = BiRNN(batch_x, n_steps)\n",
"```\n",
Expand Down Expand Up @@ -863,7 +874,7 @@
},
"outputs": [],
"source": [
"def get_tower_results(n_steps, batch_set, optimizer=None):\n",
"def get_tower_results(batch_set, optimizer=None):\n",
" # Tower decodings to return\n",
" tower_decodings = []\n",
" # Tower labels to return\n",
Expand All @@ -879,10 +890,7 @@
" with tf.name_scope('tower_%d' % i) as scope:\n",
" # Calculate the avg_loss and accuracy and retrieve the decoded \n",
" # batch along with the original batch's labels (Y) of this tower\n",
" avg_loss, accuracy, decoded, labels = calculate_accuracy_and_loss(\\\n",
" n_steps, \\\n",
" batch_set \\\n",
" )\n",
" avg_loss, accuracy, decoded, labels = calculate_accuracy_and_loss(batch_set)\n",
" \n",
" # Allow for variables to be re-used by the next tower\n",
" tf.get_variable_scope().reuse_variables()\n",
Expand Down Expand Up @@ -1090,17 +1098,8 @@
"outputs": [],
"source": [
"def decode_batch(data_set):\n",
" # Set n_steps parameter\n",
" n_steps = data_set.max_batch_seq_len\n",
"\n",
" # Calculate the total number of batches\n",
" total_batch = int(data_set.num_examples/batch_size)\n",
"\n",
" # Require that we have at least as many batches as devices\n",
" assert total_batch >= len(available_devices)\n",
" \n",
" # Get gradients for each tower (Runs across all GPU's)\n",
" tower_decodings, tower_labels, _, _, _ = get_tower_results(n_steps, data_set)\n",
" tower_decodings, tower_labels, _, _, _ = get_tower_results(data_set)\n",
" return tower_decodings, tower_labels\n",
" "
]
Expand Down Expand Up @@ -1130,8 +1129,8 @@
" # Iterating over the towers\n",
" for i in range(len(tower_decodings)):\n",
" decoded, labels = session.run([tower_decodings[i], tower_labels[i]], feed_dict)\n",
" originals.extend(sparse_tensor_value_to_text(labels))\n",
" results.extend(sparse_tensor_value_to_text(decoded))\n",
" originals.extend(sparse_tensor_value_to_texts(labels))\n",
" results.extend(sparse_tensor_value_to_texts(decoded))\n",
" \n",
" # Pairwise calculation of all rates\n",
" rates, mean = wers(originals, results)\n",
Expand Down Expand Up @@ -1186,24 +1185,18 @@
"outputs": [],
"source": [
"def train(session, data_sets):\n",
" # Set n_steps parameter\n",
" n_steps = data_sets.train.max_batch_seq_len\n",
"\n",
" # Calculate the total number of batches\n",
" total_batch = int(data_sets.train.num_examples/batch_size)\n",
"\n",
" # Require that we have at least as many batches as devices\n",
" assert total_batch >= len(available_devices)\n",
"\n",
" total_batches = data_sets.train.total_batches\n",
" \n",
" # Create optimizer\n",
" optimizer = create_optimizer()\n",
"\n",
" # Get gradients for each tower (Runs across all GPU's)\n",
" tower_decodings, tower_labels, tower_gradients, tower_loss, accuracy = \\\n",
" get_tower_results(n_steps, data_sets.train, optimizer)\n",
" get_tower_results(data_sets.train, optimizer)\n",
" \n",
" # Validation step preparation\n",
" validation_tower_decodings, validation_tower_labels = decode_batch(data_sets.validation)\n",
" validation_tower_decodings, validation_tower_labels = decode_batch(data_sets.dev)\n",
"\n",
" # Average tower gradients\n",
" avg_tower_gradients = average_gradients(tower_gradients)\n",
Expand Down Expand Up @@ -1239,22 +1232,22 @@
" print\n",
"\n",
" # Loop over the batches\n",
" for batch in range(total_batch/len(available_devices)):\n",
" for batch in range(int(ceil(float(total_batches)/len(available_devices)))):\n",
" # Compute the average loss for the last batch\n",
" _, batch_avg_loss = session.run([apply_gradient_op, tower_loss], feed_dict_train)\n",
"\n",
" # Add batch to total_accuracy\n",
" total_accuracy += session.run(accuracy, feed_dict_train)\n",
"\n",
" # Log all variable states in current step\n",
" step = epoch * total_batch + batch * len(available_devices)\n",
" step = epoch * total_batches + batch * len(available_devices)\n",
" summary_str = session.run(merged, feed_dict_train)\n",
" writer.add_summary(summary_str, step)\n",
" writer.flush()\n",
" \n",
" # Print progress message\n",
" if epoch % display_step == 0:\n",
" print \"Epoch:\", '%04d' % (epoch+1), \"avg_cer=\", \"{:.9f}\".format((total_accuracy / total_batch))\n",
" print \"Epoch:\", '%04d' % (epoch+1), \"avg_cer=\", \"{:.9f}\".format((total_accuracy / total_batches))\n",
" _, last_train_wer = print_wer_report(session, \"Training\", tower_decodings, tower_labels)\n",
" print\n",
"\n",
Expand Down Expand Up @@ -1285,24 +1278,26 @@
},
"outputs": [],
"source": [
"# Create session in which to execute\n",
"session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))\n",
"\n",
"# Obtain ted lium data\n",
"ted_lium = read_data_sets('./data/smoke_test', n_input, n_context)\n",
"\n",
"# Take start time for time measurement\n",
"time_started = datetime.datetime.utcnow()\n",
"\n",
"# Train the network\n",
"last_train_wer, last_validation_wer = train(session, ted_lium)\n",
"\n",
"# Take final time for time measurement\n",
"time_finished = datetime.datetime.utcnow()\n",
"\n",
"# Calculate duration in seconds\n",
"duration = time_finished - time_started\n",
"duration = duration.days * 86400 + duration.seconds"
"# Define CPU as device on which the muti-gpu training is orchestrated\n",
"with tf.device('/cpu:0'):\n",
" # Obtain ted lium data\n",
" ted_lium = read_data_sets(tf.get_default_graph(), './data/ted', batch_size, n_input, n_context)\n",
" \n",
" # Create session in which to execute\n",
" session = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))\n",
" \n",
" # Take start time for time measurement\n",
" time_started = datetime.datetime.utcnow()\n",
" \n",
" # Train the network\n",
" last_train_wer, last_validation_wer = train(session, ted_lium)\n",
" \n",
" # Take final time for time measurement\n",
" time_finished = datetime.datetime.utcnow()\n",
" \n",
" # Calculate duration in seconds\n",
" duration = time_finished - time_started\n",
" duration = duration.days * 86400 + duration.seconds"
]
},
{
Expand All @@ -1320,9 +1315,11 @@
},
"outputs": [],
"source": [
"# Test network\n",
"test_decodings, test_labels = decode_batch(ted_lium.test)\n",
"_, test_wer = print_wer_report(session, \"Test\", test_decodings, test_labels)"
"# Define CPU as device on which the muti-gpu testing is orchestrated\n",
"with tf.device('/cpu:0'):\n",
" # Test network\n",
" test_decodings, test_labels = decode_batch(ted_lium.test)\n",
" _, test_wer = print_wer_report(session, \"Test\", test_decodings, test_labels)"
]
},
{
Expand Down Expand Up @@ -1374,9 +1371,9 @@
" 'n_hidden_6': n_hidden_6, \\\n",
" 'n_cell_dim': n_cell_dim, \\\n",
" 'n_character': n_character, \\\n",
" 'num_examples_train': ted_lium.train.num_examples, \\\n",
" 'num_examples_validation': ted_lium.validation.num_examples, \\\n",
" 'num_examples_test': ted_lium.test.num_examples \\\n",
" 'total_batches_train': ted_lium.train.total_batches, \\\n",
" 'total_batches_validation': ted_lium.validation.total_batches, \\\n",
" 'total_batches_test': ted_lium.test.total_batches \\\n",
" }, \\\n",
" 'results': { \\\n",
" 'duration': duration, \\\n",
Expand Down Expand Up @@ -1422,7 +1419,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
"version": "2.7.11"
}
},
"nbformat": 4,
Expand Down
Empty file added data/ted/.gitkeep
Empty file.
68 changes: 68 additions & 0 deletions util/audio.py
@@ -0,0 +1,68 @@
import numpy as np
import scipy.io.wavfile as wav

from python_speech_features import mfcc

def audiofile_to_input_vector(audio_filename, numcep, numcontext):
# Load wav files
fs, audio = wav.read(audio_filename)

# Get mfcc coefficients
orig_inputs = mfcc(audio, samplerate=fs, numcep=numcep)

# For each time slice of the training set, we need to copy the context this makes
# the numcep dimensions vector into a numcep + 2*numcep*numcontext dimensions
# because of:
# - numcep dimensions for the current mfcc feature set
# - numcontext*numcep dimensions for each of the past and future (x2) mfcc feature set
# => so numcep + 2*numcontext*numcep
train_inputs = np.array([], np.float32)
train_inputs.resize((orig_inputs.shape[0], numcep + 2*numcep*numcontext))

# Prepare pre-fix post fix context (TODO: Fill empty_mfcc with MCFF of silence)
empty_mfcc = np.array([])
empty_mfcc.resize((numcep))

# Prepare train_inputs with past and future contexts
time_slices = range(train_inputs.shape[0])
context_past_min = time_slices[0] + numcontext
context_future_max = time_slices[-1] - numcontext
for time_slice in time_slices:
### Reminder: array[start:stop:step]
### slices from indice |start| up to |stop| (not included), every |step|
# Pick up to numcontext time slices in the past, and complete with empty
# mfcc features
need_empty_past = max(0, (context_past_min - time_slice))
empty_source_past = list(empty_mfcc for empty_slots in range(need_empty_past))
data_source_past = orig_inputs[max(0, time_slice - numcontext):time_slice]
assert(len(empty_source_past) + len(data_source_past) == numcontext)

# Pick up to numcontext time slices in the future, and complete with empty
# mfcc features
need_empty_future = max(0, (time_slice - context_future_max))
empty_source_future = list(empty_mfcc for empty_slots in range(need_empty_future))
data_source_future = orig_inputs[time_slice + 1:time_slice + numcontext + 1]
assert(len(empty_source_future) + len(data_source_future) == numcontext)

if need_empty_past:
past = np.concatenate((empty_source_past, data_source_past))
else:
past = data_source_past

if need_empty_future:
future = np.concatenate((data_source_future, empty_source_future))
else:
future = data_source_future

past = np.reshape(past, numcontext*numcep)
now = orig_inputs[time_slice]
future = np.reshape(future, numcontext*numcep)

train_inputs[time_slice] = np.concatenate((past, now, future))
assert(len(train_inputs[time_slice]) == numcep + 2*numcep*numcontext)

# Whiten inputs (TODO: Should we whiten)
train_inputs = (train_inputs - np.mean(train_inputs))/np.std(train_inputs)

# Return results
return train_inputs

0 comments on commit a3abc9d

Please sign in to comment.