Skip to content

Commit

Permalink
fix some RDLM training options
Browse files Browse the repository at this point in the history
  • Loading branch information
rsennrich committed Apr 27, 2015
1 parent ce55bc4 commit da648fd
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 16 deletions.
4 changes: 2 additions & 2 deletions scripts/training/rdlm/README
Expand Up @@ -31,8 +31,8 @@ RDLM is split into two neural network models, which can be trained with

mkdir working_dir_head
mkdir working_dir_label
./train_rdlm.py --nplm-home /path/to/nplm --working-dir working_dir_head --output-dir /path/to/output_directory --output-model rdlm_head --mode head --output-vocab-size 500000 --noise-samples 100
./train_rdlm.py --nplm-home /path/to/nplm --working-dir working_dir_label --output-dir /path/to/output_directory --output-model rdlm_label --mode label --output-vocab-size 75 --noise-samples 50
./train_rdlm.py --nplm-home /path/to/nplm --corpus [your_training_corpus] --working-dir working_dir_head --output-dir /path/to/output_directory --output-model rdlm_head --mode head --output-vocab-size 500000 --noise 100
./train_rdlm.py --nplm-home /path/to/nplm --corpus [your_training_corpus] --working-dir working_dir_label --output-dir /path/to/output_directory --output-model rdlm_label --mode label --output-vocab-size 75 --noise 50

for more options, run `train_rdlm.py --help`. Parameters you may want to adjust
include the vocabulary size of the label model (depending on the number of
Expand Down
5 changes: 3 additions & 2 deletions scripts/training/rdlm/extract_syntactic_ngrams.py
Expand Up @@ -113,13 +113,14 @@ def get_syntactic_ngrams(xml, options, vocab, output_vocab, parent_heads=None, p
int_list.extend(parent_heads)
int_list.extend(parent_labels)

# write root of tree
if options.mode == 'label':
int_list.append(output_vocab.get(label, 0))
sys.stdout.write(' '.join(map(str, int_list)) + '\n')
options.output.write(' '.join(map(str, int_list)) + '\n')
elif options.mode == 'head' and not head == '<dummy_head>':
int_list.append(vocab.get(label, 0))
int_list.append(output_vocab.get(head, output_vocab.get(preterminal, 0)))
sys.stdout.write(' '.join(map(str, int_list)) + '\n')
options.output.write(' '.join(map(str, int_list)) + '\n')

parent_heads.append(vocab.get(head, 0))
parent_labels.append(vocab.get(label, 0))
Expand Down
4 changes: 0 additions & 4 deletions scripts/training/rdlm/extract_vocab.py
Expand Up @@ -59,10 +59,6 @@ def get_head(xml, args):
preterminal = child.get('label')
head = escape_text(child.text.strip())

# hack for split compounds
elif child[-1].get('label') == 'SEGMENT':
return escape_text(child[-1].text.strip()), 'SEGMENT'

elif args.ptkvz and head and child.get('label') == 'avz':
for grandchild in child:
if grandchild.get('label') == 'PTKVZ':
Expand Down
17 changes: 9 additions & 8 deletions scripts/training/rdlm/train_rdlm.py
Expand Up @@ -43,7 +43,7 @@
parser.add_argument("--input-words-file", dest="input_words_file", metavar="PATH", help="input vocabulary (default: %(default)s)")
parser.add_argument("--output-words-file", dest="output_words_file", metavar="PATH", help="output vocabulary (default: %(default)s)")
parser.add_argument("--input_vocab_size", dest="input_vocab_size", type=int, metavar="INT", help="input vocabulary size (default: %(default)s)")
parser.add_argument("--output_vocab_size", dest="output_vocab_size", type=int, metavar="INT", help="output vocabulary size (default: %(default)s)")
parser.add_argument("--output-vocab-size", dest="output_vocab_size", type=int, metavar="INT", help="output vocabulary size (default: %(default)s)")


parser.set_defaults(
Expand Down Expand Up @@ -95,7 +95,7 @@ def prepare_vocabulary(options):
filtered_vocab = open(orig).readlines()
orig = vocab_prefix + '.nonterminals'
filtered_vocab += open(orig).readlines()
filtered_vocab = [word for word in filtered_vocab if not word.startswith(prefix) for prefix in blacklist]
filtered_vocab = [word for word in filtered_vocab if not any(word.startswith(prefix) for prefix in blacklist)]
if options.output_vocab_size:
filtered_vocab = filtered_vocab[:options.output_vocab_size]
else:
Expand Down Expand Up @@ -127,12 +127,13 @@ def main(options):
sys.stderr.write('extracting syntactic n-grams\n')
extract_syntactic_ngrams.main(extract_options)

if validation_corpus:
extract_options.input = options.validation_corpus
options.validation_file = os.path.join(options.working_dir, os.path.basename(options.validation_corpus) + '.numberized')
extract_options.output = options.validation_file
if options.validation_corpus:
extract_options.input = open(options.validation_corpus)
options.validation_file = os.path.join(options.working_dir, os.path.basename(options.validation_corpus))
extract_options.output = open(options.validation_file + '.numberized', 'w')
sys.stderr.write('extracting syntactic n-grams (validation file)\n')
extract_syntactic_ngrams.main(extract_options)
extract_options.output.close()

sys.stderr.write('training neural network\n')
train_nplm.main(options)
Expand All @@ -141,8 +142,8 @@ def main(options):
ret = subprocess.call([os.path.join(sys.path[0], 'average_null_embedding.py'),
options.nplm_home,
os.path.join(options.output_dir, options.output_model + '.model.nplm.' + str(options.epochs)),
os.path.join(options.working_dir, options.corpus_stem + '.numberized'),
os.path.join(options.output_dir, options.output_model + '.model.nplm.')
os.path.join(options.working_dir, os.path.basename(options.corpus_stem) + '.numberized'),
os.path.join(options.output_dir, options.output_model + '.model.nplm')
])
if ret:
raise Exception("averaging null words failed")
Expand Down

0 comments on commit da648fd

Please sign in to comment.