Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change DataIterator to work on get_example and get_batch and add methods... #40

Closed
wants to merge 8 commits into from
43 changes: 34 additions & 9 deletions fuel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ class Transformer(AbstractDataStream):
this attribute. Use it to access data from the wrapped data stream
by calling ``next(self.child_epoch_iterator)``.

batch : boolean
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_input instead of batch, and there should be no new line between parameter descriptions.

Determine wheter the model is working on examples or on batches

"""
def __init__(self, data_stream, **kwargs):
def __init__(self, data_stream, batch_input=False, **kwargs):
super(Transformer, self).__init__(**kwargs)
self.data_stream = data_stream
self.batch_input = batch_input

@property
def sources(self):
Expand Down Expand Up @@ -60,6 +64,24 @@ def get_epoch_iterator(self, **kwargs):
self.child_epoch_iterator = self.data_stream.get_epoch_iterator()
return super(Transformer, self).get_epoch_iterator(**kwargs)

def get_data(self, request=None):
if self.batch_input:
return self.get_data_from_batch(request)
else:
return self.get_data_from_example(request)

def get_data_from_example(self, request=None):
raise NotImplementedError(
str(type(self)) +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally prefer to use formatting instead of string concatenation, it looks a bit cleaner IMHO. Something slightly more informative would be good too e.g. "{}does not support examples as inputs, butbatch_inputwas set toFalse".format(type(self))

"does not have an example method"
)

def get_data_from_batch(self, request=None):
raise NotImplementedError(
str(type(self)) +
"does not have a batch input method"
)


class Mapping(Transformer):
"""Applies a mapping to the data of the wrapped data stream.
Expand Down Expand Up @@ -164,7 +186,7 @@ def __init__(self, data_stream, iteration_scheme):
data_stream, iteration_scheme=iteration_scheme)
self.cache = [[] for _ in self.sources]

def get_data(self, request=None):
def get_data_from_example(self, request=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be get_data_from_batch (don't forget to pass batch_input=True to super).

if request > len(self.cache[0]):
self._cache()
data = []
Expand Down Expand Up @@ -237,7 +259,7 @@ def __init__(self, data_stream, iteration_scheme, strictness=0):
data_stream, iteration_scheme=iteration_scheme)
self.strictness = strictness

def get_data(self, request=None):
def get_data_from_example(self, request=None):
"""Get data from the dataset."""
if request is None:
raise ValueError
Expand All @@ -261,20 +283,23 @@ def get_data(self, request=None):
class Unpack(Transformer):
"""Unpacks batches to compose a stream of examples.


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why new lines were added here, but there should be only one.

This class is the inverse of the Batch class: it turns a minibatch into
a stream of examples.


Parameters
----------
data_stream : :class:`AbstractDataStream` instance
The data stream to unpack


"""
def __init__(self, data_stream):
super(Unpack, self).__init__(data_stream)
super(Unpack, self).__init__(data_stream, batch_input=True)
self.data = None

def get_data(self, request=None):
def get_data_from_batch(self, request=None):
if not self.data:
data = next(self.child_epoch_iterator)
self.data = izip(*data)
Expand Down Expand Up @@ -311,7 +336,7 @@ class Padding(Transformer):

"""
def __init__(self, data_stream, mask_sources=None, mask_dtype=None):
super(Padding, self).__init__(data_stream)
super(Padding, self).__init__(data_stream, batch_input=True)
if mask_sources is None:
mask_sources = self.data_stream.sources
self.mask_sources = mask_sources
Expand All @@ -329,7 +354,7 @@ def sources(self):
sources.append(source + '_mask')
return tuple(sources)

def get_data(self, request=None):
def get_data_from_batch(self, request=None):
if request is not None:
raise ValueError
data = list(next(self.child_epoch_iterator))
Expand Down Expand Up @@ -459,13 +484,13 @@ class MultiProcessing(Transformer):

"""
def __init__(self, data_stream, max_store=100):
super(MultiProcessing, self).__init__(data_stream)
super(MultiProcessing, self).__init__(data_stream, batch_input=True)
self.background = BackgroundProcess(data_stream, max_store)
self.proc = Process(target=self.background.main)
self.proc.daemon = True
self.proc.start()

def get_data(self, request=None):
def get_data_from_batch(self, request=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this one just passes on data, I think its agnostic to whether the input is a batch or single example. You can just leave the get_data method here.

if request is not None:
raise ValueError
data = self.background.get_next_data()
Expand Down