Skip to content

Commit

Permalink
update docs for sklearn kit
Browse files Browse the repository at this point in the history
  • Loading branch information
jfischer committed Jan 13, 2020
1 parent ddcbfb5 commit f463db3
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 11 deletions.
101 changes: 93 additions & 8 deletions dataworkspaces/kits/scikit_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def load_dataset_from_resource(resource_name:str, subpath:Optional[str]=None,
* ``data`` - a NumPy array of shape number_samples * number_features
* ``target`` - a NumPy array of length number_samples
* ``resource`` - a :class:`~ResourceRef` that provides the resource name and
* ``resource`` - a :class:`~dataworkspaces.workspace.ResourceRef` that provides the resource name and
subpath (if any) for the data
Some other attributes that may also be present, depending on the data set:
Expand All @@ -77,10 +77,6 @@ def load_dataset_from_resource(resource_name:str, subpath:Optional[str]=None,
Data sets may define their own attributes as well (see below).
The ``data`` and ``target`` attributes can be used directly (e.g. passed to
``train_test_split()``) or the entire bunch used as a parameter to
:func:`~train_and_predict_with_cv`.
**Parameters**
resource_name
Expand Down Expand Up @@ -311,9 +307,79 @@ def print_metrics(self, file=sys.stdout):

import sklearn.utils.metaestimators
class LineagePredictor(sklearn.utils.metaestimators._BaseComposition):
"""This is a wrapper for adding lineage to any predictor in sklearn.
To use it, instantiate the predictor (for classification or regression)
and then create a new instance of :class:`~LineagePredictor`.
The initializer finds the associated workspace and initializes a
:class:`~dataworkspaces.lineage.Lineage` instance. The input_resource
is recorded in this lineage. Other methods call the underlying wrapped
predictor's methods, with additional functionality as needed (see below).
**Parameters**
predictor
Any sklearn predictor instance. It must have ``fit`` and ``predict``
methods.
metrics
Either a string naming a metrics type or a subclass of :class:`~Metrics`.
If a string, it should be one of: binary_classification,
multiclass_classification, or regression.
input_resource
Resource providing the input data to this model. May be
specified by name, by a local file path, or via a
:class:`~dataworkspaces.workspace.ResourceRef`.
resource_resource
(optional) Resource where the results are to be stored.
May be specified by name, by a local file path, or via a
:class:`!ResourceRef`.
If not specified, will try to infer from the workspace.
model_save_file
(optional) Name of file to store a (joblib-formmatted)
serialization of the trained model upon completion of the ``fit()``
method. This should be a relative path, as it is stored under
the results resource. If model_save_file is not specified,
no model is saved.
workspace_dir
(optional) Directory specifying the workspace. Usually can be
inferred from the current directory.
verbose
If True, print a lot of detailed information about the execution
of Data Workspaces.
**Example**
Here is an example useage of the wrapper, taken from the
:ref:`Quick Start <quickstart>`::
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from dataworkspaces.kits.scikit_learn import load_dataset_from_resource
from dataworkspaces.kits.scikit_learn import LineagePredictor
dataset = load_dataset_from_resource('sklearn-digits-dataset')
X_train, X_test, y_train, y_test = train_test_split(
dataset.data, dataset.target, test_size=0.5, shuffle=False)
classifier = LineagePredictor(SVC(gamma=0.001),
metrics='multiclass_classification',
input_resource=dataset.resource,
model_save_file='digits.joblib')
classifier.fit(X_train, y_train)
score = classifier.score(X_test, y_test)
**Methods**
"""
_dws_model_wrap = True
def __init__(self, predictor, metrics:Union[str,type],
input_resource:Optional[Union[str, ResourceRef]]=None,
input_resource:Union[str, ResourceRef],
results_resource:Optional[Union[str, ResourceRef]]=None,
model_save_file:Optional[str]=None,
workspace_dir:Optional[str]=None,
Expand All @@ -330,6 +396,9 @@ def __init__(self, predictor, metrics:Union[str,type],
self.input_resource = input_resource
self.results_resource = results_resource
self.model_save_file = model_save_file
if model_save_file is not None:
assert not isabs(model_save_file),\
"Model save file should not be an absolute path"
self.workspace_dir = workspace_dir
self.metrics = metrics
self.verbose = verbose
Expand Down Expand Up @@ -369,7 +438,6 @@ def _save_model(self):
def __getstate__(self):
state = super().__getstate__()
if '_dws_state' in state:
print("__get_state__: deleting _dws_state")
del state['_dws_state']
return state

Expand All @@ -378,12 +446,18 @@ def __setstate__(self, state):
self._init_dws_state()

def set_params(self, **params):
print("set_params(%s)"%repr(params))
""""""
super().set_params(**params)
self._init_dws_state()
return self

def fit(self, X, y, *args, **kwargs):
"""The underlying fit() method of a predictor trains the predictio based
on the input data (X) and labels (y).
If the input resource is an api resource, the wrapper captures the hash of
the inputs.
If ``model_save_file`` was specified, it also saves the trained model."""
api_resource = self._dws_state.find_input_resources_and_return_if_api(X, y)
if api_resource is not None:
api_resource.init_hash_state()
Expand All @@ -397,6 +471,15 @@ def fit(self, X, y, *args, **kwargs):
return result

def score(self, X, y, sample_weight=None):
"""This method make predictions from a trained model and scores them
according to the metrics specified when instantiated the wrapper.
If the input resource is an api resource, the wrapper captures its hash.
The wapper runs the wrapped predictor's :meth:`~predict` method to
generate predictions. A `metrics` object is instantiated to compute the metrics
for the predictions and a ``results.json`` file is written to the
results resource. The lineage data is saved and finally the score
is computed from the predictions and returned to the caller."""
if self.score_has_been_run:
# This might be from a saved model, so we reset the
# execution time, etc.
Expand All @@ -422,6 +505,8 @@ def score(self, X, y, sample_weight=None):
return metrics_inst.score()

def predict(self, X):
"""The underlying :meth:`~predict` method is called directly,
without affecting the lineage."""
return self.predictor.predict(X)


Expand Down
2 changes: 1 addition & 1 deletion dataworkspaces/kits/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def add_lineage_to_keras_model_class(Cls:type,
``ResourceRef``. If no inputs are specified, will try to infer from the
workspace.
* ``results_resource`` -- optional resource where the results are to be stored.
My be specified by name, by a local file path, or via a ``ResourceRef``.
May be specified by name, by a local file path, or via a ``ResourceRef``.
if not specified, will try to infer from the workspace.
* ``workspace-dir`` -- Optional directory specifying the workspace. Usually can be
inferred from the current directory.
Expand Down
3 changes: 2 additions & 1 deletion docs/kits.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ Scikit-learn

.. automodule:: dataworkspaces.kits.scikit_learn
:no-undoc-members:
:members: load_dataset_from_resource,train_and_predict_with_cv,Metrics,BinaryClassificationMetrics,MulticlassClassificationMetrics
:members: load_dataset_from_resource,LineagePredictor,Metrics,BinaryClassificationMetrics,MulticlassClassificationMetrics


TensorFlow
----------
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sklearn_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def wrapper_tc(self, model_save_file):
dataset.data, dataset.target, test_size=0.5, shuffle=False)
classifier = skkit.LineagePredictor(SVC(gamma=0.001),
'multiclass_classification',
input_resource='sklearn-digits-dataset',
input_resource=dataset.resource,
model_save_file=model_save_file,
workspace_dir=WS_DIR,
verbose=False)
Expand Down

0 comments on commit f463db3

Please sign in to comment.