Skip to content

Commit

Permalink
fix: refactor python wrappers to use common class
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Apr 30, 2021
1 parent 99b580f commit 13a953b
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 206 deletions.
73 changes: 2 additions & 71 deletions src/main/python/mmlspark/lightgbm/LightGBMClassificationModel.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,16 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

import sys
from pyspark import SQLContext
from pyspark import SparkContext

if sys.version >= '3':
basestring = str

from mmlspark.lightgbm._LightGBMClassificationModel import _LightGBMClassificationModel
from mmlspark.lightgbm.mixin import LightGBMModelMixin
from pyspark import SparkContext
from pyspark.ml.common import inherit_doc
from pyspark.ml.linalg import SparseVector, DenseVector
from pyspark.ml.wrapper import JavaParams
from mmlspark.core.serialize.java_params_patch import *


@inherit_doc
class LightGBMClassificationModel(_LightGBMClassificationModel):
def saveNativeModel(self, filename, overwrite=True):
"""
Save the booster as string format to a local or WASB remote location.
"""
self._java_obj.saveNativeModel(filename, overwrite)

class LightGBMClassificationModel(LightGBMModelMixin, _LightGBMClassificationModel):
@staticmethod
def loadNativeModelFromFile(filename):
"""
Expand All @@ -44,62 +31,6 @@ def loadNativeModelFromString(model):
java_model = loader.loadNativeModelFromString(model)
return JavaParams._from_java(java_model)

def getFeatureImportances(self, importance_type="split"):
"""
Get the feature importances as a list. The importance_type can be "split" or "gain".
"""
return list(self._java_obj.getFeatureImportances(importance_type))

def getFeatureShaps(self, vector):
"""
Get the local shap feature importances.
"""
if isinstance(vector, DenseVector):
dense_values = [float(v) for v in vector]
return list(self._java_obj.getDenseFeatureShaps(dense_values))
elif isinstance(vector, SparseVector):
sparse_indices = [int(i) for i in vector.indices]
sparse_values = [float(v) for v in vector.values]
return list(self._java_obj.getSparseFeatureShaps(vector.size, sparse_indices, sparse_values))
else:
raise TypeError("Vector argument to getFeatureShaps must be a pyspark.linalg sparse or dense vector type")

def getBoosterBestIteration(self):
"""Get the best iteration from the booster.
Returns:
The best iteration, if early stopping was triggered.
"""
return self._java_obj.getBoosterBestIteration()

def getBoosterNumTotalIterations(self):
"""Get the total number of iterations trained.
Returns:
The total number of iterations trained.
"""
return self._java_obj.getBoosterNumTotalIterations()

def getBoosterNumTotalModel(self):
"""Get the total number of models trained.
Note this may be larger than the number of iterations,
since in multiclass a model is trained per class for
each iteration.
Returns:
The total number of models.
"""
return self._java_obj.getBoosterNumTotalModel()

def getBoosterNumFeatures(self):
"""Get the number of features from the booster.
Returns:
The number of features.
"""
return self._java_obj.getBoosterNumFeatures()

def getBoosterNumClasses(self):
"""Get the number of classes from the booster.
Expand Down
70 changes: 2 additions & 68 deletions src/main/python/mmlspark/lightgbm/LightGBMRankerModel.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,16 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

import sys
from pyspark import SQLContext
from pyspark import SparkContext

if sys.version >= '3':
basestring = str

from mmlspark.lightgbm._LightGBMRankerModel import _LightGBMRankerModel
from mmlspark.lightgbm.mixin import LightGBMModelMixin
from pyspark import SparkContext
from pyspark.ml.common import inherit_doc
from pyspark.ml.linalg import SparseVector, DenseVector
from pyspark.ml.wrapper import JavaParams
from mmlspark.core.serialize.java_params_patch import *


@inherit_doc
class LightGBMRankerModel(_LightGBMRankerModel):
def saveNativeModel(self, filename, overwrite=True):
"""
Save the booster as string format to a local or WASB remote location.
"""
self._java_obj.saveNativeModel(filename, overwrite)

class LightGBMRankerModel(LightGBMModelMixin, _LightGBMRankerModel):
@staticmethod
def loadNativeModelFromFile(filename):
"""
Expand All @@ -44,59 +31,6 @@ def loadNativeModelFromString(model):
java_model = loader.loadNativeModelFromString(model)
return JavaParams._from_java(java_model)

def getFeatureImportances(self, importance_type="split"):
"""
Get the feature importances as a list. The importance_type can be "split" or "gain".
"""
return list(self._java_obj.getFeatureImportances(importance_type))

def getFeatureShaps(self, vector):
"""
Get the local shap feature importances.
"""
if isinstance(vector, DenseVector):
dense_values = [float(v) for v in vector]
return list(self._java_obj.getDenseFeatureShaps(dense_values))
elif isinstance(vector, SparseVector):
sparse_size = [float(v) for v in vector.size]
sparse_indices = [int(i) for i in vector.indices]
sparse_values = [float(v) for v in vector.values]
return list(self._java_obj.getSparseFeatureShaps(sparse_size, sparse_indices, sparse_values))
else:
raise TypeError("Vector argument to getFeatureShaps must be a pyspark.linalg sparse or dense vector type")

def getBoosterBestIteration(self):
"""Get the best iteration from the booster.
Returns:
The best iteration, if early stopping was triggered.
"""
return self._java_obj.getBoosterBestIteration()

def getBoosterNumTotalIterations(self):
"""Get the total number of iterations trained.
Returns:
The total number of iterations trained.
"""
return self._java_obj.getBoosterNumTotalIterations()

def getBoosterNumTotalModel(self):
"""Get the total number of models trained.
Returns:
The total number of models.
"""
return self._java_obj.getBoosterNumTotalModel()

def getBoosterNumFeatures(self):
"""Get the number of features from the booster.
Returns:
The number of features.
"""
return self._java_obj.getBoosterNumFeatures()

def getBoosterNumClasses(self):
"""Get the number of classes from the booster.
Expand Down
69 changes: 2 additions & 67 deletions src/main/python/mmlspark/lightgbm/LightGBMRegressionModel.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,15 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

