Skip to content

Commit

Permalink
[SPARK-21310][ML][PYSPARK] Expose offset in PySpark
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
Add offset to PySpark in GLM as in apache#16699.

## How was this patch tested?
Python test

Author: actuaryzhang <actuaryzhang10@gmail.com>

Closes apache#18534 from actuaryzhang/pythonOffset.
  • Loading branch information
actuaryzhang authored and yanboliang committed Jul 5, 2017
1 parent a386432 commit 4852b7d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
25 changes: 21 additions & 4 deletions python/pyspark/ml/regression.py
Expand Up @@ -1376,17 +1376,20 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
typeConverter=TypeConverters.toFloat)
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
"options: irls.", typeConverter=TypeConverters.toString)
offsetCol = Param(Params._dummy(), "offsetCol", "The offset column name. If this is not set " +
"or empty, we treat all instance offsets as 0.0",
typeConverter=TypeConverters.toString)

@keyword_only
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
variancePower=0.0, linkPower=None):
variancePower=0.0, linkPower=None, offsetCol=None):
"""
__init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
variancePower=0.0, linkPower=None)
variancePower=0.0, linkPower=None, offsetCol=None)
"""
super(GeneralizedLinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
Expand All @@ -1402,12 +1405,12 @@ def __init__(self, labelCol="label", featuresCol="features", predictionCol="pred
def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None,
variancePower=0.0, linkPower=None):
variancePower=0.0, linkPower=None, offsetCol=None):
"""
setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \
variancePower=0.0, linkPower=None)
variancePower=0.0, linkPower=None, offsetCol=None)
Sets params for generalized linear regression.
"""
kwargs = self._input_kwargs
Expand Down Expand Up @@ -1486,6 +1489,20 @@ def getLinkPower(self):
"""
return self.getOrDefault(self.linkPower)

@since("2.3.0")
def setOffsetCol(self, value):
"""
Sets the value of :py:attr:`offsetCol`.
"""
return self._set(offsetCol=value)

@since("2.3.0")
def getOffsetCol(self):
"""
Gets the value of offsetCol or its default value.
"""
return self.getOrDefault(self.offsetCol)


class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable,
JavaMLReadable):
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/ml/tests.py
Expand Up @@ -1291,6 +1291,20 @@ def test_tweedie_distribution(self):
self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4))
self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4))

def test_offset(self):

df = self.spark.createDataFrame(
[(0.2, 1.0, 2.0, Vectors.dense(0.0, 5.0)),
(0.5, 2.1, 0.5, Vectors.dense(1.0, 2.0)),
(0.9, 0.4, 1.0, Vectors.dense(2.0, 1.0)),
(0.7, 0.7, 0.0, Vectors.dense(3.0, 3.0))], ["label", "weight", "offset", "features"])

glr = GeneralizedLinearRegression(family="poisson", weightCol="weight", offsetCol="offset")
model = glr.fit(df)
self.assertTrue(np.allclose(model.coefficients.toArray(), [0.664647, -0.3192581],
atol=1E-4))
self.assertTrue(np.isclose(model.intercept, -1.561613, atol=1E-4))


class FPGrowthTests(SparkSessionTestCase):
def setUp(self):
Expand Down

0 comments on commit 4852b7d

Please sign in to comment.