From fd15fad6ca84b7b43d05d2f0291128d0e0ec96a0 Mon Sep 17 00:00:00 2001 From: Fardin Abdi Date: Fri, 24 Jul 2020 18:38:05 -0700 Subject: [PATCH] updated the check Signed-off-by: Fardin Abdi --- horovod/spark/common/util.py | 25 +++++++++++++++---------- test/test_spark.py | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/horovod/spark/common/util.py b/horovod/spark/common/util.py index 4532629035..074c19f0cd 100644 --- a/horovod/spark/common/util.py +++ b/horovod/spark/common/util.py @@ -178,26 +178,31 @@ def check_shape_compatibility(metadata, feature_columns, label_columns, 'provided label shapes count {outputs}' .format(labels=label_count, outputs=len(label_shapes))) + label_count = len(label_columns) if output_shapes is not None: - label_count = len(label_columns) if label_count != len(output_shapes): raise ValueError('Label column count {labels} must equal ' 'model outputs count {outputs}' .format(labels=label_count, outputs=len(output_shapes))) - if output_shapes is not None and label_shapes is None: - label_count = len(label_columns) - for idx, col, output_shape in zip(range(label_count), label_columns, output_shapes): - col_size = metadata[col]['shape'] - if col_size is None: + def _check_label_cols_size(target_shapes): + for _idx, _col, target_shape in zip(range(label_count), label_columns, target_shapes): + _col_size = metadata[_col]['shape'] + if _col_size is None: # When training directly on Parquet, we do not compute shape metadata continue - output_size = abs(np.prod(output_shape)) - if col_size != output_size: + target_size = abs(np.prod(target_shape)) + if _col_size != target_size: raise ValueError('Label column \'{col}\' with size {label} must equal that of the ' - 'model output at index {idx} with size {output}' - .format(col=col, label=col_size, idx=idx, output=output_size)) + 'model output and label shape at index {idx} with size {output}' + .format(col=_col, label=_col_size, idx=_idx, output=target_size)) + + if output_shapes is not None: + _check_label_cols_size(output_shapes) + + if label_shapes is not None: + _check_label_cols_size(label_shapes) def _get_col_info(df): diff --git a/test/test_spark.py b/test/test_spark.py index b0a56ffdc2..c426c01020 100644 --- a/test/test_spark.py +++ b/test/test_spark.py @@ -1431,6 +1431,18 @@ def test_check_shape_compatibility(self): util.check_shape_compatibility(metadata, feature_columns, label_columns, input_shapes, output_shapes) + input_shapes = [[1], [1], [-1, 3, 4]] + label_shapes = [[1], [-1, 3, 4]] + util.check_shape_compatibility(metadata, feature_columns, label_columns, + input_shapes, output_shapes, label_shapes) + + # The case where label_shapes is different from output_shapes + input_shapes = [[1], [1], [-1, 3, 4]] + output_shapes = [[1, 1], [-1, 2, 3, 2]] + label_shapes = [[1], [-1, 3, 4]] + util.check_shape_compatibility(metadata, feature_columns, label_columns, + input_shapes, output_shapes, label_shapes) + bad_input_shapes = [[1], [1], [-1, 3, 5]] with pytest.raises(ValueError): util.check_shape_compatibility(metadata, feature_columns, label_columns, @@ -1446,6 +1458,12 @@ def test_check_shape_compatibility(self): util.check_shape_compatibility(metadata, feature_columns, label_columns, input_shapes, bad_output_shapes) + input_shapes = [[1], [1], [-1, 3, 4]] + bad_label_shapes = [[-1, 3, 4]] + with pytest.raises(ValueError): + util.check_shape_compatibility(metadata, feature_columns, label_columns, + input_shapes, output_shapes, bad_label_shapes) + @mock.patch('horovod.spark.common.store.HDFSStore._get_filesystem_fn') def test_sync_hdfs_store(self, mock_get_fs_fn): mock_fs = mock.Mock()