import sys
from pyspark import SQLContext
from pyspark import SparkContext

if sys.version >= '3':
basestring = str

from mmlspark.lightgbm._LightGBMRegressionModel import _LightGBMRegressionModel
from mmlspark.lightgbm.mixin import LightGBMModelMixin
from pyspark import SparkContext
from pyspark.ml.common import inherit_doc
from pyspark.ml.linalg import SparseVector, DenseVector
from pyspark.ml.wrapper import JavaParams
from mmlspark.core.serialize.java_params_patch import *

@inherit_doc
class LightGBMRegressionModel(_LightGBMRegressionModel):
def saveNativeModel(self, filename, overwrite=True):
"""
Save the booster as string format to a local or WASB remote location.
"""
self._java_obj.saveNativeModel(filename, overwrite)

class LightGBMRegressionModel(LightGBMModelMixin, _LightGBMRegressionModel):
@staticmethod
def loadNativeModelFromFile(filename):
"""
Expand All @@ -42,55 +29,3 @@ def loadNativeModelFromString(model):
loader = ctx._jvm.com.microsoft.ml.spark.lightgbm.LightGBMRegressionModel
java_model = loader.loadNativeModelFromString(model)
return JavaParams._from_java(java_model)

def getFeatureImportances(self, importance_type="split"):
"""
Get the feature importances as a list. The importance_type can be "split" or "gain".
"""
return list(self._java_obj.getFeatureImportances(importance_type))

def getFeatureShaps(self, vector):
"""
Get the local shap feature importances.
"""
if isinstance(vector, DenseVector):
dense_values = [float(v) for v in vector]
return list(self._java_obj.getDenseFeatureShaps(dense_values))
elif isinstance(vector, SparseVector):
sparse_indices = [int(i) for i in vector.indices]
sparse_values = [float(v) for v in vector.values]
return list(self._java_obj.getSparseFeatureShaps(vector.size, sparse_indices, sparse_values))
else:
raise TypeError("Vector argument to getFeatureShaps must be a pyspark.linalg sparse or dense vector type")

def getBoosterBestIteration(self):
"""Get the best iteration from the booster.
Returns:
The best iteration, if early stopping was triggered.
"""
return self._java_obj.getBoosterBestIteration()

def getBoosterNumTotalIterations(self):
"""Get the total number of iterations trained.
Returns:
The total number of iterations trained.
"""
return self._java_obj.getBoosterNumTotalIterations()

def getBoosterNumTotalModel(self):
"""Get the total number of models trained.
Returns:
The total number of models.
"""
return self._java_obj.getBoosterNumTotalModel()

def getBoosterNumFeatures(self):
"""Get the number of features from the booster.
Returns:
The number of features.
"""
return self._java_obj.getBoosterNumFeatures()
66 changes: 66 additions & 0 deletions src/main/python/mmlspark/lightgbm/mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

from pyspark.ml.linalg import SparseVector, DenseVector
from pyspark.ml.common import inherit_doc
from mmlspark.core.serialize.java_params_patch import *

@inherit_doc
class LightGBMModelMixin:
def saveNativeModel(self, filename, overwrite=True):
"""
Save the booster as string format to a local or WASB remote location.
"""
self._java_obj.saveNativeModel(filename, overwrite)

def getFeatureImportances(self, importance_type="split"):
"""
Get the feature importances as a list. The importance_type can be "split" or "gain".
"""
return list(self._java_obj.getFeatureImportances(importance_type))

def getFeatureShaps(self, vector):
"""
Get the local shap feature importances.
"""
if isinstance(vector, DenseVector):
dense_values = [float(v) for v in vector]
return list(self._java_obj.getDenseFeatureShaps(dense_values))
elif isinstance(vector, SparseVector):
sparse_indices = [int(i) for i in vector.indices]
sparse_values = [float(v) for v in vector.values]
return list(self._java_obj.getSparseFeatureShaps(vector.size, sparse_indices, sparse_values))
else:
raise TypeError("Vector argument to getFeatureShaps must be a pyspark.linalg sparse or dense vector type")

def getBoosterBestIteration(self):
"""Get the best iteration from the booster.
Returns:
The best iteration, if early stopping was triggered.
"""
return self._java_obj.getBoosterBestIteration()

def getBoosterNumTotalIterations(self):
"""Get the total number of iterations trained.
Returns:
The total number of iterations trained.
"""
return self._java_obj.getBoosterNumTotalIterations()

def getBoosterNumTotalModel(self):
"""Get the total number of models trained.
Returns:
The total number of models.
"""
return self._java_obj.getBoosterNumTotalModel()

def getBoosterNumFeatures(self):
"""Get the number of features from the booster.
Returns:
The number of features.
"""
return self._java_obj.getBoosterNumFeatures()

0 comments on commit 13a953b

Please sign in to comment.