diff --git a/modelstore/models/managers.py b/modelstore/models/managers.py index 9d8ed0df..7d1ba7d0 100644 --- a/modelstore/models/managers.py +++ b/modelstore/models/managers.py @@ -31,6 +31,7 @@ from modelstore.models.shap import ShapManager from modelstore.models.sklearn import SKLearnManager from modelstore.models.skorch import SkorchManager +from modelstore.models.statsmodels import StatsModelsManager from modelstore.models.tensorflow import TensorflowManager from modelstore.models.transformers import TransformersManager from modelstore.models.xgboost import XGBoostManager @@ -55,6 +56,7 @@ ShapManager, SKLearnManager, SkorchManager, + StatsModelsManager, TensorflowManager, TransformersManager, XGBoostManager, diff --git a/modelstore/models/statsmodels.py b/modelstore/models/statsmodels.py new file mode 100644 index 00000000..9c17a407 --- /dev/null +++ b/modelstore/models/statsmodels.py @@ -0,0 +1,76 @@ +# Copyright 2020 Neal Lathia +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from functools import partial +from typing import Any + +from modelstore.metadata import metadata +from modelstore.models.model_manager import ModelManager +from modelstore.storage.storage import CloudStorage + +MODEL_PICKLE = "model.pkl" + + +def _save_model(tmp_dir, model, file_name): + path = os.path.join(tmp_dir, file_name) + model.save(path) + return path + + +class StatsModelsManager(ModelManager): + + """ + Model persistence for statsmodels fitted result objects: + https://www.statsmodels.org/stable/index.html + """ + + NAME = "statsmodels" + + def __init__(self, storage: CloudStorage = None): + super().__init__(self.NAME, storage) + + def required_dependencies(self) -> list: + return ["statsmodels"] + + def _required_kwargs(self): + return ["model"] + + def matches_with(self, **kwargs) -> bool: + # pylint: disable=import-outside-toplevel + from statsmodels.base.wrapper import ResultsWrapper + + return isinstance(kwargs.get("model"), ResultsWrapper) + + def _get_functions(self, **kwargs) -> list: + if not self.matches_with(**kwargs): + raise TypeError("This model is not a statsmodels ResultsWrapper!") + + return [partial(_save_model, model=kwargs["model"], file_name=MODEL_PICKLE)] + + def get_params(self, **kwargs) -> dict: + try: + params = kwargs["model"].params + if hasattr(params, "to_dict"): + return params.to_dict() + # numpy array — convert to a plain dict keyed by position + return {str(i): float(v) for i, v in enumerate(params)} + except Exception: + return {} + + def load(self, model_path: str, meta_data: metadata.Summary) -> Any: + super().load(model_path, meta_data) + # pylint: disable=import-outside-toplevel + from statsmodels.iolib.smpickle import load_pickle + + return load_pickle(os.path.join(model_path, MODEL_PICKLE)) diff --git a/tests/models/test_managers.py b/tests/models/test_managers.py index e6615072..c8bea78a 100644 --- a/tests/models/test_managers.py +++ b/tests/models/test_managers.py @@ -27,7 +27,7 @@ def test_iter_libraries(): mgrs = {library: manager for library, manager in managers.iter_libraries()} - assert len(mgrs) == 18 + assert len(mgrs) == 19 assert isinstance(mgrs["sklearn"], SKLearnManager) assert isinstance(mgrs["pytorch"], PyTorchManager) assert isinstance(mgrs["xgboost"], XGBoostManager) diff --git a/tests/models/test_statsmodels.py b/tests/models/test_statsmodels.py new file mode 100644 index 00000000..616519c9 --- /dev/null +++ b/tests/models/test_statsmodels.py @@ -0,0 +1,93 @@ +# Copyright 2020 Neal Lathia +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os + +import numpy as np +import pandas as pd +import pytest +import statsmodels.api as sm + +from modelstore.metadata import metadata +from modelstore.models.statsmodels import MODEL_PICKLE, StatsModelsManager + +# pylint: disable=unused-import +from tests.models.utils import classification_data + +# pylint: disable=protected-access +# pylint: disable=redefined-outer-name +# pylint: disable=missing-function-docstring + + +@pytest.fixture +def statsmodels_model(classification_data): + X_train, y_train = classification_data + X_train = sm.add_constant(X_train) + model = sm.OLS(y_train, X_train).fit() + return model + + +@pytest.fixture +def statsmodels_manager(): + return StatsModelsManager() + + +def test_model_info(statsmodels_manager, statsmodels_model): + res = statsmodels_manager.model_info(model=statsmodels_model) + assert res == metadata.ModelType("statsmodels", "RegressionResultsWrapper", None) + + +def test_model_data(statsmodels_manager, statsmodels_model): + res = statsmodels_manager.model_data(model=statsmodels_model) + assert res is None + + +def test_required_kwargs(statsmodels_manager): + assert statsmodels_manager._required_kwargs() == ["model"] + + +def test_matches_with(statsmodels_manager, statsmodels_model): + assert statsmodels_manager.matches_with(model=statsmodels_model) + assert not statsmodels_manager.matches_with(model="a-string-value") + assert not statsmodels_manager.matches_with(classifier=statsmodels_model) + + +def test_get_functions(statsmodels_manager, statsmodels_model): + assert len(statsmodels_manager._get_functions(model=statsmodels_model)) == 1 + with pytest.raises(TypeError): + statsmodels_manager._get_functions(model="not-a-statsmodels-model") + + +def test_get_params(statsmodels_manager, statsmodels_model): + result = statsmodels_manager.get_params(model=statsmodels_model) + assert isinstance(result, dict) + assert len(result) > 0 + try: + json.dumps(result) + except Exception as exc: + pytest.fail(f"Params are not JSON-serialisable: {str(exc)}") + + +def test_load_model(tmp_path, statsmodels_manager, statsmodels_model): + # Save the model to a tmp directory + model_path = os.path.join(tmp_path, MODEL_PICKLE) + statsmodels_model.save(model_path) + assert os.path.exists(model_path) + + # Load the model + loaded_model = statsmodels_manager.load(tmp_path, None) + + # Expect type and params to match + assert type(loaded_model) == type(statsmodels_model) + assert np.allclose(loaded_model.params, statsmodels_model.params)