Skip to content

Commit

Permalink
[Frameworks] Add the tag attribute to the ml frameworks (#1618)
Browse files Browse the repository at this point in the history
  • Loading branch information
guy1992l committed Jan 8, 2022
1 parent 6c0eb23 commit 3b5cfe4
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 53 deletions.
23 changes: 8 additions & 15 deletions mlrun/frameworks/_ml_common/mlrun_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,18 @@ class MLMLRunInterface:

@classmethod
def add_interface(
cls,
model_handler: MLModelHandler,
context,
model_name,
data={},
*args,
**kwargs
cls, model_handler: MLModelHandler, context, tag, data={}, *args, **kwargs
):
"""
Wrap the given model with MLRun model features, providing it with MLRun model attributes including its
parameters and methods.
:param model: The model to wrap.
:param context: MLRun context to work with. If no context is given it will be retrieved via
'mlrun.get_or_create_ctx(None)'
:param model_name: name under whcih the model will be saved within the databse.
:param data: Optional: The train_test_split X_train, X_test, y_train, y_test can be passed,
or the test data X_test, y_test can be passed.
:return: The wrapped model.
:param model_handler: The model to wrap.
:param context: MLRun context to work with. If no context is given it will be retrieved via
'mlrun.get_or_create_ctx(None)'
:param tag: Tag for the model to log with.
:param data: The train_test_split X_train, X_test, y_train, y_test can be passed, or the test data
X_test, y_test can be passed.
"""
model = model_handler.model

Expand Down Expand Up @@ -102,6 +94,7 @@ def _post_fit(*args, **kwargs):
raise ValueError("No column name for y was specified")

model_handler.log(
tag=tag,
algorithm=str(model.__class__.__name__),
training_set=train_set,
label_column=label_column,
Expand Down
34 changes: 18 additions & 16 deletions mlrun/frameworks/lgbm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,42 @@ def apply_mlrun(
X_test: Union[np.ndarray, pd.DataFrame] = None,
y_test: Union[np.ndarray, pd.DataFrame] = None,
model_name: str = None,
tag: str = "",
generate_test_set: bool = True,
**kwargs
):
) -> LGBMModelHandler:
"""
Wrap the given model with MLRun model, saving the model's
attributes and methods while giving it mlrun's additional features.
examples::
Wrap the given model with MLRun model, saving the model's attributes and methods while giving it mlrun's additional
features.
example:
model = LGBMClassifier()
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=0.2)
model = apply_mlrun(model, context, X_test=X_test, y_test=y_test)
model.fit(X_train, y_train)
:param model: The model which will have the fit()
function wrapped
:param context: MLRun context to work with. If no context
is given it will be retrieved via
'mlrun.get_or_create_ctx(None)'
:param X_test: X_test dataset
:param y_test: y_test dataset
:param model_name: The model artifact name (Optional)
:param generate_test_set: Generates a test_set dataset artifact
:return: The model with MLRun's interface.
:param model: The model which will have the fit() function wrapped
:param context: MLRun context to work with. If no context is given it will be retrieved via
'mlrun.get_or_create_ctx(None)'
:param X_test: X_test dataset
:param y_test: y_test dataset
:param model_name: The model artifact name (Optional)
:param tag: Tag of a version to give to the logged model.
:param generate_test_set: Generates a test_set dataset artifact
:return: The model in a MLRun model handler.
"""
if context is None:
context = mlrun.get_or_create_ctx("mlrun_lgbm")

kwargs["X_test"] = X_test
kwargs["y_test"] = y_test
kwargs["generate_test_set"] = generate_test_set

mh = LGBMModelHandler(
model_name=model_name or "model", model=model, context=context
)

# Add MLRun's interface to the model:
MLMLRunInterface.add_interface(mh, context, model_name, kwargs)
MLMLRunInterface.add_interface(mh, context, tag, kwargs)
return mh
25 changes: 13 additions & 12 deletions mlrun/frameworks/sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,30 @@ def apply_mlrun(
X_test=None,
y_test=None,
model_name=None,
tag: str = "",
generate_test_set=True,
**kwargs
):
) -> SKLearnModelHandler:
"""
Wrap the given model with MLRun model, saving the model's attributes and methods while giving it mlrun's additional
features.
examples::
examples:
model = LogisticRegression()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
model = apply_mlrun(model, context, X_test=X_test, y_test=y_test)
model.fit(X_train, y_train)
:param model: The model to wrap.
:param context: MLRun context to work with. If no context is given it will be retrieved via
'mlrun.get_or_create_ctx(None)'
:param X_test: X test data (for accuracy and plots generation)
:param y_test: y test data (for accuracy and plots generation)
:param model_name: model artifact name
:param generate_test_set: will generate a test_set dataset artifact
:param model: The model which will have the fit() function wrapped
:param context: MLRun context to work with. If no context is given it will be retrieved via
'mlrun.get_or_create_ctx(None)'
:param X_test: X_test dataset
:param y_test: y_test dataset
:param model_name: The model artifact name (Optional)
:param tag: Tag of a version to give to the logged model.
:param generate_test_set: Generates a test_set dataset artifact
:return: The model with MLRun's interface.
:return: The model in a MLRun model handler.
"""
if context is None:
context = mlrun.get_or_create_ctx("mlrun_sklearn")
Expand All @@ -55,5 +56,5 @@ def apply_mlrun(
)

# Add MLRun's interface to the model:
MLMLRunInterface.add_interface(mh, context, model_name, kwargs)
MLMLRunInterface.add_interface(mh, context, tag, kwargs)
return mh
22 changes: 12 additions & 10 deletions mlrun/frameworks/xgboost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,30 @@ def apply_mlrun(
X_test=None,
y_test=None,
model_name=None,
tag: str = "",
generate_test_set=True,
**kwargs
) -> XGBoostModelHandler:
"""
Wrap the given model with MLRun model, saving the model's attributes and methods while giving it mlrun's additional
features.
examples::
examples:
model = XGBRegressor()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
apply_mlrun(model, context, X_test=X_test, y_test=y_test)
model.fit(X_train, y_train)
:param model: The model to wrap.
:param context: MLRun context to work with. If no context is given it will be retrieved via
'mlrun.get_or_create_ctx(None)'
:param X_test: X test data (for accuracy and plots generation)
:param y_test: y test data (for accuracy and plots generation)
:param model_name: model artifact name
:param generate_test_set: will generate a test_set dataset artifact
:param model: The model which will have the fit() function wrapped
:param context: MLRun context to work with. If no context is given it will be retrieved via
'mlrun.get_or_create_ctx(None)'
:param X_test: X_test dataset
:param y_test: y_test dataset
:param model_name: The model artifact name (Optional)
:param tag: Tag of a version to give to the logged model.
:param generate_test_set: Generates a test_set dataset artifact
:return: The model with MLRun's interface.
:return: The model in a MLRun model handler.
"""
if context is None:
context = mlrun.get_or_create_ctx("mlrun_xgb")
Expand All @@ -54,5 +56,5 @@ def apply_mlrun(
)

# Add MLRun's interface to the model:
MLMLRunInterface.add_interface(mh, context, model_name, kwargs)
MLMLRunInterface.add_interface(mh, context, tag, kwargs)
return mh

0 comments on commit 3b5cfe4

Please sign in to comment.