Skip to content

Commit

Permalink
addressed comments
Browse files Browse the repository at this point in the history
Signed-off-by: Fardin Abdi <fardin@uber.com>
  • Loading branch information
abditag2 committed Jul 28, 2020
1 parent 307ea05 commit 3bf9789
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 15 deletions.
25 changes: 12 additions & 13 deletions horovod/spark/common/util.py
Expand Up @@ -171,38 +171,37 @@ def check_shape_compatibility(metadata, feature_columns, label_columns,
'model input at index {idx} with size {input}'
.format(col=col, feature=col_size, idx=idx, input=input_size))

label_count = len(label_columns)
if label_shapes is not None:
label_count = len(label_columns)
if label_count != len(label_shapes):
raise ValueError('Label column count {labels} must equal '
'provided label shapes count {outputs}'
.format(labels=label_count, outputs=len(label_shapes)))

label_count = len(label_columns)
if output_shapes is not None:
if label_count != len(output_shapes):
raise ValueError('Label column count {labels} must equal '
'model outputs count {outputs}'
.format(labels=label_count, outputs=len(output_shapes)))

def _check_label_cols_size(target_shapes):
for _idx, _col, target_shape in zip(range(label_count), label_columns, target_shapes):
_col_size = metadata[_col]['shape']
if _col_size is None:
def _check_label_cols_size(target_shapes, target_name):
for idx, col, target_shape in zip(range(label_count), label_columns, target_shapes):
col_size = metadata[col]['shape']
if col_size is None:
# When training directly on Parquet, we do not compute shape metadata
continue

target_size = abs(np.prod(target_shape))
if _col_size != target_size:
if col_size != target_size:
raise ValueError('Label column \'{col}\' with size {label} must equal that of the '
'model output and label shape at index {idx} with size {output}'
.format(col=_col, label=_col_size, idx=_idx, output=target_size))

if output_shapes is not None:
_check_label_cols_size(output_shapes)
'{target_name} shape at index {idx} with size {output}'
.format(col=col, label=col_size, idx=idx, output=target_size,
target_name=target_name))

if label_shapes is not None:
_check_label_cols_size(label_shapes)
_check_label_cols_size(label_shapes, 'label')
elif output_shapes is not None:
_check_label_cols_size(output_shapes, 'model output')


def _get_col_info(df):
Expand Down
4 changes: 2 additions & 2 deletions horovod/spark/torch/remote.py
Expand Up @@ -267,12 +267,12 @@ def transform_outputs(outputs, labels):

# reshape labels to match the output shape of the model
if hasattr(outputs[0], 'shape'):
# If label_shapes parameter is not provided, reshape the label columns
# data to match the shape of the model output
if label_shapes:
labels = [label.reshape(label_shape)
for label, label_shape in zip(labels, label_shapes)]
else:
# If label_shapes parameter is not provided, reshape the label
# columns data to match the shape of the model output
labels = [label.reshape(output.shape) if
output.shape.numel() == label.shape.numel() else label
for label, output in zip(labels, outputs)]
Expand Down

0 comments on commit 3bf9789

Please sign in to comment.