From cab643b156fdecfded51d70d29072fc43f397bbb Mon Sep 17 00:00:00 2001 From: Eduardo Carvalho Date: Sat, 19 Dec 2020 00:20:39 +0100 Subject: [PATCH] feat: save and load preds_mapper --- README.md | 2 ++ census_example.ipynb | 20 +++++++++++- forest_example.ipynb | 6 ++-- multi_task_example.ipynb | 11 ++++++- pytorch_tabnet/abstract_model.py | 21 +++++++++--- pytorch_tabnet/multiclass_utils.py | 11 +++++++ pytorch_tabnet/multitask.py | 4 +-- pytorch_tabnet/tab_model.py | 5 +-- pytorch_tabnet/utils.py | 9 ++++++ regression_example.ipynb | 52 +++++++++++++++++++++++++++++- 10 files changed, 127 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 2b2295af..049c9c49 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,8 @@ clf.fit( preds = clf.predict(X_test) ``` +The targets on `y_train/y_valid` should contain a unique type (i.e. they must all be strings or integers). + ### Default eval_metric A few classical evaluation metrics are implemented (see bellow section for custom ones): diff --git a/census_example.ipynb b/census_example.ipynb index a3e3a070..3ceb0ad6 100755 --- a/census_example.ipynb +++ b/census_example.ipynb @@ -288,6 +288,15 @@ "assert np.isclose(valid_auc, np.max(clf.history['valid_auc']), atol=1e-6)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "clf.predict(X_test)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -338,6 +347,15 @@ "assert(test_auc == loaded_test_auc)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loaded_clf.predict(X_test)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -461,7 +479,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.7.5" }, "toc": { "base_numbering": 1, diff --git a/forest_example.ipynb b/forest_example.ipynb index 3cc2e82a..5ba8bd1d 100644 --- a/forest_example.ipynb +++ b/forest_example.ipynb @@ -495,9 +495,9 @@ ], "metadata": { "kernelspec": { - "display_name": ".shap", + "display_name": "Python 3", "language": "python", - "name": ".shap" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -509,7 +509,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.8" + "version": "3.7.5" }, "toc": { "base_numbering": 1, diff --git a/multi_task_example.ipynb b/multi_task_example.ipynb index 21acf83b..8766b497 100644 --- a/multi_task_example.ipynb +++ b/multi_task_example.ipynb @@ -386,6 +386,15 @@ "assert(test_aucs == loaded_test_auc)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loaded_clf.predict(X_test)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -448,7 +457,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.7.5" }, "toc": { "base_numbering": 1, diff --git a/pytorch_tabnet/abstract_model.py b/pytorch_tabnet/abstract_model.py index 714f9ceb..50d99bb1 100644 --- a/pytorch_tabnet/abstract_model.py +++ b/pytorch_tabnet/abstract_model.py @@ -12,6 +12,7 @@ validate_eval_set, create_dataloaders, define_device, + ComplexEncoder, ) from pytorch_tabnet.callbacks import ( CallbackContainer, @@ -333,6 +334,10 @@ def load_weights_from_unsupervised(self, unsupervised_model): self.network.load_state_dict(update_state_dict) + def load_class_attrs(self, class_attrs): + for attr_name, attr_value in class_attrs.items(): + setattr(self, attr_name, attr_value) + def save_model(self, path): """Saving TabNet model in two distinct files. @@ -348,19 +353,26 @@ def save_model(self, path): """ saved_params = {} + init_params = {} for key, val in self.get_params().items(): if isinstance(val, type): # Don't save torch specific params continue else: - saved_params[key] = val + init_params[key] = val + saved_params["init_params"] = init_params + + class_attrs = { + "preds_mapper": self.preds_mapper + } + saved_params["class_attrs"] = class_attrs # Create folder Path(path).mkdir(parents=True, exist_ok=True) # Save models params with open(Path(path).joinpath("model_params.json"), "w", encoding="utf8") as f: - json.dump(saved_params, f) + json.dump(saved_params, f, cls=ComplexEncoder) # Save state_dict torch.save(self.network.state_dict(), Path(path).joinpath("network.pt")) @@ -381,7 +393,7 @@ def load_model(self, filepath): with zipfile.ZipFile(filepath) as z: with z.open("model_params.json") as f: loaded_params = json.load(f) - loaded_params["device_name"] = self.device_name + loaded_params["init_params"]["device_name"] = self.device_name with z.open("network.pt") as f: try: saved_state_dict = torch.load(f, map_location=self.device) @@ -396,11 +408,12 @@ def load_model(self, filepath): except KeyError: raise KeyError("Your zip file is missing at least one component") - self.__init__(**loaded_params) + self.__init__(**loaded_params["init_params"]) self._set_network() self.network.load_state_dict(saved_state_dict) self.network.eval() + self.load_class_attrs(loaded_params["class_attrs"]) return diff --git a/pytorch_tabnet/multiclass_utils.py b/pytorch_tabnet/multiclass_utils.py index 9c279eec..8dbf08c5 100644 --- a/pytorch_tabnet/multiclass_utils.py +++ b/pytorch_tabnet/multiclass_utils.py @@ -16,6 +16,7 @@ import scipy.sparse as sp import numpy as np +import pandas as pd def _assert_all_finite(X, allow_nan=False): @@ -344,6 +345,14 @@ def type_of_target(y): return "binary" # [1, 2] or [["a"], ["b"]] +def check_unique_type(y): + target_types = pd.Series(y).map(type).unique() + if len(target_types) != 1: + raise TypeError( + f"Values on the target must have the same type. Target has types {target_types}" + ) + + def infer_output_dim(y_train): """ Infer output_dim from targets @@ -360,6 +369,7 @@ def infer_output_dim(y_train): train_labels : list Sorted list of initial classes """ + check_unique_type(y_train) train_labels = unique_labels(y_train) output_dim = len(train_labels) @@ -368,6 +378,7 @@ def infer_output_dim(y_train): def check_output_dim(labels, y): if y is not None: + check_unique_type(y) valid_labels = unique_labels(y) if not set(valid_labels).issubset(set(labels)): raise ValueError( diff --git a/pytorch_tabnet/multitask.py b/pytorch_tabnet/multitask.py index 2c30205e..c72c479f 100644 --- a/pytorch_tabnet/multitask.py +++ b/pytorch_tabnet/multitask.py @@ -75,7 +75,7 @@ def update_fit_params(self, X_train, y_train, eval_set, weights): for classes in self.classes_ ] self.preds_mapper = [ - {index: class_label for index, class_label in enumerate(classes)} + {str(index): str(class_label) for index, class_label in enumerate(classes)} for classes in self.classes_ ] self.updated_weights = weights @@ -121,7 +121,7 @@ def predict(self, X): results = [np.hstack(task_res) for task_res in results.values()] # map all task individually results = [ - np.vectorize(self.preds_mapper[task_idx].get)(task_res) + np.vectorize(self.preds_mapper[task_idx].get)(task_res.astype(str)) for task_idx, task_res in enumerate(results) ] return results diff --git a/pytorch_tabnet/tab_model.py b/pytorch_tabnet/tab_model.py index 143569ea..2ebae53c 100755 --- a/pytorch_tabnet/tab_model.py +++ b/pytorch_tabnet/tab_model.py @@ -59,7 +59,7 @@ def update_fit_params( class_label: index for index, class_label in enumerate(self.classes_) } self.preds_mapper = { - index: class_label for index, class_label in enumerate(self.classes_) + str(index): class_label for index, class_label in enumerate(self.classes_) } self.updated_weights = self.weight_updater(weights) @@ -71,7 +71,7 @@ def stack_batches(self, list_y_true, list_y_score): def predict_func(self, outputs): outputs = np.argmax(outputs, axis=1) - return np.vectorize(self.preds_mapper.get)(outputs) + return np.vectorize(self.preds_mapper.get)(outputs.astype(str)) def predict_proba(self, X): """ @@ -132,6 +132,7 @@ def update_fit_params( "Use reshape(-1, 1) for single regression." raise ValueError(msg) self.output_dim = y_train.shape[1] + self.preds_mapper = None self.updated_weights = weights filter_weights(self.updated_weights) diff --git a/pytorch_tabnet/utils.py b/pytorch_tabnet/utils.py index 690f7124..86a09c63 100644 --- a/pytorch_tabnet/utils.py +++ b/pytorch_tabnet/utils.py @@ -3,6 +3,7 @@ import torch import numpy as np import scipy +import json from sklearn.utils import check_array @@ -328,3 +329,11 @@ def define_device(device_name): return "cpu" else: return device_name + + +class ComplexEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.int64): + return int(obj) + # Let the base class default method raise the TypeError + return json.JSONEncoder.default(self, obj) diff --git a/regression_example.ipynb b/regression_example.ipynb index 5f18079b..efd0f8c1 100644 --- a/regression_example.ipynb +++ b/regression_example.ipynb @@ -219,6 +219,56 @@ "print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_score}\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Save model and load" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# save tabnet model\n", + "saving_path_name = \"./tabnet_model_test_1\"\n", + "saved_filepath = clf.save_model(saving_path_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# define new model with basic parameters and load state dict weights\n", + "loaded_clf = TabNetRegressor()\n", + "loaded_clf.load_model(saved_filepath)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loaded_preds = loaded_clf.predict(X_test)\n", + "loaded_test_mse = mean_squared_error(loaded_preds, y_test)\n", + "\n", + "print(f\"FINAL TEST SCORE FOR {dataset_name} : {loaded_test_mse}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert(test_score == loaded_test_mse)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -352,7 +402,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.7.5" }, "toc": { "base_numbering": 1,