diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 5ddbbee4babdd..e438694dccdb8 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -19,6 +19,7 @@ from numpy import array from pyspark import RDD +from pyspark.streaming.dstream import DStream from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc from pyspark.mllib.linalg import SparseVector, Vectors, _convert_to_vector from pyspark.mllib.util import Saveable, Loader @@ -570,6 +571,47 @@ def train(cls, data, isotonic=True): return IsotonicRegressionModel(boundaries.toArray(), predictions.toArray(), isotonic) +@inherit_doc +class StreamingLinearRegressionWithSGD(LinearRegressionModel): + + def _validate_dstream(self, dstream): + if not isinstance(dstream, DStream): + raise TypeError( + "dstream should be a DStream object, got %s" % type(dstream)) + if not self.latestModel: + raise ValueError( + "Model must be intialized using setInitialWeights") + + def __init__(self, stepSize, numIterations, miniBatchFraction): + self.stepSize = stepSize + self.numIterations = numIterations + self.miniBatchFraction = miniBatchFraction + self.latestModel = None + + def setInitialWeights(self, initialWeights): + initialWeights = _convert_to_vector(initialWeights) + self.latestModel = LinearRegressionModel(initialWeights, 0) + + def trainOn(self, dstream): + self._validate_dstream(dstream) + + def update(rdd): + if rdd: + self.latestModel = LinearRegressionWithSGD.train( + rdd, self.numIteration, self.stepSize, + self.miniBatchFraction, self.latestModel.weights) + + dstream.foreachRDD(update) + + def predictOn(self, dstream): + self._validate_dstream(dstream) + return dstream.map(self.latestModel.predict) + + def predictOnValues(self, dstream): + self._validate_dstream(dstream) + return dstream.mapValues(self.latestModel.predict) + + def _test(): import doctest from pyspark import SparkContext