Skip to content

Commit

Permalink
Fixes binary feature postprocessing upcast (#2101)
Browse files Browse the repository at this point in the history
* Fixes binary feature postprocessing upcast

* make more numpy-thonic

Co-authored-by: Geoffrey Angus <geoffrey@predibase.com>
  • Loading branch information
geoffreyangus and geoffreyangus committed Jun 6, 2022
1 parent b20cbdd commit d3eea13
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions ludwig/features/binary_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,19 +361,20 @@ def postprocess_predictions(

probabilities_col = f"{self.feature_name}_{PROBABILITIES}"
if probabilities_col in result:
result[probabilities_col] = backend.df_engine.map_objects(
result[probabilities_col],
lambda prob: np.array([1 - prob, prob], dtype=result[probabilities_col].dtype),
)

false_col = f"{probabilities_col}_{class_names[0]}"
result[false_col] = backend.df_engine.map_objects(result[probabilities_col], lambda prob: 1 - prob)
result[false_col] = backend.df_engine.map_objects(result[probabilities_col], lambda probs: probs[0])

true_col = f"{probabilities_col}_{class_names[1]}"
result[true_col] = result[probabilities_col]
result[true_col] = backend.df_engine.map_objects(result[probabilities_col], lambda probs: probs[1])

prob_col = f"{self.feature_name}_{PROBABILITY}"
result[prob_col] = result[[false_col, true_col]].max(axis=1)

result[probabilities_col] = backend.df_engine.map_objects(
result[probabilities_col], lambda prob: [1 - prob, prob]
)

return result

@staticmethod
Expand Down

0 comments on commit d3eea13

Please sign in to comment.