Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[src] batch renormalization finished #65

Open
wants to merge 30 commits into
base: svd_draft
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
37091d6
[egs] Update Librispeech RNNLM results; use correct training data (#2…
keli78 Dec 6, 2018
b50a4cf
[scripts] RNNLM: old iteration model cleanup; save space (#2885)
slckl Dec 7, 2018
a464bd7
[scripts] Make prepare_lang.sh cleanup beforehand (prevents certain f…
danpovey Dec 11, 2018
c41cbb1
[scripts] Expose dim-range-node at xconfig level (#2903)
yangxueruivs Dec 11, 2018
aa0ac7b
[scripts] Fix bug related to multi-task in train_raw_rnn.py (#2907)
danpovey Dec 12, 2018
3e50be9
[scripts] Cosmetic fix/clarification to utils/prepare_lang.sh (#2912)
danpovey Dec 12, 2018
791cd82
[scripts,egs] Added a new lexicon learning (adaptation) recipe for te…
xiaohui-zhang Dec 14, 2018
b126161
[egs] TDNN+LSTM example scripts, with RNNLM, for Librispeech (#2857)
GaofengCheng Dec 15, 2018
78f0127
[src] cosmetic fix in nnet1 code (#2921)
csukuangfj Dec 17, 2018
44980dd
[src] Fix incorrect invocation of mutex in nnet-batch-compute code (#…
danpovey Dec 21, 2018
a46f554
[egs,minor] Fix typo in comment in voxceleb script (#2926)
corollari Dec 23, 2018
2edb074
[src,egs] Mostly cosmetic changes; add some missing includes (#2936)
yzmyyff Dec 24, 2018
9b320ad
[egs] Fix path of rescoring binaries used in tfrnnlm scripts (#2941)
virenderkadyan Dec 27, 2018
3b0162b
[src] Fix bug in nnet3-latgen-faster-batch for determinize=false (#2945)
danpovey Dec 28, 2018
b984543
[egs] Add example for rimes handwriting database; Madcat arabic scrip…
aarora8 Dec 28, 2018
46826d9
[egs] Add scripts for yomdle korean (#2942)
aarora8 Dec 28, 2018
3e77220
[build] Refactor/cleanup build system, easier build on ubuntu 18.04. …
danpovey Dec 31, 2018
5a720ac
[scripts,egs] Changes for Python 2/3 compatibility (#2925)
desh2608 Dec 31, 2018
ca32c4e
[egs] Add more modern DNN recipe for fisher_callhome_spanish (#2951)
GoVivace Dec 31, 2018
1ea2ba7
[scripts] switch from bc to perl to reduce dependencies (diarization …
mmaciej2 Jan 1, 2019
969869c
[scripts] Further fix for Python 2/3 compatibility (#2957)
desh2608 Jan 2, 2019
8b800b5
First commit
GaofengCheng Jan 2, 2019
1f5c6eb
add backprop
GaofengCheng Jan 3, 2019
5fd0e8c
update
GaofengCheng Jan 3, 2019
0696624
Batch-renorm OK
GaofengCheng Jan 5, 2019
902540e
clean-up
GaofengCheng Jan 5, 2019
5162bd7
Update nnet-normalize-component.cc
GaofengCheng Jan 7, 2019
ca3ff04
Revert "Update nnet-normalize-component.cc"
GaofengCheng Jan 7, 2019
1732ae7
tempt-test
GaofengCheng Jan 8, 2019
bfe09be
add stats average for batch-renorm
GaofengCheng Jan 9, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
1 change: 1 addition & 0 deletions egs/aishell2/s5/local/word_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
# Apache 2.0

from __future__ import print_function
import sys
import jieba
reload(sys)
Expand Down
5 changes: 3 additions & 2 deletions egs/ami/s5/local/sort_bad_utts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python

from __future__ import print_function
import sys
import argparse
import logging
Expand Down Expand Up @@ -38,10 +39,10 @@ def GetSortedWers(utt_info_file):
utt_wer_sorted = sorted(utt_wer, key = lambda k : k[1])
try:
import numpy as np
bins = range(0,105,5)
bins = list(range(0,105,5))
bins.append(sys.float_info.max)

hist, bin_edges = np.histogram(map(lambda x: x[1], utt_wer_sorted),
hist, bin_edges = np.histogram([x[1] for x in utt_wer_sorted],
bins = bins)
num_utts = len(utt_wer)
string = ''
Expand Down
2 changes: 1 addition & 1 deletion egs/ami/s5/local/tfrnnlm/run_lstm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ if [ $stage -le 3 ]; then
decode_dir=${basedir}/decode_${decode_set}

# Lattice rescoring
steps/lmrescore_rnnlm_lat.sh \
steps/tfrnnlm/lmrescore_rnnlm_lat.sh \
--cmd "$tfrnnlm_cmd --mem 16G" \
--rnnlm-ver tensorflow --weight $weight --max-ngram-order $ngram_order \
data/lang_$LM $dir \
Expand Down
2 changes: 1 addition & 1 deletion egs/ami/s5/local/tfrnnlm/run_vanilla_rnnlm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ if [ $stage -le 3 ]; then
decode_dir=${basedir}/decode_${decode_set}

# Lattice rescoring
steps/lmrescore_rnnlm_lat.sh \
steps/tfrnnlm/lmrescore_rnnlm_lat.sh \
--cmd "$tfrnnlm_cmd --mem 16G" \
--rnnlm-ver tensorflow --weight $weight --max-ngram-order $ngram_order \
data/lang_$LM $dir \
Expand Down
1 change: 1 addition & 0 deletions egs/an4/s5/local/data_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import os
import re
import sys
Expand Down
1 change: 1 addition & 0 deletions egs/an4/s5/local/lexicon_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import os
import re
import sys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# creates a segments file in the provided data directory
# into uniform segments with specified window and overlap

from __future__ import division
import imp, sys, argparse, os, math, subprocess

min_segment_length = 10 # in seconds
def segment(total_length, window_length, overlap = 0):
increment = window_length - overlap
num_windows = int(math.ceil(float(total_length)/increment))
segments = map(lambda x: (x * increment, min( total_length, (x * increment) + window_length)), range(0, num_windows))
segments = [(x * increment, min( total_length, (x * increment) + window_length)) for x in range(0, num_windows)]
if segments[-1][1] - segments[-1][0] < min_segment_length:
segments[-2] = (segments[-2][0], segments[-1][1])
segments.pop()
Expand Down Expand Up @@ -53,7 +54,7 @@ def prepare_segments_file(kaldi_data_dir, window_length, overlap):
parser = argparse.ArgumentParser()
parser.add_argument('--window-length', type = float, default = 30.0, help = 'length of the window used to cut the segment')
parser.add_argument('--overlap', type = float, default = 5.0, help = 'overlap of neighboring windows')
parser.add_argument('data_dir', type=str, help='directory such as data/train')
parser.add_argument('data_dir', help='directory such as data/train')

params = parser.parse_args()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ def fill_ctm(input_ctm_file, output_ctm_file, recording_names):

sys.stderr.write(str(" ".join(sys.argv)))
parser = argparse.ArgumentParser(usage)
parser.add_argument('input_ctm_file', type=str, help='ctm file for the recordings')
parser.add_argument('output_ctm_file', type=str, help='ctm file for the recordings')
parser.add_argument('recording_name_file', type=str, help='file with names of the recordings')
parser.add_argument('input_ctm_file', help='ctm file for the recordings')
parser.add_argument('output_ctm_file', help='ctm file for the recordings')
parser.add_argument('recording_name_file', help='file with names of the recordings')

params = parser.parse_args()

try:
file_names = map(lambda x: x.strip(), open("{0}".format(params.recording_name_file)).readlines())
file_names = [x.strip() for x in open("{0}".format(params.recording_name_file)).readlines()]
except IOError:
raise Exception("Expected to find {0}".format(params.recording_name_file))

Expand Down
3 changes: 2 additions & 1 deletion egs/aspire/s5/local/multi_condition/get_air_file_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# script to generate the file_patterns of the AIR database
# see load_air.m file in AIR db to understand the naming convention
from __future__ import print_function
import sys, glob, re, os.path

air_dir = sys.argv[1]
Expand Down Expand Up @@ -45,4 +46,4 @@
file_patterns.append(file_pattern+" "+output_file_name)
file_patterns = list(set(file_patterns))
file_patterns.sort()
print "\n".join(file_patterns)
print("\n".join(file_patterns))
8 changes: 5 additions & 3 deletions egs/aspire/s5/local/multi_condition/normalize_wavs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

# normalizes the wave files provided in input file list with a common scaling factor
# the common scaling factor is computed to 1/\sqrt(1/(total_samples) * \sum_i{\sum_j x_i(j)^2}) where total_samples is sum of all samples of all wavefiles. If the data is multi-channel then each channel is treated as a seperate wave files
from __future__ import division
from __future__ import print_function
import argparse, scipy.io.wavfile, warnings, numpy as np, math

def get_normalization_coefficient(file_list, is_rir, additional_scaling):
Expand All @@ -29,7 +31,7 @@ def get_normalization_coefficient(file_list, is_rir, additional_scaling):
assert(rate == sampling_rate)
else:
sampling_rate = rate
data = data / dtype_max_value
data = data/dtype_max_value
if is_rir:
# just count the energy of the direct impulse response
# this is treated as energy of signal from 0.001 seconds before impulse
Expand All @@ -55,8 +57,8 @@ def get_normalization_coefficient(file_list, is_rir, additional_scaling):
except IOError:
warnings.warn("Did not find the file {0}.".format(file))
assert(total_samples > 0)
scaling_coefficient = np.sqrt(total_samples / total_energy)
print "Scaling coefficient is {0}.".format(scaling_coefficient)
scaling_coefficient = np.sqrt(total_samples/total_energy)
print("Scaling coefficient is {0}.".format(scaling_coefficient))
if math.isnan(scaling_coefficient):
raise Exception(" Nan encountered while computing scaling coefficient. This is mostly due to numerical overflow")
return scaling_coefficient
Expand Down
6 changes: 3 additions & 3 deletions egs/aspire/s5/local/multi_condition/read_rir.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def usage():
#sys.stderr.write(" ".join(sys.argv)+"\n")
parser = argparse.ArgumentParser(usage())
parser.add_argument('--output-sampling-rate', type = int, default = 8000, help = 'sampling rate of the output')
parser.add_argument('type', type = str, default = None, help = 'database type', choices = ['air'])
parser.add_argument('input', type = str, default = None, help = 'directory containing the multi-channel data for a particular recording, or file name or file-regex-pattern')
parser.add_argument('output_filename', type = str, default = None, help = 'output filename (if "-" then output is written to output pipe)')
parser.add_argument('type', default = None, help = 'database type', choices = ['air'])
parser.add_argument('input', default = None, help = 'directory containing the multi-channel data for a particular recording, or file name or file-regex-pattern')
parser.add_argument('output_filename', default = None, help = 'output filename (if "-" then output is written to output pipe)')
params = parser.parse_args()

if params.output_filename == "-":
Expand Down
14 changes: 8 additions & 6 deletions egs/aspire/s5/local/multi_condition/reverberate_wavs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
# script to generate multicondition training data / dev data / test data
import argparse, glob, math, os, random, scipy.io.wavfile, sys

class list_cyclic_iterator:
class list_cyclic_iterator(object):
def __init__(self, list, random_seed = 0):
self.list_index = 0
self.list = list
random.seed(random_seed)
random.shuffle(self.list)

def next(self):
def __next__(self):
item = self.list[self.list_index]
self.list_index = (self.list_index + 1) % len(self.list)
return item

next = __next__ # for Python 2

def return_nonempty_lines(lines):
new_lines = []
for line in lines:
Expand Down Expand Up @@ -71,15 +73,15 @@ def return_nonempty_lines(lines):
for i in range(len(wav_files)):
wav_file = " ".join(wav_files[i].split()[1:])
output_wav_file = wav_out_files[i]
impulse_file = impulses.next()
impulse_file = next(impulses)
noise_file = ''
snr = ''
found_impulse = False
if add_noise:
for i in xrange(len(impulse_noise_index)):
for i in range(len(impulse_noise_index)):
if impulse_file in impulse_noise_index[i][0]:
noise_file = impulse_noise_index[i][1].next()
snr = snrs.next()
noise_file = next(impulse_noise_index[i][1])
snr = next(snrs)
assert(len(wav_file.strip()) > 0)
assert(len(impulse_file.strip()) > 0)
assert(len(noise_file.strip()) > 0)
Expand Down
19 changes: 10 additions & 9 deletions egs/babel/s5b/local/lonestar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python
from __future__ import print_function
from pylauncher import *
import pylauncher
import sys
Expand Down Expand Up @@ -39,7 +40,7 @@ def KaldiLauncher(lo, **kwargs):

logfiles = list()
commands = list()
for q in xrange(lo.jobstart, lo.jobend+1):
for q in range(lo.jobstart, lo.jobend+1):
s = "bash " + lo.queue_scriptfile + " " + str(q)
commands.append(s)

Expand Down Expand Up @@ -74,7 +75,7 @@ def KaldiLauncher(lo, **kwargs):
time.sleep(delay);

lines=tail(10, logfile)
with_status=filter(lambda x:re.search(r'with status (\d+)', x), lines)
with_status=[x for x in lines if re.search(r'with status (\d+)', x)]

if len(with_status) == 0:
sys.stderr.write("The last line(s) of the log-file " + logfile + " does not seem"
Expand All @@ -98,7 +99,7 @@ def KaldiLauncher(lo, **kwargs):
sys.exit(-1);

#Remove service files. Be careful not to remove something that might be needed in problem diagnostics
for i in xrange(len(commands)):
for i in range(len(commands)):
out_file=os.path.join(qdir, ce.outstring+str(i))

#First, let's wait on files missing (it might be that those are missing
Expand Down Expand Up @@ -149,7 +150,7 @@ def KaldiLauncher(lo, **kwargs):

#print job.final_report()

class LauncherOpts:
class LauncherOpts(object):
def __init__(self):
self.sync=0
self.nof_threads = 1
Expand Down Expand Up @@ -199,7 +200,7 @@ def CmdLineParser(argv):
jobend=int(m.group(2))
argv.pop(0)
elif re.match("^.+=.*:.*$", argv[0]):
print >> sys.stderr, "warning: suspicious JOB argument " + argv[0];
print("warning: suspicious JOB argument " + argv[0], file=sys.stderr);

if jobstart > jobend:
sys.stderr.write("lonestar.py: JOBSTART("+ str(jobstart) + ") must be lower than JOBEND(" + str(jobend) + ")\n")
Expand Down Expand Up @@ -238,8 +239,8 @@ def setup_paths_and_vars(opts):
cwd = os.getcwd()

if opts.varname and (opts.varname not in opts.logfile ) and (opts.jobstart != opts.jobend):
print >>sys.stderr, "lonestar.py: you are trying to run a parallel job" \
"but you are putting the output into just one log file (" + opts.logfile + ")";
print("lonestar.py: you are trying to run a parallel job" \
"but you are putting the output into just one log file (" + opts.logfile + ")", file=sys.stderr);
sys.exit(1)

if not os.path.isabs(opts.logfile):
Expand All @@ -261,8 +262,8 @@ def setup_paths_and_vars(opts):
taskname=os.path.basename(queue_logfile)
taskname = taskname.replace(".log", "");
if taskname == "":
print >> sys.stderr, "lonestar.py: you specified the log file name in such form " \
"that leads to an empty task name ("+logfile + ")";
print("lonestar.py: you specified the log file name in such form " \
"that leads to an empty task name ("+logfile + ")", file=sys.stderr);
sys.exit(1)

if not os.path.isabs(queue_logfile):
Expand Down
29 changes: 15 additions & 14 deletions egs/babel/s5b/local/resegment/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright 2014 Vimal Manohar
# Apache 2.0

from __future__ import division
import os, glob, argparse, sys, re, time
from argparse import ArgumentParser

Expand All @@ -19,12 +20,12 @@

def mean(l):
if len(l) > 0:
return float(sum(l)) / len(l)
return (float(sum(l))/len(l))
return 0

# Analysis class
# Stores statistics like the confusion matrix, length of the segments etc.
class Analysis:
class Analysis(object):
def __init__(self, file_id, frame_shift, prefix):
self.confusion_matrix = [0] * 9
self.type_counts = [ [[] for j in range(0,9)] for i in range(0,3) ]
Expand Down Expand Up @@ -274,8 +275,8 @@ def read_rttm_file(rttm_file, temp_dir, frame_shift):
i = len(this_file)
category = splits[6]
word = splits[5]
start_time = int(float(splits[3])/frame_shift + 0.5)
duration = int(float(splits[4])/frame_shift + 0.5)
start_time = int((float(splits[3])/frame_shift) + 0.5)
duration = int((float(splits[4])/frame_shift) + 0.5)
if i < start_time:
this_file.extend(["0"]*(start_time - i))
if type1 == "NON-LEX":
Expand All @@ -295,7 +296,7 @@ def read_rttm_file(rttm_file, temp_dir, frame_shift):
# Stats class to store some basic stats about the number of
# times the post-processor goes through particular loops or blocks
# of code in the algorithm. This is just for debugging.
class Stats:
class Stats(object):
def __init__(self):
self.inter_utt_nonspeech = 0
self.merge_nonspeech_segment = 0
Expand All @@ -321,7 +322,7 @@ def reset(self):
self.noise_only = 0

# Timer class to time functions
class Timer:
class Timer(object):
def __enter__(self):
self.start = time.clock()
return self
Expand All @@ -332,7 +333,7 @@ def __exit__(self, *args):
# The main class for post-processing a file.
# This does the segmentation either looking at the file isolated
# or by looking at both classes simultaneously
class JointResegmenter:
class JointResegmenter(object):
def __init__(self, P, A, f, options, phone_map, stats = None, reference = None):

# Pointers to prediction arrays and Initialization
Expand Down Expand Up @@ -1290,22 +1291,22 @@ def main():
dest='hard_max_segment_length', default=15.0, \
help="Hard maximum on the segment length above which the segment " \
+ "will be broken even if in the middle of speech (default: %(default)s)")
parser.add_argument('--first-separator', type=str, \
parser.add_argument('--first-separator', \
dest='first_separator', default="-", \
help="Separator between recording-id and start-time (default: %(default)s)")
parser.add_argument('--second-separator', type=str, \
parser.add_argument('--second-separator', \
dest='second_separator', default="-", \
help="Separator between start-time and end-time (default: %(default)s)")
parser.add_argument('--remove-noise-only-segments', type=str, \
parser.add_argument('--remove-noise-only-segments', \
dest='remove_noise_only_segments', default="true", choices=("true", "false"), \
help="Remove segments that have only noise. (default: %(default)s)")
parser.add_argument('--min-inter-utt-silence-length', type=float, \
dest='min_inter_utt_silence_length', default=1.0, \
help="Minimum silence that must exist between two separate utterances (default: %(default)s)");
parser.add_argument('--channel1-file', type=str, \
parser.add_argument('--channel1-file', \
dest='channel1_file', default="inLine", \
help="String that matches with the channel 1 file (default: %(default)s)")
parser.add_argument('--channel2-file', type=str, \
parser.add_argument('--channel2-file', \
dest='channel2_file', default="outLine", \
help="String that matches with the channel 2 file (default: %(default)s)")
parser.add_argument('--isolated-resegmentation', \
Expand Down Expand Up @@ -1388,7 +1389,7 @@ def main():

speech_cap = None
if options.speech_cap_length != None:
speech_cap = int( options.speech_cap_length / options.frame_shift )
speech_cap = int(options.speech_cap_length/options.frame_shift)
# End if

for f in pred_files:
Expand Down Expand Up @@ -1454,7 +1455,7 @@ def main():
f2 = f3
# End if

if (len(A1) - len(A2)) > options.max_length_diff / options.frame_shift:
if (len(A1) - len(A2)) > int(options.max_length_diff/options.frame_shift):
sys.stderr.write( \
"%s: Warning: Lengths of %s and %s differ by more than %f. " \
% (sys.argv[0], f1,f2, options.max_length_diff) \
Expand Down
Loading