Skip to content

Commit

Permalink
Merge pull request #701 from talalryz/u/talal/fix_math_binary_pyspark
Browse files Browse the repository at this point in the history
U/talal/fix math binary pyspark
  • Loading branch information
ancasarb committed Jun 29, 2020
2 parents bb82854 + 91dbea9 commit 9362f01
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 19 deletions.
45 changes: 27 additions & 18 deletions python/mleap/pyspark/feature/math_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pyspark.ml.wrapper import JavaTransformer

from mleap.pyspark.py2scala import jvm_scala_object
from mleap.pyspark.py2scala import ScalaNone
from mleap.pyspark.py2scala import Some


Expand Down Expand Up @@ -43,11 +44,13 @@ class MathBinary(JavaTransformer, HasOutputCol, JavaMLReadable, JavaMLWritable):

@keyword_only
def __init__(
self,
operation=None,
inputA=None,
inputB=None,
outputCol=None,
self,
operation=None,
inputA=None,
inputB=None,
outputCol=None,
defaultA=None,
defaultB=None,
):
"""
Computes the mathematical binary `operation` over
Expand All @@ -57,15 +60,24 @@ def __init__(
:param inputA: column name for the left side of operation (string)
:param inputB: column name for the right side of operation (string)
:param outputCol: output column name (string)
NOTE: `operation` is not a JavaParam because the underlying MathBinary
scala object uses a MathBinaryModel to store the info about the binary
operation.
:param defaultA: Default to use instead of inputA. This will only be used
when inputA is None. For example when defaultA=4,
operation=BinaryOperation.Multiply and inputB=f1, then all entries of
col f1 will be multiplied by 4.
:param defaultB: Default to use instead of inputB. This will only be used
when inputB is None. For example when defaultB=4,
operation=BinaryOperation.Multiply and inputA=f1, then all entries of
col f1 will be multiplied by 4.
NOTE: `operation`, `defaultA`, `defaultB` is not a JavaParam because
the underlying MathBinary scala object uses a MathBinaryModel to store
the info about the binary operation.
`operation` has a None default value even though it should *never* be
None. A None value is necessary upon deserialization to instantiate a
MathBinary without errors. Afterwards, pyspark sets the _java_obj to
the deserialized scala object, which encodes the operation.
the deserialized scala object, which encodes the operation (as well
as the default values for A and B).
"""
super(MathBinary, self).__init__()

Expand All @@ -80,14 +92,11 @@ def __init__(
operation.name
)

# IMPORTANT: defaults for missing values are forced to None.
# I've found an issue when setting default values for A and B,
# Remember to treat your missing values before the MathBinary
# (for example, you could use an Imputer)
scalaMathBinaryModel = _jvm().ml.combust.mleap.core.feature.MathBinaryModel(
scalaBinaryOperation, Some(None), Some(None)
scalaBinaryOperation,
Some(defaultA) if defaultA else ScalaNone(),
Some(defaultB) if defaultB else ScalaNone(),
)

self._java_obj = self._new_java_obj(
"org.apache.spark.ml.mleap.feature.MathBinary",
self.uid,
Expand All @@ -102,7 +111,8 @@ def setParams(self, inputA=None, inputB=None, outputCol=None):
"""
Sets params for this MathBinary.
"""
kwargs = self._input_kwargs
# For the correct behavior of MathBinary, params that are None must be unset
kwargs = {k: v for k, v in self._input_kwargs.items() if v is not None}
return self._set(**kwargs)

def setInputA(self, value):
Expand All @@ -122,4 +132,3 @@ def setOutputCol(self, value):
Sets the value of :py:attr:`outputCol`.
"""
return self._set(outputCol=value)

3 changes: 3 additions & 0 deletions python/mleap/pyspark/py2scala.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ def Some(value):
an Option[<value>]
"""
return _jvm().scala.Some(value)

def ScalaNone():
return jvm_scala_object(_jvm().scala, "None")
75 changes: 74 additions & 1 deletion python/tests/pyspark/feature/math_binary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def test_serialize_deserialize_math_binary(self):

add_transformer.serializeToBundle(file_path, self.input)
deserialized_math_binary = SimpleSparkSerializer().deserializeFromBundle(file_path)

result = deserialized_math_binary.transform(self.input).toPandas()[['add(f1, f2)']]
assert_frame_equal(self.expected_add, result)

Expand Down Expand Up @@ -147,3 +146,77 @@ def test_serialize_deserialize_pipeline(self):

result = pipeline_model.transform(self.input).toPandas()[['mul(f1, add(f1, f2))']]
assert_frame_equal(expected, result)

def test_add_math_binary_defaults_none(self):
add_transformer = self._new_add_math_binary()

none_df = self.spark.createDataFrame([
(None, float(i * 2))
for i in range(1, 3)
], INPUT_SCHEMA)

# Summing None + int yields Nones
expected_df = pd.DataFrame([
(None,)
for i in range(1, 3)
], columns=['add(f1, f2)'])

result = add_transformer.transform(none_df).toPandas()[['add(f1, f2)']]
assert_frame_equal(expected_df, result)

def test_mult_math_binary_default_inputA(self):
mult_transformer = MathBinary(
operation=BinaryOperation.Multiply,
inputB="f2",
outputCol="mult(1, f2)",
defaultA=1.0,
)
none_df = self.spark.createDataFrame([
(None, float(i * 1234))
for i in range(1, 3)
], INPUT_SCHEMA)

expected_df = pd.DataFrame([
(float(i * 1234), )
for i in range(1, 3)
], columns=['mult(1, f2)'])
result = mult_transformer.transform(none_df).toPandas()[['mult(1, f2)']]
assert_frame_equal(expected_df, result)

def test_mult_math_binary_default_inputB(self):
mult_transformer = MathBinary(
operation=BinaryOperation.Multiply,
inputA="f1",
outputCol="mult(f1, 2)",
defaultB=2.0,
)
none_df = self.spark.createDataFrame([
(float(i * 1234), None)
for i in range(1, 3)
], INPUT_SCHEMA)

expected_df = pd.DataFrame([
(float(i * 1234 * 2), )
for i in range(1, 3)
], columns=['mult(f1, 2)'])
result = mult_transformer.transform(none_df).toPandas()[['mult(f1, 2)']]
assert_frame_equal(expected_df, result)

def test_mult_math_binary_default_both(self):
mult_transformer = MathBinary(
operation=BinaryOperation.Multiply,
outputCol="mult(7, 8)",
defaultA=7.0,
defaultB=8.0,
)
none_df = self.spark.createDataFrame([
(None, None)
for i in range(1, 3)
], INPUT_SCHEMA)

expected_df = pd.DataFrame([
(float(7 * 8), )
for i in range(1, 3)
], columns=['mult(7, 8)'])
result = mult_transformer.transform(none_df).toPandas()[['mult(7, 8)']]
assert_frame_equal(expected_df, result)

0 comments on commit 9362f01

Please sign in to comment.