Skip to content

Commit

Permalink
Patch for Spark 3 Update
Browse files Browse the repository at this point in the history
  • Loading branch information
dciborow committed Feb 1, 2021
1 parent 13ce0c9 commit 3b18519
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 74 deletions.
Original file line number Diff line number Diff line change
@@ -1,34 +1,56 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root for information.

import sys
from mmlspark.recommendation.RankingTrainValidationSplitModel import RankingTrainValidationSplitModel as tvmodel
from mmlspark.recommendation._RankingTrainValidationSplit import _RankingTrainValidationSplit
from pyspark import keyword_only
from pyspark.ml.param import Params
from pyspark.ml.tuning import _ValidatorParams
from pyspark.ml.util import *
from mmlspark.recommendation.ValidatorSetterParams import ValidatorSetterParams
from pyspark import keyword_only
from pyspark.ml import Estimator
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaParams
from pyspark.ml import Estimator


if sys.version >= '3':
if sys.version >= "3":
basestring = str


@inherit_doc
class RankingTrainValidationSplit(Estimator, _ValidatorParams):
trainRatio = Param(Params._dummy(), "trainRatio", "Param for ratio between train and\
validation data. Must be between 0 and 1.", typeConverter=TypeConverters.toFloat)
userCol = Param(Params._dummy(), "userCol",
"userCol: column name for user ids. Ids must be within the integer value range. (default: user)")
class RankingTrainValidationSplit(Estimator, ValidatorSetterParams):
trainRatio = Param(
Params._dummy(),
"trainRatio",
"Param for ratio between train and\
validation data. Must be between 0 and 1.",
typeConverter=TypeConverters.toFloat,
)
userCol = Param(
Params._dummy(),
"userCol",
"userCol: column name for user ids. Ids must be within the integer value range. (default: user)",
)
ratingCol = Param(Params._dummy(), "ratingCol", "ratingCol: column name for ratings (default: rating)")

itemCol = Param(Params._dummy(), "itemCol",
"itemCol: column name for item ids. Ids must be within the integer value range. (default: item)")
itemCol = Param(
Params._dummy(),
"itemCol",
"itemCol: column name for item ids. Ids must be within the integer value range. (default: item)",
)

def setEstimator(self, value):
"""
Sets the value of :py:attr:`estimator`.
"""
return self._set(estimator=value)

def setEvaluator(self, value):
"""
Sets the value of :py:attr:`evaluator`.
"""
return self._set(evaluator=value)

def setEstimatorParamMaps(self, value):
"""
Sets the value of :py:attr:`estimatorParamMaps`.
"""
return self._set(estimatorParamMaps=value)

def setTrainRatio(self, value):
"""
Expand All @@ -44,63 +66,43 @@ def getTrainRatio(self):

def setItemCol(self, value):
"""
Args:
itemCol (str): column name for item ids. Ids must be within the integer value range. (default: item)
Sets the value of :py:attr:`itemCol`.
"""
self._set(itemCol=value)
return self

def getItemCol(self):
"""
Returns:
str: column name for item ids. Ids must be within the integer value range. (default: item)
Gets the value of :py:attr:`itemCol`.
"""
return self.getOrDefault(self.itemCol)

def setRatingCol(self, value):
"""
Args:
ratingCol (str): column name for ratings (default: rating)
Sets the value of :py:attr:`ratingCol`.
"""
self._set(ratingCol=value)
return self

def getRatingCol(self):
"""
Returns:
str: column name for ratings (default: rating)
Gets the value of :py:attr:`ratingCol`.
"""
return self.getOrDefault(self.ratingCol)

def setUserCol(self, value):
"""
Args:
userCol (str): column name for user ids. Ids must be within the integer value range. (default: user)
Sets the value of :py:attr:`userCol`.
"""
self._set(userCol=value)
return self

def getUserCol(self):
"""
Returns:
str: column name for user ids. Ids must be within the integer value range. (default: user)
Gets the value of :py:attr:`userCol`.
"""
return self.getOrDefault(self.userCol)

@keyword_only
def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, seed=None):
"""
Expand Down Expand Up @@ -142,6 +144,7 @@ def copy(self, extra=None):

def _create_model(self, java_model):
from mmlspark.recommendation.RankingTrainValidationSplitModel import RankingTrainValidationSplitModel

model = RankingTrainValidationSplitModel._from_java(java_model)
return model

Expand All @@ -153,8 +156,9 @@ def _to_java(self):

estimator, epms, evaluator = super(RankingTrainValidationSplit, self)._to_java_impl()

