Skip to content

Commit

Permalink
Added support for LightGBM and XGBoost
Browse files Browse the repository at this point in the history
  • Loading branch information
nbarraille committed Apr 13, 2019
1 parent 6cec035 commit 5400dc8
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 7 deletions.
7 changes: 7 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@

### Improvements

## 0.1.5 (2019-04-13)

### Improvements

- Added support for LightGBM
- Added support for XGBoost

## 0.1.4 (2019-04-12)

### Improvements
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ At the moment, we support the following frameworks:
- Scikit Learn (Supervised learning models and pipeline)
- Keras
- PyTorch
- XGBoost
- LightGBM

Coming soon:

Expand Down
2 changes: 1 addition & 1 deletion blazee/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# Or predict a batch
>>> preds = model.batch_predict(X)
"""
__version__ = '0.1.4'
__version__ = '0.1.5'

from .client import Client as Blazee
from .model import BlazeeModel
Expand Down
42 changes: 42 additions & 0 deletions blazee/lightgbm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import io
import os

from blazee.utils import (SerializedModel, add_file_deps,
get_files_dependencies, get_requirements)


def is_lightgbm(model):
try:
from lightgbm.basic import Booster
return isinstance(model, Booster)
except:
return False


def _get_model_metadata(model, include_files):
deps = ['lightgbm']
return {
'lib_versions': get_requirements(deps),
'include_files': include_files
}


def serialize_lightgbm(model, include_files):
tmp_file = 'tmp.txt'
try:
model.save_model(tmp_file)
with open(tmp_file, 'rb') as f:
buffer = f.read()
finally:
try:
os.remove(tmp_file)
except OSError:
pass

files = [('model.txt', buffer)]

add_file_deps(files, include_files)

meta = _get_model_metadata(model, include_files)

return SerializedModel('lightgbm', meta, files)
6 changes: 6 additions & 0 deletions blazee/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
from dateutil import parser

from blazee.keras_utils import is_keras, serialize_keras
from blazee.lightgbm_utils import is_lightgbm, serialize_lightgbm
from blazee.prediction import Prediction
from blazee.pytorch_utils import is_pytorch, serialize_pytorch
from blazee.sklearn_utils import is_sklearn, serialize_sklearn
from blazee.utils import generate_zip, pretty_size
from blazee.xgboost_utils import is_xgboost, serialize_xgboost


def _serialize_model(model, include_files=None):
Expand All @@ -39,6 +41,10 @@ def _serialize_model(model, include_files=None):
return serialize_keras(model, include_files)
elif is_pytorch(model):
return serialize_pytorch(model, include_files)
elif is_lightgbm(model):
return serialize_lightgbm(model, include_files)
elif is_xgboost(model):
return serialize_xgboost(model, include_files)
else:
raise TypeError(f'Model Type not supported: {type(model)}')

Expand Down
43 changes: 39 additions & 4 deletions blazee/sklearn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,49 @@ def is_sklearn(model):
return False


def _get_model_metadata(model, include_files):
deps = ['scikit-learn']
def _get_estimator_dependencies(estimator):
# Keras
try:
from keras.wrappers.scikit_learn import BaseWrapper
if isinstance(model, BaseWrapper):
deps += get_keras_deps()
if isinstance(estimator, BaseWrapper):
return get_keras_deps()
except:
pass
# XGBoost
try:
from xgboost.sklearn import XGBClassifier
if isinstance(estimator, XGBClassifier):
return ['xgboost']
except:
pass
# LightGBM
try:
from lightgbm.sklearn import LGBMModel
if isinstance(estimator, LGBMModel):
return ['lightgbm']
except:
pass

return []


def _get_model_metadata(model, include_files):
deps = ['scikit-learn']
from sklearn.pipeline import Pipeline
from sklearn.base import BaseEstimator
from sklearn.model_selection._search import BaseSearchCV

# Import other depdendencies if needed
if isinstance(model, Pipeline):
for _, estimator in model.steps:
deps += _get_estimator_dependencies(estimator)
elif isinstance(model, BaseSearchCV):
deps += _get_estimator_dependencies(model.estimator)
elif isinstance(model, BaseEstimator):
deps += _get_estimator_dependencies(model)
else:
raise ValueError(f"Model of type {type(model)} not supported")

if include_files:
deps += get_files_dependencies(include_files)

Expand Down
4 changes: 2 additions & 2 deletions blazee/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def pretty_size(num_bytes):
if num_bytes < 1024:
return f'{num_bytes} B'
elif num_bytes < 1024 * 1024:
return f'{num_bytes / 1024:1f} KB'
return f'{num_bytes / 1024:.1f} KB'
else:
return f'{num_bytes / 1024 / 1024:1f} MB'
return f'{num_bytes / 1024 / 1024:.1f} MB'


def generate_zip(files):
Expand Down
42 changes: 42 additions & 0 deletions blazee/xgboost_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import io
import os

from blazee.utils import (SerializedModel, add_file_deps,
get_files_dependencies, get_requirements)


def is_xgboost(model):
try:
from xgboost.core import Booster
return isinstance(model, Booster)
except:
return False


def _get_model_metadata(model, include_files):
deps = ['xgboost']
return {
'lib_versions': get_requirements(deps),
'include_files': include_files
}


def serialize_xgboost(model, include_files):
tmp_file = 'tmp.txt'
try:
model.save_model(tmp_file)
with open(tmp_file, 'rb') as f:
buffer = f.read()
finally:
try:
os.remove(tmp_file)
except OSError:
pass

files = [('model.txt', buffer)]

add_file_deps(files, include_files)

meta = _get_model_metadata(model, include_files)

return SerializedModel('xgboost', meta, files)

0 comments on commit 5400dc8

Please sign in to comment.