Skip to content

Commit

Permalink
more work on sklearn kit
Browse files Browse the repository at this point in the history
  • Loading branch information
jfischer committed Jan 13, 2020
1 parent bc57902 commit ddcbfb5
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 87 deletions.
98 changes: 12 additions & 86 deletions dataworkspaces/kits/scikit_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def __init__(self, predictor, metrics:Union[str,type],
self.workspace_dir = workspace_dir
self.metrics = metrics
self.verbose = verbose
self.score_has_been_run = False
self._init_dws_state()

def _init_dws_state(self):
Expand All @@ -342,21 +343,15 @@ def _init_dws_state(self):
self.results_resource)

def _save_model(self):
if self.model_save_file.endswith('.joblib') or \
self.model_save_file.endswith('.pkl'):
model_save_file = self.model_save_file
else:
if not self.model_save_file.endswith('.joblib'):
model_save_file = self.model_save_file + '.joblib'
else:
model_save_file = self.model_save_file
tempname = None
try:
if model_save_file.endswith('.joblib'):
with NamedTemporaryFile(delete=False, suffix='.joblib') as f:
tempname = f.name
joblib.dump(self, tempname)
else:
with NamedTemporaryFile(delete=False, suffix='pkl') as f:
tempname = f.name
pickle.dump(self, f)
with NamedTemporaryFile(delete=False, suffix='.joblib') as f:
tempname = f.name
joblib.dump(self, tempname)
resource = self._dws_state.workspace.get_resource(self._dws_state.results_ref.name)
if self._dws_state.results_ref.subpath is not None:
target_name = join(self._dws_state.results_ref.subpath,
Expand Down Expand Up @@ -402,6 +397,10 @@ def fit(self, X, y, *args, **kwargs):
return result

def score(self, X, y, sample_weight=None):
if self.score_has_been_run:
# This might be from a saved model, so we reset the
# execution time, etc.
self._dws_state.reset_lineage()
for (param, value) in self.predictor.get_params(deep=True).items():
self._dws_state.lineage.add_param(param, value)
api_resource = self._dws_state.find_input_resources_and_return_if_api(X, y)
Expand All @@ -419,87 +418,14 @@ def score(self, X, y, sample_weight=None):
else:
metrics_inst = self.metrics(y, predictions, sample_weight=sample_weight)
self._dws_state.write_metrics_and_complete(metrics_inst.to_dict())
self.score_has_been_run = True
return metrics_inst.score()

def predict(self, X):
return self.predictor.predict(X)




def add_lineage_to_predictor_instance(predictor,
metrics_class:type,
input_resource:Optional[Union[str, ResourceRef]]=None,
results_resource:Optional[Union[str, ResourceRef]]=None,
workspace_dir:Optional[str]=None,
verbose:bool=False):
"""
This function wraps a predictor instance with a subclass that overrides
key methods to make calls to the data lineage api.
"""
if hasattr(predictor, '_dws_model_wrap') and predictor._dws_model_wrap is True: # type: ignore
print("dws>> %s is already wrapped" % repr(predictor))
return predictor # already wrapped
assert issubclass(metrics_class, Metrics),\
"%s is not a subclass of Metrics" % metrics_class.__name__
workspace = find_and_load_workspace(batch=True, verbose=verbose,
uri_or_local_path=workspace_dir)

class WrappedPredictor: # type: ignore
_dws_model_wrap = True
def __init__(self):
self.predictor = predictor
self._dws_state = _DwsModelState(workspace, input_resource,
results_resource)

@classmethod
def _get_param_names(cls):
"""Get parameter names for the estimator"""
return predictor.__class__.get_param_names()
def get_params(self, deep=True):
"""Get parameters for this estimator."""
return self.predictor.get_params(deep=deep)
def set_params(self, **params):
return self.predictor.set_params(**params)
def __repr__(self):
class_name = self.__class__.__name__
return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False),
offset=len(class_name),),)
def __getstate__(self):
return self.predictor.__getstate__()
def __setstate__(self, state):
return self.predictor.__setstate__(state)
def fit(self, X, y, *args, **kwargs):
api_resource = self._dws_state.find_input_resources_and_return_if_api(X, y)
if api_resource is not None:
api_resource.init_hash_state()
hash_state = api_resource.get_hash_state()
_add_to_hash(X, hash_state)
_add_to_hash(y, hash_state)
api_resource.save_current_hash() # in case we evaluate in a separate process
return self.predictor.fit(X, y, *args, **kwargs)
def score(self, X, y, sample_weight=None):
for (param, value) in self.predictor.get_params(deep=True).items():
self._dws_state.lineage.add_param(param, value)
api_resource = self._dws_state.find_input_resources_and_return_if_api(X, y)
if api_resource is not None:
api_resource.dup_hash_state()
hash_state = api_resource.get_hash_state()
_add_to_hash(X, hash_state)
if y is not None:
_add_to_hash(y, hash_state)
api_resource.save_current_hash()
api_resource.pop_hash_state()
predictions = self.predictor.predict(X)
metrics = metrics_class(y, predictions, sample_weight=sample_weight)
self._dws_state.write_metrics_and_complete(metrics.to_dict())
return metrics.score()
def predict(self, X):
return self.predictor.predict(X)

