-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor MultiNodeIterator classes #259
Conversation
raise RuntimeError('Multi node iterator supports numpy.float32 ' | ||
'or tuple of numpy.float32 as the data type ' | ||
'of the batch element only.') | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this else
. Let's reduce indent level.
# 3. whether dataset is paired. | ||
# 4. is_new_epoch. | ||
# 5. current_position. | ||
_info = numpy.ones((5, )) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that this is a refactoring patch, I'd want this conversion as a symmetric functions for tests. Maybe we could increase the number of elements in future. Let's have unit tests for this conversions for future. Like this:
def _build_ctrl_msg(stop, is_valid_data_type, is_paired_dataset, is_new_epoch, current_pos):
...
def _parse_ctrl_msg(msg):
....
def test_msg_conversion():
msg = [....]
assert msg == _build_ctr_msg(_parse_ctrl_msg(msg))
@@ -117,23 +137,28 @@ def __next__(self): | |||
# Check if master iterator received stop signal. | |||
_info = self.communicator.bcast(None, root=self.rank_master) | |||
stop = bool(_info[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above.
if stop: | ||
raise StopIteration | ||
elif not valid_data_type: | ||
raise RuntimeError('Multi node iterator supports numpy.float32 ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TypeError should be raised here.
Passing arguments of the wrong type (e.g. passing a list when an int is expected) should result in a TypeError,
if stop: | ||
raise StopIteration | ||
elif not valid_data_type: | ||
raise RuntimeError('Multi node iterator supports numpy.float32 ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto: TypeError
.
This PR refactor two MultiNodeIterator classes.
I mainly did the following: