forked from nytimes/ingredient-phrase-tagger
-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
148 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import csv | ||
|
||
_REQUIRED_COLUMNS = ['input', 'name', 'qty', 'range_end', 'unit', 'comment'] | ||
|
||
|
||
class Error(Exception): | ||
pass | ||
|
||
|
||
class InvalidHeaderError(Error): | ||
pass | ||
|
||
|
||
class Reader(object): | ||
"""Reads labelled ingredient data formatted as a CSV. | ||
Input data must be a CSV file, encoded in UTF-8, and containing the | ||
following columns: | ||
input | ||
name | ||
qty | ||
range_end | ||
unit | ||
comment | ||
""" | ||
|
||
def __init__(self, data_file): | ||
self._csv_reader = csv.DictReader(data_file) | ||
for required_column in _REQUIRED_COLUMNS: | ||
if required_column not in self._csv_reader.fieldnames: | ||
raise InvalidHeaderError( | ||
'Data file is missing required column: %s' % | ||
required_column) | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def next(self): | ||
return _parse_row(self._csv_reader.next()) | ||
|
||
|
||
def _parse_row(row): | ||
"""Parses a row of raw data from a labelled ingredient CSV file. | ||
Args: | ||
row: A row of labelled ingredient data. This is modified in place so | ||
that any of its values that contain a number (e.g. "6.4") are | ||
converted to floats and the 'index' value is converted to an int. | ||
Returns: | ||
A dictionary representing the row's values, for example: | ||
{ | ||
'input': '1/2 cup yellow cornmeal', | ||
'name': 'yellow cornmeal', | ||
'qty': 0.5, | ||
'range_end': 0.0, | ||
'unit': 'cup', | ||
'comment': '', | ||
} | ||
""" | ||
# Certain rows have range_end set to empty. | ||
if row['range_end'] == '': | ||
range_end = 0.0 | ||
else: | ||
range_end = float(row['range_end']) | ||
|
||
return { | ||
'input': row['input'].decode('utf-8'), | ||
'name': row['name'].decode('utf-8'), | ||
'qty': float(row['qty']), | ||
'range_end': range_end, | ||
'unit': row['unit'].decode('utf-8'), | ||
'comment': row['comment'].decode('utf-8'), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import io | ||
import unittest | ||
|
||
from ingredient_phrase_tagger.training import labelled_data | ||
|
||
|
||
class ReaderTest(unittest.TestCase): | ||
|
||
def test_reads_valid_label_file(self): | ||
mock_file = io.BytesIO(""" | ||
index,input,name,qty,range_end,unit,comment | ||
63,4 to 6 large cloves garlic,garlic,4.0,6.0,clove, | ||
77,3 bananas,bananas,3.0,0.0,, | ||
106,"2 1/2 pounds bell peppers (about 6 peppers in assorted colors), cut into 2-inch chunks",bell peppers,2.5,0.0,pound,"(about 6 peppers in assorted colors), cut into 2-inch chunks" | ||
""".strip()) | ||
reader = labelled_data.Reader(mock_file) | ||
self.assertEqual([{ | ||
'input': u'4 to 6 large cloves garlic', | ||
'qty': 4.0, | ||
'unit': u'clove', | ||
'name': u'garlic', | ||
'range_end': 6.0, | ||
'comment': u'', | ||
}, { | ||
'input': u'3 bananas', | ||
'qty': 3.0, | ||
'unit': u'', | ||
'name': u'bananas', | ||
'comment': u'', | ||
'range_end': 0.0, | ||
}, { | ||
'input': (u'2 1/2 pounds bell peppers (about 6 peppers in ' | ||
u'assorted colors), cut into 2-inch chunks'), | ||
'qty': | ||
2.5, | ||
'unit': | ||
u'pound', | ||
'name': | ||
u'bell peppers', | ||
'range_end': | ||
0.0, | ||
'comment': (u'(about 6 peppers in assorted colors), cut into ' | ||
u'2-inch chunks'), | ||
}], [r for r in reader]) | ||
|
||
def test_interprets_empty_range_end_as_zero(self): | ||
mock_file = io.BytesIO(""" | ||
index,input,name,qty,range_end,unit,comment | ||
77,3 bananas,bananas,3.0,,, | ||
""".strip()) | ||
reader = labelled_data.Reader(mock_file) | ||
self.assertEqual({ | ||
'input': u'3 bananas', | ||
'qty': 3.0, | ||
'unit': u'', | ||
'name': u'bananas', | ||
'comment': u'', | ||
'range_end': 0.0, | ||
}, reader.next()) | ||
|
||
def test_raises_error_when_csv_does_not_have_required_columns(self): | ||
with self.assertRaises(labelled_data.InvalidHeaderError): | ||
mock_file = io.BytesIO(""" | ||
index,input,UNEXPECTED_COLUMN,qty,range_end,unit,comment | ||
77,3 bananas,bananas,3.0,0.0,, | ||
""".strip()) | ||
labelled_data.Reader(mock_file).next() |