From 3bf978988f0d3f02341d4c9f3314fb39fe96cfa5 Mon Sep 17 00:00:00 2001 From: Fardin Abdi Date: Tue, 28 Jul 2020 12:40:08 -0700 Subject: [PATCH] addressed comments Signed-off-by: Fardin Abdi --- horovod/spark/common/util.py | 25 ++++++++++++------------- horovod/spark/torch/remote.py | 4 ++-- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/horovod/spark/common/util.py b/horovod/spark/common/util.py index 074c19f0cd..5b674b0b00 100644 --- a/horovod/spark/common/util.py +++ b/horovod/spark/common/util.py @@ -171,38 +171,37 @@ def check_shape_compatibility(metadata, feature_columns, label_columns, 'model input at index {idx} with size {input}' .format(col=col, feature=col_size, idx=idx, input=input_size)) + label_count = len(label_columns) if label_shapes is not None: - label_count = len(label_columns) if label_count != len(label_shapes): raise ValueError('Label column count {labels} must equal ' 'provided label shapes count {outputs}' .format(labels=label_count, outputs=len(label_shapes))) - label_count = len(label_columns) if output_shapes is not None: if label_count != len(output_shapes): raise ValueError('Label column count {labels} must equal ' 'model outputs count {outputs}' .format(labels=label_count, outputs=len(output_shapes))) - def _check_label_cols_size(target_shapes): - for _idx, _col, target_shape in zip(range(label_count), label_columns, target_shapes): - _col_size = metadata[_col]['shape'] - if _col_size is None: + def _check_label_cols_size(target_shapes, target_name): + for idx, col, target_shape in zip(range(label_count), label_columns, target_shapes): + col_size = metadata[col]['shape'] + if col_size is None: # When training directly on Parquet, we do not compute shape metadata continue target_size = abs(np.prod(target_shape)) - if _col_size != target_size: + if col_size != target_size: raise ValueError('Label column \'{col}\' with size {label} must equal that of the ' - 'model output and label shape at index {idx} with size {output}' - .format(col=_col, label=_col_size, idx=_idx, output=target_size)) - - if output_shapes is not None: - _check_label_cols_size(output_shapes) + '{target_name} shape at index {idx} with size {output}' + .format(col=col, label=col_size, idx=idx, output=target_size, + target_name=target_name)) if label_shapes is not None: - _check_label_cols_size(label_shapes) + _check_label_cols_size(label_shapes, 'label') + elif output_shapes is not None: + _check_label_cols_size(output_shapes, 'model output') def _get_col_info(df): diff --git a/horovod/spark/torch/remote.py b/horovod/spark/torch/remote.py index f620984473..cc7f002896 100644 --- a/horovod/spark/torch/remote.py +++ b/horovod/spark/torch/remote.py @@ -267,12 +267,12 @@ def transform_outputs(outputs, labels): # reshape labels to match the output shape of the model if hasattr(outputs[0], 'shape'): - # If label_shapes parameter is not provided, reshape the label columns - # data to match the shape of the model output if label_shapes: labels = [label.reshape(label_shape) for label, label_shape in zip(labels, label_shapes)] else: + # If label_shapes parameter is not provided, reshape the label + # columns data to match the shape of the model output labels = [label.reshape(output.shape) if output.shape.numel() == label.shape.numel() else label for label, output in zip(labels, outputs)]