-
Notifications
You must be signed in to change notification settings - Fork 268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
change DataIterator to work on get_example and get_batch and add methods... #40
Changes from all commits
f669330
2653454
a49e5a6
0cc19c9
579bd2c
3c6e49f
01adf80
4f9d814
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,10 +22,14 @@ class Transformer(AbstractDataStream): | |
this attribute. Use it to access data from the wrapped data stream | ||
by calling ``next(self.child_epoch_iterator)``. | ||
|
||
batch : boolean | ||
Determine wheter the model is working on examples or on batches | ||
|
||
""" | ||
def __init__(self, data_stream, **kwargs): | ||
def __init__(self, data_stream, batch_input=False, **kwargs): | ||
super(Transformer, self).__init__(**kwargs) | ||
self.data_stream = data_stream | ||
self.batch_input = batch_input | ||
|
||
@property | ||
def sources(self): | ||
|
@@ -60,6 +64,24 @@ def get_epoch_iterator(self, **kwargs): | |
self.child_epoch_iterator = self.data_stream.get_epoch_iterator() | ||
return super(Transformer, self).get_epoch_iterator(**kwargs) | ||
|
||
def get_data(self, request=None): | ||
if self.batch_input: | ||
return self.get_data_from_batch(request) | ||
else: | ||
return self.get_data_from_example(request) | ||
|
||
def get_data_from_example(self, request=None): | ||
raise NotImplementedError( | ||
str(type(self)) + | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We generally prefer to use formatting instead of string concatenation, it looks a bit cleaner IMHO. Something slightly more informative would be good too e.g. |
||
"does not have an example method" | ||
) | ||
|
||
def get_data_from_batch(self, request=None): | ||
raise NotImplementedError( | ||
str(type(self)) + | ||
"does not have a batch input method" | ||
) | ||
|
||
|
||
class Mapping(Transformer): | ||
"""Applies a mapping to the data of the wrapped data stream. | ||
|
@@ -164,7 +186,7 @@ def __init__(self, data_stream, iteration_scheme): | |
data_stream, iteration_scheme=iteration_scheme) | ||
self.cache = [[] for _ in self.sources] | ||
|
||
def get_data(self, request=None): | ||
def get_data_from_example(self, request=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be |
||
if request > len(self.cache[0]): | ||
self._cache() | ||
data = [] | ||
|
@@ -237,7 +259,7 @@ def __init__(self, data_stream, iteration_scheme, strictness=0): | |
data_stream, iteration_scheme=iteration_scheme) | ||
self.strictness = strictness | ||
|
||
def get_data(self, request=None): | ||
def get_data_from_example(self, request=None): | ||
"""Get data from the dataset.""" | ||
if request is None: | ||
raise ValueError | ||
|
@@ -261,20 +283,23 @@ def get_data(self, request=None): | |
class Unpack(Transformer): | ||
"""Unpacks batches to compose a stream of examples. | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure why new lines were added here, but there should be only one. |
||
This class is the inverse of the Batch class: it turns a minibatch into | ||
a stream of examples. | ||
|
||
|
||
Parameters | ||
---------- | ||
data_stream : :class:`AbstractDataStream` instance | ||
The data stream to unpack | ||
|
||
|
||
""" | ||
def __init__(self, data_stream): | ||
super(Unpack, self).__init__(data_stream) | ||
super(Unpack, self).__init__(data_stream, batch_input=True) | ||
self.data = None | ||
|
||
def get_data(self, request=None): | ||
def get_data_from_batch(self, request=None): | ||
if not self.data: | ||
data = next(self.child_epoch_iterator) | ||
self.data = izip(*data) | ||
|
@@ -311,7 +336,7 @@ class Padding(Transformer): | |
|
||
""" | ||
def __init__(self, data_stream, mask_sources=None, mask_dtype=None): | ||
super(Padding, self).__init__(data_stream) | ||
super(Padding, self).__init__(data_stream, batch_input=True) | ||
if mask_sources is None: | ||
mask_sources = self.data_stream.sources | ||
self.mask_sources = mask_sources | ||
|
@@ -329,7 +354,7 @@ def sources(self): | |
sources.append(source + '_mask') | ||
return tuple(sources) | ||
|
||
def get_data(self, request=None): | ||
def get_data_from_batch(self, request=None): | ||
if request is not None: | ||
raise ValueError | ||
data = list(next(self.child_epoch_iterator)) | ||
|
@@ -459,13 +484,13 @@ class MultiProcessing(Transformer): | |
|
||
""" | ||
def __init__(self, data_stream, max_store=100): | ||
super(MultiProcessing, self).__init__(data_stream) | ||
super(MultiProcessing, self).__init__(data_stream, batch_input=True) | ||
self.background = BackgroundProcess(data_stream, max_store) | ||
self.proc = Process(target=self.background.main) | ||
self.proc.daemon = True | ||
self.proc.start() | ||
|
||
def get_data(self, request=None): | ||
def get_data_from_batch(self, request=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this one just passes on data, I think its agnostic to whether the input is a batch or single example. You can just leave the |
||
if request is not None: | ||
raise ValueError | ||
data = self.background.get_next_data() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch_input
instead ofbatch
, and there should be no new line between parameter descriptions.