WrappedPredictor.__name__ = 'Wrapped'+predictor.__class__.__name__
return WrappedPredictor()




Expand Down
8 changes: 8 additions & 0 deletions dataworkspaces/kits/wrapper_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,11 @@ def write_metrics_and_complete(self, metrics):
print("dws>> Metrics: %s" % repr(metrics))
self.lineage.write_results(metrics)
self.lineage.complete()

def reset_lineage(self):
"""If you are rerunning a step, call this to reset the start and execution
times as well as the in_progress marker in the lineage.
"""
self.lineage.step.execution_time_seconds=None
self.lineage.step.start_time=datetime.datetime.now()
self.lineage.in_progress = True
2 changes: 1 addition & 1 deletion tests/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ DATAWORKSPACES:=$(shell cd ../dataworkspaces; pwd)
help:
@echo targets are: test clean mypy pyflakes check help install-rclone-deb

UNIT_TESTS=test_git_utils test_file_utils test_move_results test_snapshots test_push_pull test_local_files_resource test_hashtree test_lineage_utils test_git_fat_integration test_lineage test_jupyter_kit test_api test_wrapper_utils test_tensorflow test_scratch_dir
UNIT_TESTS=test_git_utils test_file_utils test_move_results test_snapshots test_push_pull test_local_files_resource test_hashtree test_lineage_utils test_git_fat_integration test_lineage test_jupyter_kit test_sklearn_kit test_api test_wrapper_utils test_tensorflow test_scratch_dir

MYPY_KITS=scikit_learn.py jupyter.py tensorflow.py wrapper_utils.py

Expand Down
60 changes: 60 additions & 0 deletions tests/test_sklearn_kit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import unittest
import sys
from os.path import exists, join

from utils_for_tests import SimpleCase, WS_DIR

try:
import sklearn
SKLEARN_INSTALLED=True
except ImportError:
SKLEARN_INSTALLED=False

class TestSklearnKit(SimpleCase):
def _add_digits_dataset(self):
self._run_git(['clone',
'https://github.com/jfischer/sklearn-digits-dataset.git'])
self._run_dws(['add','git','--role=source-data','--read-only',
'./sklearn-digits-dataset'])

def wrapper_tc(self, model_save_file):
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.externals import joblib
import dataworkspaces.kits.scikit_learn as skkit
self._setup_initial_repo(git_resources='code,results')
self._add_digits_dataset()
dataset = skkit.load_dataset_from_resource('sklearn-digits-dataset',
workspace_dir=WS_DIR)
X_train, X_test, y_train, y_test = train_test_split(
dataset.data, dataset.target, test_size=0.5, shuffle=False)
classifier = skkit.LineagePredictor(SVC(gamma=0.001),
'multiclass_classification',
input_resource='sklearn-digits-dataset',
model_save_file=model_save_file,
workspace_dir=WS_DIR,
verbose=False)
classifier.fit(X_train, y_train)
score = classifier.score(X_test, y_test)
self.assertAlmostEqual(score, 0.9688, 3,
"Score of %s not almost equal to 0.9688" % score)
results_dir = join(WS_DIR, 'results')
results_file = join(results_dir, 'results.json')
self.assertTrue(exists(results_file))
save_file = join(results_dir, model_save_file)
self.assertTrue(exists(save_file))

# test reloading the trained model
classifier2 = joblib.load(save_file)
score2 = classifier.score(X_test, y_test)
self.assertAlmostEqual(score, 0.9688, 3,
"Score of %s not almost equal to 0.9688" % score)

@unittest.skipUnless(SKLEARN_INSTALLED, "Sklearn not available")
def test_wrapper(self):
self.wrapper_tc('digits.joblib')


if __name__ == '__main__':
unittest.main()

0 comments on commit ddcbfb5

Please sign in to comment.