New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add label_shapes parameter to KerasEstimator and TorchEstimator #2140
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just a few small things.
horovod/spark/common/util.py
Outdated
col_size = metadata[col]['shape'] | ||
if col_size is None: | ||
def _check_label_cols_size(target_shapes): | ||
for _idx, _col, target_shape in zip(range(label_count), label_columns, target_shapes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to use _
here, that's only for private methods / variables.
horovod/spark/common/util.py
Outdated
.format(col=_col, label=_col_size, idx=_idx, output=target_size)) | ||
|
||
if output_shapes is not None: | ||
_check_label_cols_size(output_shapes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems this should only be done if label_shapes is None
, right? Otherwise, we will be forcing the model output size to equal the label size, which may not be correct (for these custom loss functions).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
horovod/spark/torch/remote.py
Outdated
labels = [label.reshape(output.shape) | ||
if output.shape.numel() == label.shape.numel() else label | ||
for label, output in zip(labels, outputs)] | ||
# If label_shapes parameter is not provided, reshape the label columns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can move this comment into the else
branch so it's closer to the code it's describing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved.
5c1e055
to
3bf9789
Compare
Signed-off-by: Fardin Abdi <fardin@uber.com>
Signed-off-by: Fardin Abdi <fardin@uber.com>
Signed-off-by: Fardin Abdi <fardin@uber.com>
Signed-off-by: Fardin Abdi <fardin@uber.com>
Signed-off-by: Fardin Abdi <fardin@uber.com>
Signed-off-by: Fardin Abdi <fardin@uber.com>
Signed-off-by: Fardin Abdi <fardin@uber.com>
Signed-off-by: Fardin Abdi <fardin@uber.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Checklist before submitting
Description
Fixes # (issue).
Review process to land