Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add label_shapes parameter to KerasEstimator and TorchEstimator #2140

Merged
merged 8 commits into from Jul 29, 2020
Merged

Conversation

abditag2
Copy link
Collaborator

Checklist before submitting

  • Did you read the contributor guide?
  • Did you update the docs?
  • Did you write any tests to validate this change?
  • Did you update the CHANGELOG, if this change affects users?

Description

Fixes # (issue).

Review process to land

  1. All tests and other checks must succeed.
  2. At least one member of the technical steering committee must review and approve.
  3. If any member of the technical steering committee requests changes, they must be addressed.

@abditag2 abditag2 changed the title add label_shapes parameter to KerasEstimator and TorchEstimator Add label_shapes parameter to KerasEstimator and TorchEstimator Jul 24, 2020
@abditag2 abditag2 requested a review from tgaddair July 24, 2020 23:14
@abditag2 abditag2 self-assigned this Jul 24, 2020
Copy link
Collaborator

@tgaddair tgaddair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Just a few small things.

horovod/spark/common/util.py Show resolved Hide resolved
col_size = metadata[col]['shape']
if col_size is None:
def _check_label_cols_size(target_shapes):
for _idx, _col, target_shape in zip(range(label_count), label_columns, target_shapes):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to use _ here, that's only for private methods / variables.

.format(col=_col, label=_col_size, idx=_idx, output=target_size))

if output_shapes is not None:
_check_label_cols_size(output_shapes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this should only be done if label_shapes is None, right? Otherwise, we will be forcing the model output size to equal the label size, which may not be correct (for these custom loss functions).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

labels = [label.reshape(output.shape)
if output.shape.numel() == label.shape.numel() else label
for label, output in zip(labels, outputs)]
# If label_shapes parameter is not provided, reshape the label columns
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can move this comment into the else branch so it's closer to the code it's describing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved.

@abditag2 abditag2 force-pushed the loss branch 3 times, most recently from 5c1e055 to 3bf9789 Compare July 28, 2020 19:46
Signed-off-by: Fardin Abdi <fardin@uber.com>
Signed-off-by: Fardin Abdi <fardin@uber.com>
Signed-off-by: Fardin Abdi <fardin@uber.com>
Signed-off-by: Fardin Abdi <fardin@uber.com>
Signed-off-by: Fardin Abdi <fardin@uber.com>
Signed-off-by: Fardin Abdi <fardin@uber.com>
Signed-off-by: Fardin Abdi <fardin@uber.com>
Signed-off-by: Fardin Abdi <fardin@uber.com>
Copy link
Collaborator

@tgaddair tgaddair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@abditag2 abditag2 merged commit cb67186 into master Jul 29, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants