Skip to content

Commit

Permalink
updated the check
Browse files Browse the repository at this point in the history
Signed-off-by: Fardin Abdi <fardin@uber.com>
  • Loading branch information
abditag2 committed Jul 25, 2020
1 parent a3a68cc commit fd15fad
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
25 changes: 15 additions & 10 deletions horovod/spark/common/util.py
Expand Up @@ -178,26 +178,31 @@ def check_shape_compatibility(metadata, feature_columns, label_columns,
'provided label shapes count {outputs}'
.format(labels=label_count, outputs=len(label_shapes)))

label_count = len(label_columns)
if output_shapes is not None:
label_count = len(label_columns)
if label_count != len(output_shapes):
raise ValueError('Label column count {labels} must equal '
'model outputs count {outputs}'
.format(labels=label_count, outputs=len(output_shapes)))

if output_shapes is not None and label_shapes is None:
label_count = len(label_columns)
for idx, col, output_shape in zip(range(label_count), label_columns, output_shapes):
col_size = metadata[col]['shape']
if col_size is None:
def _check_label_cols_size(target_shapes):
for _idx, _col, target_shape in zip(range(label_count), label_columns, target_shapes):
_col_size = metadata[_col]['shape']
if _col_size is None:
# When training directly on Parquet, we do not compute shape metadata
continue

output_size = abs(np.prod(output_shape))
if col_size != output_size:
target_size = abs(np.prod(target_shape))
if _col_size != target_size:
raise ValueError('Label column \'{col}\' with size {label} must equal that of the '
'model output at index {idx} with size {output}'
.format(col=col, label=col_size, idx=idx, output=output_size))
'model output and label shape at index {idx} with size {output}'
.format(col=_col, label=_col_size, idx=_idx, output=target_size))

if output_shapes is not None:
_check_label_cols_size(output_shapes)

if label_shapes is not None:
_check_label_cols_size(label_shapes)


def _get_col_info(df):
Expand Down
18 changes: 18 additions & 0 deletions test/test_spark.py
Expand Up @@ -1431,6 +1431,18 @@ def test_check_shape_compatibility(self):
util.check_shape_compatibility(metadata, feature_columns, label_columns,
input_shapes, output_shapes)

input_shapes = [[1], [1], [-1, 3, 4]]
label_shapes = [[1], [-1, 3, 4]]
util.check_shape_compatibility(metadata, feature_columns, label_columns,
input_shapes, output_shapes, label_shapes)

# The case where label_shapes is different from output_shapes
input_shapes = [[1], [1], [-1, 3, 4]]
output_shapes = [[1, 1], [-1, 2, 3, 2]]
label_shapes = [[1], [-1, 3, 4]]
util.check_shape_compatibility(metadata, feature_columns, label_columns,
input_shapes, output_shapes, label_shapes)

bad_input_shapes = [[1], [1], [-1, 3, 5]]
with pytest.raises(ValueError):
util.check_shape_compatibility(metadata, feature_columns, label_columns,
Expand All @@ -1446,6 +1458,12 @@ def test_check_shape_compatibility(self):
util.check_shape_compatibility(metadata, feature_columns, label_columns,
input_shapes, bad_output_shapes)

input_shapes = [[1], [1], [-1, 3, 4]]
bad_label_shapes = [[-1, 3, 4]]
with pytest.raises(ValueError):
util.check_shape_compatibility(metadata, feature_columns, label_columns,
input_shapes, output_shapes, bad_label_shapes)

@mock.patch('horovod.spark.common.store.HDFSStore._get_filesystem_fn')
def test_sync_hdfs_store(self, mock_get_fs_fn):
mock_fs = mock.Mock()
Expand Down

0 comments on commit fd15fad

Please sign in to comment.