Skip to content

Commit

Permalink
[SPARK-4127] Python bindings for StreamingLinearRegressionWithSGD
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Jun 25, 2015
1 parent 085a721 commit d42bdae
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions python/pyspark/mllib/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from numpy import array

from pyspark import RDD
from pyspark.streaming.dstream import DStream
from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc
from pyspark.mllib.linalg import SparseVector, Vectors, _convert_to_vector
from pyspark.mllib.util import Saveable, Loader
Expand Down Expand Up @@ -570,6 +571,47 @@ def train(cls, data, isotonic=True):
return IsotonicRegressionModel(boundaries.toArray(), predictions.toArray(), isotonic)


@inherit_doc
class StreamingLinearRegressionWithSGD(LinearRegressionModel):

def _validate_dstream(self, dstream):
if not isinstance(dstream, DStream):
raise TypeError(
"dstream should be a DStream object, got %s" % type(dstream))
if not self.latestModel:
raise ValueError(
"Model must be intialized using setInitialWeights")

def __init__(self, stepSize, numIterations, miniBatchFraction):
self.stepSize = stepSize
self.numIterations = numIterations
self.miniBatchFraction = miniBatchFraction
self.latestModel = None

def setInitialWeights(self, initialWeights):
initialWeights = _convert_to_vector(initialWeights)
self.latestModel = LinearRegressionModel(initialWeights, 0)

def trainOn(self, dstream):
self._validate_dstream(dstream)

def update(rdd):
if rdd:
self.latestModel = LinearRegressionWithSGD.train(
rdd, self.numIteration, self.stepSize,
self.miniBatchFraction, self.latestModel.weights)

dstream.foreachRDD(update)

def predictOn(self, dstream):
self._validate_dstream(dstream)
return dstream.map(self.latestModel.predict)

def predictOnValues(self, dstream):
self._validate_dstream(dstream)
return dstream.mapValues(self.latestModel.predict)


def _test():
import doctest
from pyspark import SparkContext
Expand Down

0 comments on commit d42bdae

Please sign in to comment.