_java_obj = JavaParams._new_java_obj("com.microsoft.ml.spark.recommendation.RankingTrainValidationSplit",
self.uid)
_java_obj = JavaParams._new_java_obj(
"com.microsoft.ml.spark.recommendation.RankingTrainValidationSplit", self.uid
)
_java_obj.setEstimatorParamMaps(epms)
_java_obj.setEvaluator(evaluator)
_java_obj.setEstimator(estimator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@

import sys

if sys.version >= '3':
if sys.version >= "3":
basestring = str

from pyspark.ml.common import inherit_doc
from pyspark.ml.tuning import _ValidatorParams
from pyspark.ml.util import *
from mmlspark.recommendation._RankingTrainValidationSplitModel import _RankingTrainValidationSplitModel
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.util import *
from pyspark.ml.common import _py2java

from mmlspark.recommendation.ValidatorSetterParams import ValidatorSetterParams


# Load information from java_stage to the instance.
@inherit_doc
class RankingTrainValidationSplitModel(_RankingTrainValidationSplitModel, _ValidatorParams):

class RankingTrainValidationSplitModel(_RankingTrainValidationSplitModel, ValidatorSetterParams):
def __init__(self, bestModel=None, validationMetrics=[]):
super(RankingTrainValidationSplitModel, self).__init__()
#: best model from cross validation
Expand Down Expand Up @@ -48,23 +46,35 @@ def copy(self, extra=None):
def recommendForAllUsers(self, numItems):
return self.bestModel._call_java("recommendForAllUsers", numItems)

def recommendForAllItems(self, numItems):
return self.bestModel._call_java("recommendForAllItems", numItems)
def recommendForAllItems(self, numUsers):
return self.bestModel._call_java("recommendForAllItems", numUsers)

def setEstimator(self, value):
"""
Sets the value of :py:attr:`estimator`.
"""
return self._set(estimator=value)

def setEvaluator(self, value):
"""
Sets the value of :py:attr:`evaluator`.
"""
return self._set(evaluator=value)

def setEstimatorParamMaps(self, value):
"""
Sets the value of :py:attr:`estimatorParamMaps`.
"""
return self._set(estimatorParamMaps=value)

@classmethod
def _from_java(cls, java_stage):
"""
Given a Java TrainValidationSplitModel, create and return a Python wrapper of it.
Used for ML persistence.
"""

# Load information from java_stage to the instance.
bestModel = JavaParams._from_java(java_stage.getBestModel())
estimator, epms, evaluator = super(RankingTrainValidationSplitModel,
cls)._from_java_impl(java_stage)
# Create a new instance of this stage.
py_stage = cls(bestModel=bestModel).setEstimator(estimator)
py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)
py_stage = cls(bestModel=bestModel)

py_stage._resetUid(java_stage.uid())
return py_stage
4 changes: 1 addition & 3 deletions src/main/python/mmlspark/recommendation/SAR.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@

import sys

if sys.version >= '3':
if sys.version >= "3":
basestring = str

from pyspark.ml.param.shared import *
from pyspark.ml.common import inherit_doc
from mmlspark.core.schema.Utils import *
from mmlspark.recommendation._SAR import _SAR as sar
from mmlspark.recommendation.SARModel import SARModel as sarm
Expand Down
4 changes: 1 addition & 3 deletions src/main/python/mmlspark/recommendation/SARModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@

import sys

if sys.version >= '3':
if sys.version >= "3":
basestring = str

from pyspark.ml.param.shared import *
from pyspark.ml.common import inherit_doc
from mmlspark.core.schema.Utils import *
from mmlspark.recommendation._SARModel import _SARModel as sarModel

Expand Down
43 changes: 43 additions & 0 deletions src/main/python/mmlspark/recommendation/ValidatorSetterParams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark.ml.param.shared import HasSeed
from pyspark.ml.tuning import _ValidatorParams


class ValidatorSetterParams(_ValidatorParams, HasSeed):
"""
Common params for TrainValidationSplit and CrossValidator.
"""

def setEstimator(self, value):
"""
Sets the value of :py:attr:`estimator`.
"""
return self._set(estimator=value)

def setEstimatorParamMaps(self, value):
"""
Sets the value of :py:attr:`estimatorParamMaps`.
"""
return self._set(estimatorParamMaps=value)

def setEvaluator(self, value):
"""
Sets the value of :py:attr:`evaluator`.
"""
return self._set(evaluator=value)
12 changes: 6 additions & 6 deletions src/test/python/mmlsparktest/recommendation/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@

# Prepare training and test data.
import unittest

from mmlspark.recommendation.RankingAdapter import RankingAdapter
from mmlspark.recommendation.RankingEvaluator import RankingEvaluator
from mmlspark.recommendation.RankingTrainValidationSplit import RankingTrainValidationSplit
from mmlspark.recommendation.RecommendationIndexer import RecommendationIndexer
from mmlspark.recommendation.SAR import SAR
from mmlsparktest.spark import *
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import *
from pyspark.sql.types import *
from mmlsparktest.spark import *


class RankingSpec(unittest.TestCase):
Expand Down Expand Up @@ -56,7 +56,7 @@ def getRatings():
["originalCustomerID", "newCategoryID", "rating", "notTime"])
return ratings

def ignore_adapter_evaluator(self):
def test_adapter_evaluator(self):
ratings = self.getRatings()

user_id = "originalCustomerID"
Expand All @@ -81,7 +81,7 @@ def ignore_adapter_evaluator(self):
for metric in metrics:
print(metric + ": " + str(RankingEvaluator(k=3, metricName=metric).evaluate(output)))

def ignore_adapter_evaluator_sar(self):
def test_adapter_evaluator_sar(self):
ratings = self.getRatings()

user_id = "originalCustomerID"
Expand All @@ -107,7 +107,7 @@ def ignore_adapter_evaluator_sar(self):
for metric in metrics:
print(metric + ": " + str(RankingEvaluator(k=3, metricName=metric).evaluate(output)))

def ignore_all_tiny(self):
def test_all_tiny(self):
ratings = RankingSpec.getRatings()

customerIndex = StringIndexer() \
Expand All @@ -128,7 +128,7 @@ def ignore_all_tiny(self):
.setItemCol(ratingsIndex.getOutputCol())

alsModel = als.fit(transformedDf)
usersRecs = alsModel._call_java("recommendForAllUsers", 3)
usersRecs = alsModel.recommendForAllUsers(3)
print(usersRecs.take(1))

paramGrid = ParamGridBuilder() \
Expand Down

0 comments on commit 3b18519

Please sign in to comment.