Skip to content

Commit

Permalink
[scripts] Add script to limit <unk> history of ARPA files, for OOV-mo…
Browse files Browse the repository at this point in the history
…deling efficiency (#2483)
  • Loading branch information
olix20 authored and danpovey committed Jun 12, 2018
1 parent 8fd2ba5 commit be96725
Showing 1 changed file with 157 additions and 0 deletions.
157 changes: 157 additions & 0 deletions egs/wsj/s5/utils/lang/limit_arpa_unk_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#!/usr/bin/env python3

# Copyright 2018 Armin Oliya
# Apache 2.0.

'''
This script takes an existing ARPA lanugage model and limits the <unk> history
to make it suitable for downstream <unk> modeling.
This is for the case when you don't have access
to the original text corpus that is used for creating the LM.
If you do, you can use pocolm with the option --limit-unk-history=true.
This keeps the graph compact after adding the unk model.
'''

import argparse
import io
import re
import sys
from collections import defaultdict


parser = argparse.ArgumentParser(
description='''This script takes an existing ARPA lanugage model
and limits the <unk> history to make it suitable
for downstream <unk> modeling.
It supports up to 5-grams.''',
usage='''utils/lang/limit_arpa_unk_history.py
<oov-dict-entry> <input-arpa >output-arpa''',
epilog='''E.g.: gunzip -c src.arpa.gz |
utils/lang/limit_arpa_unk_history.py "<unk>" | gzip -c >dest.arpa.gz''')

parser.add_argument(
'oov_dict_entry',
help='oov identifier, for example "<unk>"', type=str)
args = parser.parse_args()


def get_ngram_stats(old_lm_lines):
ngram_counts = defaultdict(int)

for i in range(10):
g = re.search(r"ngram (\d)=(\d+)", old_lm_lines[i])
if g:
ngram_counts[int(g.group(1))] = int(g.group(2))

if len(ngram_counts) == 0:
sys.exit("""Couldn't get counts per ngram section.
The input doesn't seem to be a valid ARPA language model.""")

max_ngrams = list(ngram_counts.keys())[-1]
skip_rows = ngram_counts[1]

if max_ngrams > 5:
sys.exit("This script supports up to 5-gram language models.")

return max_ngrams, skip_rows, ngram_counts


def find_and_replace_unks(old_lm_lines, max_ngrams, skip_rows):
ngram_diffs = defaultdict(int)
unk_pattern = re.compile(
"[0-9.-]+(?:[\s\\t]\S+){1,3}[\s\\t]" + args.oov_dict_entry +
"[\s\\t](?!-[0-9]+\.[0-9]+).*")
backoff_pattern = re.compile(
"[0-9.-]+(?:[\s\\t]\S+){1,3}[\s\\t]<unk>[\s\\t]-[0-9]+\.[0-9]+")
passed_2grams, last_ngram = False, False
unk_row_count, backoff_row_count = 0, 0

print("Upadting the language model .. ", file=sys.stderr)
new_lm_lines = old_lm_lines[:skip_rows]

for i in range(skip_rows, len(old_lm_lines)):
line = old_lm_lines[i].strip()

if "\{}-grams:".format(3) in line:
passed_2grams = True
if "\{}-grams:".format(max_ngrams) in line:
last_ngram = True

# remove any n-gram states of the form: foo <unk> -> X
# that is, any n-grams of order > 2 where <unk>
# is the second-to-last word
# here we skip 1-gram and 2-gram sections of arpa

if passed_2grams:
g_unk = unk_pattern.search(line)
if g_unk:
ngram = len(g_unk.group(0).split()) - 1

This comment has been minimized.

Copy link
@xiaohui-zhang

xiaohui-zhang Jul 19, 2018

Contributor

Hello @olix20, I tested this script and it gives wrong headers (# ngrams). The reason is that here you are not aware that the last field could be the backoff prob here. So you can't get correct "ngram" in this way. You can easily get it as you did in line 75.

ngram_diffs[ngram] = ngram_diffs[ngram] - 1
unk_row_count += 1
continue

# remove backoff probability from the lines that end with <unk>
# for example, the -0.64 in -4.09 every <unk> -0.64
# here we skip the last n-gram section because it
# doesn't include backoff probabilities

if not last_ngram:
g_backoff = backoff_pattern.search(line)
if g_backoff:
updated_row = g_backoff.group(0).split()[:-1]
updated_row = updated_row[0] + \
"\t" + " ".join(updated_row[1:]) + "\n"
new_lm_lines.append(updated_row)
backoff_row_count += 1
continue

new_lm_lines.append(line+"\n")

print("Removed {} lines including {} as second-to-last term.".format(
unk_row_count, args.oov_dict_entry), file=sys.stderr)
print("Removed backoff probabilties from {} lines.".format(
backoff_row_count), file=sys.stderr)

return new_lm_lines, ngram_diffs


def read_old_lm():
print("Reading ARPA LM frome input stream .. ", file=sys.stderr)

with io.TextIOWrapper(
sys.stdin.buffer,
encoding="latin-1") as input_stream:
old_lm_lines = input_stream.readlines()

return old_lm_lines


def write_new_lm(new_lm_lines, ngram_counts, ngram_diffs):
''' Update n-gram counts that go in the header of the arpa lm '''

for i in range(10):
g = re.search(r"ngram (\d)=(\d+)", new_lm_lines[i])
if g:
n = int(g.group(1))
if n in ngram_diffs:
# ngram_diffs contains negative values
new_num_ngrams = ngram_counts[n] + ngram_diffs[n]
new_lm_lines[i] = "ngram {}={}\n".format(
n, new_num_ngrams)

with io.TextIOWrapper(
sys.stdout.buffer,
encoding="latin-1") as output_stream:
output_stream.writelines(new_lm_lines)


def main():
old_lm_lines = read_old_lm()
max_ngrams, skip_rows, ngram_counts = get_ngram_stats(old_lm_lines)
new_lm_lines, ngram_diffs = find_and_replace_unks(
old_lm_lines, max_ngrams, skip_rows)
write_new_lm(new_lm_lines, ngram_counts, ngram_diffs)


if __name__ == "__main__":
main()

0 comments on commit be96725

Please sign in to comment.