Skip to content

Commit

Permalink
Merge pull request #389 from rizar/fix_mapping_dict_and_add_sources
Browse files Browse the repository at this point in the history
fix the case mapping_accepts=dict and add_sources=smth
  • Loading branch information
dmitriy-serdyuk authored May 2, 2017
2 parents 425fe35 + c08127b commit 42e21a2
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
15 changes: 9 additions & 6 deletions fuel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,17 @@ def get_data(self, request=None):
if request is not None:
raise ValueError
data = next(self.child_epoch_iterator)
mapping_input = data
if self.mapping_accepts == dict:
data = OrderedDict(equizip(self.data_stream.sources, data))
image = self.mapping(data)
mapping_input = OrderedDict(equizip(self.data_stream.sources,
data))
image = self.mapping(mapping_input)
image_sources = self.add_sources if self.add_sources else self.sources
if self.mapping_accepts == dict:
image = tuple(image[source] for source in self.sources)
if not self.add_sources:
return image
return data + image
image = tuple(image[source] for source in image_sources)
if self.add_sources:
return data + image
return image


@add_metaclass(ABCMeta)
Expand Down
11 changes: 11 additions & 0 deletions tests/transformers/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,17 @@ def test_add_sources(self):
assert_equal(list(transformer.get_epoch_iterator()),
list(zip(self.data, [[2, 4, 6], [4, 6, 2], [6, 4, 2]])))

def test_mapping_dict_add_sources(self):
stream = DataStream(IterableDataset(self.data))
transformer = Mapping(
stream,
lambda d: {'doubled': [2 * i for i in d['data']]},
mapping_accepts=dict,
add_sources=('doubled',))
assert_equal(transformer.sources, ('data', 'doubled'))
assert_equal(list(transformer.get_epoch_iterator()),
list(zip(self.data, [[2, 4, 6], [4, 6, 2], [6, 4, 2]])))

def test_sort_mapping_trivial_key(self):
stream = DataStream(IterableDataset(self.data))
transformer = Mapping(stream, SortMapping(operator.itemgetter(0)))
Expand Down

0 comments on commit 42e21a2

Please sign in to comment.