New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring generate_data, adding unit tests #50

Closed
wants to merge 1 commit into
base: master
from
Jump to file or symbol
Failed to load files and symbols.
+226 −117
Diff settings

Always

Just for now

Copy path View file
@@ -21,5 +21,8 @@ yapf \
# Run static analysis for Python bugs/cruft.
pyflakes bin/ ingredient_phrase_tagger/
# Run unit tests.
python -m unittest discover
# Run E2E tests.
bash ./test_e2e
@@ -1,9 +1,7 @@
import re
import decimal
import optparse
import pandas as pd
import utils
import translator
class Cli(object):
@@ -30,122 +28,10 @@ def generate_data(self, count, offset):
for index, row in df_slice.iterrows():
try:
# extract the display name
display_input = utils.cleanUnicodeFractions(row["input"])
tokens = utils.tokenize(display_input)
del (row["input"])
rowData = self.addPrefixes(
[(t, self.matchUp(t, row)) for t in tokens])
for i, (token, tags) in enumerate(rowData):
features = utils.getFeatures(token, i + 1, tokens)
print utils.joinLine(
[token] + features + [self.bestTag(tags)])
print translator.translate_row(row)
# ToDo: deal with this
except UnicodeDecodeError:
pass
print
def parseNumbers(self, s):
"""
Parses a string that represents a number into a decimal data type so that
we can match the quantity field in the db with the quantity that appears
in the display name. Rounds the result to 2 places.
"""
ss = utils.unclump(s)
m3 = re.match('^\d+$', ss)
if m3 is not None:
return decimal.Decimal(round(float(ss), 2))
m1 = re.match(r'(\d+)\s+(\d)/(\d)', ss)
if m1 is not None:
num = int(m1.group(1)) + (float(m1.group(2)) / float(m1.group(3)))
return decimal.Decimal(str(round(num, 2)))
m2 = re.match(r'^(\d)/(\d)$', ss)
if m2 is not None:
num = float(m2.group(1)) / float(m2.group(2))
return decimal.Decimal(str(round(num, 2)))
return None
def matchUp(self, token, ingredientRow):
"""
Returns our best guess of the match between the tags and the
words from the display text.
This problem is difficult for the following reasons:
* not all the words in the display name have associated tags
* the quantity field is stored as a number, but it appears
as a string in the display name
* the comment is often a compilation of different comments in
the display name
"""
ret = []
# strip parens from the token, since they often appear in the
# display_name, but are removed from the comment.
token = utils.normalizeToken(token)
decimalToken = self.parseNumbers(token)
for key, val in ingredientRow.iteritems():
if isinstance(val, basestring):
for n, vt in enumerate(utils.tokenize(val)):
if utils.normalizeToken(vt) == token:
ret.append(key.upper())
elif decimalToken is not None:
try:
if val == decimalToken:
ret.append(key.upper())
except:
pass
return ret
def addPrefixes(self, data):
"""
We use BIO tagging/chunking to differentiate between tags
at the start of a tag sequence and those in the middle. This
is a common technique in entity recognition.
Reference: http://www.kdd.cis.ksu.edu/Courses/Spring-2013/CIS798/Handouts/04-ramshaw95text.pdf
"""
prevTags = None
newData = []
for n, (token, tags) in enumerate(data):
newTags = []
for t in tags:
p = "B" if ((prevTags is None) or (t not in prevTags)) else "I"
newTags.append("%s-%s" % (p, t))
newData.append((token, newTags))
prevTags = tags
return newData
def bestTag(self, tags):
if len(tags) == 1:
return tags[0]
# if there are multiple tags, pick the first which isn't COMMENT
else:
for t in tags:
if (t != "B-COMMENT") and (t != "I-COMMENT"):
return t
# we have no idea what to guess
return "OTHER"
print ''
def _parse_args(self, argv):
"""
@@ -0,0 +1,134 @@
import decimal
import re
import utils
def translate_row(row):
"""Translates a row of labeled data into CRF++-compatible tag strings.
Args:
row: A row of data from the input CSV of labeled ingredient data.
Returns:
The row of input converted to CRF++-compatible tags, e.g.
2\tI1\tL4\tNoCAP\tNoPAREN\tB-QTY
cups\tI2\tL4\tNoCAP\tNoPAREN\tB-UNIT
flour\tI3\tL4\tNoCAP\tNoPAREN\tB-NAME
"""
# extract the display name
display_input = utils.cleanUnicodeFractions(row['input'])
tokens = utils.tokenize(display_input)
del (row['input'])
rowData = _addPrefixes([(t, _matchUp(t, row)) for t in tokens])
translated = ''
for i, (token, tags) in enumerate(rowData):
features = utils.getFeatures(token, i + 1, tokens)
translated += utils.joinLine(
[token] + features + [_bestTag(tags)]) + '\n'
return translated
def _parseNumbers(s):
"""
Parses a string that represents a number into a decimal data type so that
we can match the quantity field in the db with the quantity that appears
in the display name. Rounds the result to 2 places.
"""
ss = utils.unclump(s)
m3 = re.match('^\d+$', ss)
if m3 is not None:
return decimal.Decimal(round(float(ss), 2))
m1 = re.match(r'(\d+)\s+(\d)/(\d)', ss)
if m1 is not None:
num = int(m1.group(1)) + (float(m1.group(2)) / float(m1.group(3)))
return decimal.Decimal(str(round(num, 2)))
m2 = re.match(r'^(\d)/(\d)$', ss)
if m2 is not None:
num = float(m2.group(1)) / float(m2.group(2))
return decimal.Decimal(str(round(num, 2)))
return None
def _matchUp(token, ingredientRow):
"""
Returns our best guess of the match between the tags and the
words from the display text.
This problem is difficult for the following reasons:
* not all the words in the display name have associated tags
* the quantity field is stored as a number, but it appears
as a string in the display name
* the comment is often a compilation of different comments in
the display name
"""
ret = []
# strip parens from the token, since they often appear in the
# display_name, but are removed from the comment.
token = utils.normalizeToken(token)
decimalToken = _parseNumbers(token)
for key, val in ingredientRow.iteritems():
if isinstance(val, basestring):
for n, vt in enumerate(utils.tokenize(val)):
if utils.normalizeToken(vt) == token:
ret.append(key.upper())
elif decimalToken is not None:
try:
if val == decimalToken:
ret.append(key.upper())
except:
pass
return ret
def _addPrefixes(data):
"""
We use BIO tagging/chunking to differentiate between tags
at the start of a tag sequence and those in the middle. This
is a common technique in entity recognition.
Reference: http://www.kdd.cis.ksu.edu/Courses/Spring-2013/CIS798/Handouts/04-ramshaw95text.pdf
"""
prevTags = None
newData = []
for n, (token, tags) in enumerate(data):
newTags = []
for t in tags:
p = "B" if ((prevTags is None) or (t not in prevTags)) else "I"
newTags.append("%s-%s" % (p, t))
newData.append((token, newTags))
prevTags = tags
return newData
def _bestTag(tags):
if len(tags) == 1:
return tags[0]
# if there are multiple tags, pick the first which isn't COMMENT
else:
for t in tags:
if (t != "B-COMMENT") and (t != "I-COMMENT"):
return t
# we have no idea what to guess
return "OTHER"
Copy path View file
No changes.
Copy path View file
@@ -0,0 +1,86 @@
import unittest
from ingredient_phrase_tagger.training import translator
class TranslatorTest(unittest.TestCase):
def test_translates_row_with_simple_phrase(self):
row = {
'index': 162,
'input': '2 cups flour',
'name': 'flour',
'qty': 2.0,
'range_end': 0.0,
'unit': 'cup',
'comment': '',
}
self.assertMultiLineEqual("""
2\tI1\tL4\tNoCAP\tNoPAREN\tB-QTY
cups\tI2\tL4\tNoCAP\tNoPAREN\tB-UNIT
flour\tI3\tL4\tNoCAP\tNoPAREN\tB-NAME
""".strip(),
translator.translate_row(row).strip())
def test_translates_row_with_simple_fraction(self):
row = {
'index': 161,
'input': '1/2 cup yellow cornmeal',
'name': 'yellow cornmeal',
'qty': 0.5,
'range_end': 0.0,
'unit': 'cup',
'comment': '',
}
self.assertMultiLineEqual("""
1/2\tI1\tL8\tNoCAP\tNoPAREN\tB-QTY
cup\tI2\tL8\tNoCAP\tNoPAREN\tB-UNIT
yellow\tI3\tL8\tNoCAP\tNoPAREN\tB-NAME
cornmeal\tI4\tL8\tNoCAP\tNoPAREN\tI-NAME
""".strip(),
translator.translate_row(row).strip())
def test_translates_row_with_complex_fraction(self):
row = {
'index': 158,
'input': '1 1/2 teaspoons salt',
'name': 'salt',
'qty': 1.5,
'range_end': 0.0,
'unit': 'teaspoon',
'comment': '',
}
self.assertMultiLineEqual("""
1$1/2\tI1\tL4\tNoCAP\tNoPAREN\tB-QTY
teaspoons\tI2\tL4\tNoCAP\tNoPAREN\tB-UNIT
salt\tI3\tL4\tNoCAP\tNoPAREN\tB-NAME
""".strip(),
translator.translate_row(row).strip())
def test_translates_row_with_comment(self):
row = {
'index': 412,
'input': 'Half a vanilla bean, split lengthwise, seeds scraped',
'name': 'vanilla bean',
'qty': 0.5,
'range_end': 0.0,
'unit': '',
'comment': 'split lengthwise, seeds scraped',
}
self.assertMultiLineEqual("""
Half\tI1\tL12\tYesCAP\tNoPAREN\tOTHER
a\tI2\tL12\tNoCAP\tNoPAREN\tOTHER
vanilla\tI3\tL12\tNoCAP\tNoPAREN\tB-NAME
bean\tI4\tL12\tNoCAP\tNoPAREN\tI-NAME
,\tI5\tL12\tNoCAP\tNoPAREN\tB-COMMENT
split\tI6\tL12\tNoCAP\tNoPAREN\tI-COMMENT
lengthwise\tI7\tL12\tNoCAP\tNoPAREN\tI-COMMENT
,\tI8\tL12\tNoCAP\tNoPAREN\tI-COMMENT
seeds\tI9\tL12\tNoCAP\tNoPAREN\tI-COMMENT
scraped\tI10\tL12\tNoCAP\tNoPAREN\tI-COMMENT
""".strip(),
translator.translate_row(row).strip())
ProTip! Use n and p to navigate between commits in a pull request.