diff --git a/ingredient_phrase_tagger/training/translator.py b/ingredient_phrase_tagger/training/translator.py index c45fbbd..2b4aa58 100644 --- a/ingredient_phrase_tagger/training/translator.py +++ b/ingredient_phrase_tagger/training/translator.py @@ -20,18 +20,34 @@ def translate_row(row): # extract the display name display_input = utils.cleanUnicodeFractions(row['input']) tokens = utils.tokenize(display_input) - del (row['input']) - rowData = _addPrefixes([(t, _matchUp(t, row)) for t in tokens]) + labels = _row_to_labels(row) + label_data = _addPrefixes([(t, _matchUp(t, labels)) for t in tokens]) translated = '' - for i, (token, tags) in enumerate(rowData): + for i, (token, tags) in enumerate(label_data): features = utils.getFeatures(token, i + 1, tokens) translated += utils.joinLine( [token] + features + [_bestTag(tags)]) + '\n' return translated +def _row_to_labels(row): + """Extracts labels from a labelled ingredient data row. + + Args: + A row of full data about an ingredient, including input and labels. + + Returns: + A dictionary of the label data extracted from the row. + """ + labels = {} + label_keys = ['name', 'qty', 'range_end', 'unit', 'comment'] + for key in label_keys: + labels[key] = row[key] + return labels + + def _parseNumbers(s): """ Parses a string that represents a number into a decimal data type so that @@ -57,7 +73,7 @@ def _parseNumbers(s): return None -def _matchUp(token, ingredientRow): +def _matchUp(token, labels): """ Returns our best guess of the match between the tags and the words from the display text. @@ -77,21 +93,19 @@ def _matchUp(token, ingredientRow): token = utils.normalizeToken(token) decimalToken = _parseNumbers(token) - # Note: We iterate in this specific order to preserve parity with the - # legacy implementation. The legacy implementation is likely incorrect and - # shouldn't actually include 'index', but we will revisit when we're ready - # to change behavior. - for key in ['index', 'name', 'qty', 'range_end', 'unit', 'comment']: - val = ingredientRow[key] - if isinstance(val, basestring): - - for n, vt in enumerate(utils.tokenize(val)): + # Iterate through the labels in descending order of label importance. + # TODO(mtlynch): Reorder this list so that it is in better order of + # importance. + for label_key in ['name', 'qty', 'range_end', 'unit', 'comment']: + label_value = labels[label_key] + if isinstance(label_value, basestring): + for n, vt in enumerate(utils.tokenize(label_value)): if utils.normalizeToken(vt) == token: - ret.append(key.upper()) + ret.append(label_key.upper()) elif decimalToken is not None: - if val == decimalToken: - ret.append(key.upper()) + if label_value == decimalToken: + ret.append(label_key.upper()) return ret diff --git a/tests/golden/training_data.crf b/tests/golden/training_data.crf index 1a317e3..c76022b 100644 --- a/tests/golden/training_data.crf +++ b/tests/golden/training_data.crf @@ -16,7 +16,7 @@ squash I15 L20 NoCAP NoPAREN B-NAME , I16 L20 NoCAP NoPAREN OTHER defrosted I17 L20 NoCAP NoPAREN I-COMMENT -1 I1 L20 NoCAP NoPAREN B-INDEX +1 I1 L20 NoCAP NoPAREN B-QTY cup I2 L20 NoCAP NoPAREN B-UNIT peeled I3 L20 NoCAP NoPAREN I-COMMENT and I4 L20 NoCAP NoPAREN I-COMMENT @@ -29,7 +29,7 @@ about I9 L20 NoCAP YesPAREN I-COMMENT ) I11 L20 NoCAP YesPAREN I-COMMENT , I12 L20 NoCAP NoPAREN OTHER or I13 L20 NoCAP NoPAREN I-COMMENT -1 I14 L20 NoCAP NoPAREN B-INDEX +1 I14 L20 NoCAP NoPAREN B-QTY cup I15 L20 NoCAP NoPAREN B-UNIT canned I16 L20 NoCAP NoPAREN I-COMMENT , I17 L20 NoCAP NoPAREN OTHER