Skip to content

Commit

Permalink
Merge f718ecd into a83442d
Browse files Browse the repository at this point in the history
  • Loading branch information
mtlynch committed May 2, 2018
2 parents a83442d + f718ecd commit 81517e6
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 31 deletions.
36 changes: 5 additions & 31 deletions ingredient_phrase_tagger/training/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import csv
import optparse

import labelled_data
import translator


Expand All @@ -22,15 +22,13 @@ def generate_data(self, count, offset):
start = int(offset)
end = int(offset) + int(count)

with open(self.opts.data_path) as csv_file:
csv_reader = csv.DictReader(csv_file)
for index, row in enumerate(csv_reader):
with open(self.opts.data_path) as data_file:
data_reader = labelled_data.Reader(data_file)
for index, row in enumerate(data_reader):
if index < start or index >= end:
continue

parsed_row = _parse_row(row)

print translator.translate_row(parsed_row).encode('utf-8')
print translator.translate_row(row).encode('utf-8')

def _parse_args(self, argv):
"""
Expand All @@ -48,27 +46,3 @@ def _parse_args(self, argv):

(options, args) = opts.parse_args(argv)
return options


def _parse_row(row):
"""Converts string values in a row to numbers where possible.
Args:
row: A row of labelled ingredient data. This is modified in place so
that any of its values that contain a number (e.g. "6.4") are
converted to floats and the 'index' value is converted to an int.
"""
# Certain rows have range_end set to empty.
if row['range_end'] == '':
range_end = 0.0
else:
range_end = float(row['range_end'])

return {
'input': row['input'].decode('utf-8'),
'name': row['name'].decode('utf-8'),
'qty': float(row['qty']),
'range_end': range_end,
'unit': row['unit'].decode('utf-8'),
'comment': row['comment'].decode('utf-8'),
}
76 changes: 76 additions & 0 deletions ingredient_phrase_tagger/training/labelled_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import csv

_REQUIRED_COLUMNS = ['input', 'name', 'qty', 'range_end', 'unit', 'comment']


class Error(Exception):
pass


class InvalidHeaderError(Error):
pass


class Reader(object):
"""Reads labelled ingredient data formatted as a CSV.
Input data must be a CSV file, encoded in UTF-8, and containing the
following columns:
input
name
qty
range_end
unit
comment
"""

def __init__(self, data_file):
self._csv_reader = csv.DictReader(data_file)
for required_column in _REQUIRED_COLUMNS:
if required_column not in self._csv_reader.fieldnames:
raise InvalidHeaderError(
'Data file is missing required column: %s' %
required_column)

def __iter__(self):
return self

def next(self):
return _parse_row(self._csv_reader.next())


def _parse_row(row):
"""Parses a row of raw data from a labelled ingredient CSV file.
Args:
row: A row of labelled ingredient data. This is modified in place so
that any of its values that contain a number (e.g. "6.4") are
converted to floats and the 'index' value is converted to an int.
Returns:
A dictionary representing the row's values, for example:
{
'input': '1/2 cup yellow cornmeal',
'name': 'yellow cornmeal',
'qty': 0.5,
'range_end': 0.0,
'unit': 'cup',
'comment': '',
}
"""
# Certain rows have range_end set to empty.
if row['range_end'] == '':
range_end = 0.0
else:
range_end = float(row['range_end'])

return {
'input': row['input'].decode('utf-8'),
'name': row['name'].decode('utf-8'),
'qty': float(row['qty']),
'range_end': range_end,
'unit': row['unit'].decode('utf-8'),
'comment': row['comment'].decode('utf-8'),
}
67 changes: 67 additions & 0 deletions tests/test_labelled_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import io
import unittest

from ingredient_phrase_tagger.training import labelled_data


class ReaderTest(unittest.TestCase):

def test_reads_valid_label_file(self):
mock_file = io.BytesIO("""
index,input,name,qty,range_end,unit,comment
63,4 to 6 large cloves garlic,garlic,4.0,6.0,clove,
77,3 bananas,bananas,3.0,0.0,,
106,"2 1/2 pounds bell peppers (about 6 peppers in assorted colors), cut into 2-inch chunks",bell peppers,2.5,0.0,pound,"(about 6 peppers in assorted colors), cut into 2-inch chunks"
""".strip())
reader = labelled_data.Reader(mock_file)
self.assertEqual([{
'input': u'4 to 6 large cloves garlic',
'qty': 4.0,
'unit': u'clove',
'name': u'garlic',
'range_end': 6.0,
'comment': u'',
}, {
'input': u'3 bananas',
'qty': 3.0,
'unit': u'',
'name': u'bananas',
'comment': u'',
'range_end': 0.0,
}, {
'input': (u'2 1/2 pounds bell peppers (about 6 peppers in '
u'assorted colors), cut into 2-inch chunks'),
'qty':
2.5,
'unit':
u'pound',
'name':
u'bell peppers',
'range_end':
0.0,
'comment': (u'(about 6 peppers in assorted colors), cut into '
u'2-inch chunks'),
}], [r for r in reader])

def test_interprets_empty_range_end_as_zero(self):
mock_file = io.BytesIO("""
index,input,name,qty,range_end,unit,comment
77,3 bananas,bananas,3.0,,,
""".strip())
reader = labelled_data.Reader(mock_file)
self.assertEqual({
'input': u'3 bananas',
'qty': 3.0,
'unit': u'',
'name': u'bananas',
'comment': u'',
'range_end': 0.0,
}, reader.next())

def test_raises_error_when_csv_does_not_have_required_columns(self):
with self.assertRaises(labelled_data.InvalidHeaderError):
mock_file = io.BytesIO("""
index,input,UNEXPECTED_COLUMN,qty,range_end,unit,comment
77,3 bananas,bananas,3.0,0.0,,
""".strip())
labelled_data.Reader(mock_file).next()

0 comments on commit 81517e6

Please sign in to comment.