Skip to content

Commit

Permalink
Fix pylint. (#10296)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 17, 2024
1 parent 835e59e commit ba9b4cb
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,8 @@ def _validate_and_convert_feature_col_as_array_col(
(DoubleType, FloatType, LongType, IntegerType, ShortType),
):
raise ValueError(
"If feature column is array type, its elements must be number type."
"If feature column is array type, its elements must be number type, "
f"got {features_col_datatype.elementType}."
)
features_array_col = features_col.cast(ArrayType(FloatType())).alias(alias.data)
elif isinstance(features_col_datatype, VectorUDT):
Expand Down Expand Up @@ -1379,15 +1380,15 @@ def _transform(self, dataset: DataFrame) -> DataFrame:
# to avoid the `self` object to be pickled to remote.
xgb_sklearn_model = self._xgb_sklearn_model

has_base_margin = False
base_margin_col = None
if (
self.isDefined(self.base_margin_col)
and self.getOrDefault(self.base_margin_col) != ""
):
has_base_margin = True
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
alias.margin
)
has_base_margin = base_margin_col is not None

features_col, feature_col_names = self._get_feature_col(dataset)
enable_sparse_data_optim = self.getOrDefault(self.enable_sparse_data_optim)
Expand Down Expand Up @@ -1472,6 +1473,7 @@ def to_gpu_if_possible(data: ArrayLike) -> ArrayLike:
yield predict_func(model, X, base_margin)

if has_base_margin:
assert base_margin_col is not None
pred_col = predict_udf(struct(*features_col, base_margin_col))
else:
pred_col = predict_udf(struct(*features_col))
Expand Down

0 comments on commit ba9b4cb

Please sign in to comment.