Skip to content

Commit

Permalink
feat: save and load preds_mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp authored and Optimox committed Jan 12, 2021
1 parent 7ae20c9 commit cab643b
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 14 deletions.
2 changes: 2 additions & 0 deletions README.md
Expand Up @@ -70,6 +70,8 @@ clf.fit(
preds = clf.predict(X_test)
```

The targets on `y_train/y_valid` should contain a unique type (i.e. they must all be strings or integers).

### Default eval_metric

A few classical evaluation metrics are implemented (see bellow section for custom ones):
Expand Down
20 changes: 19 additions & 1 deletion census_example.ipynb
Expand Up @@ -288,6 +288,15 @@
"assert np.isclose(valid_auc, np.max(clf.history['valid_auc']), atol=1e-6)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"clf.predict(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -338,6 +347,15 @@
"assert(test_auc == loaded_test_auc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loaded_clf.predict(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -461,7 +479,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.7.5"
},
"toc": {
"base_numbering": 1,
Expand Down
6 changes: 3 additions & 3 deletions forest_example.ipynb
Expand Up @@ -495,9 +495,9 @@
],
"metadata": {
"kernelspec": {
"display_name": ".shap",
"display_name": "Python 3",
"language": "python",
"name": ".shap"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -509,7 +509,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
"version": "3.7.5"
},
"toc": {
"base_numbering": 1,
Expand Down
11 changes: 10 additions & 1 deletion multi_task_example.ipynb
Expand Up @@ -386,6 +386,15 @@
"assert(test_aucs == loaded_test_auc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loaded_clf.predict(X_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -448,7 +457,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
"version": "3.7.5"
},
"toc": {
"base_numbering": 1,
Expand Down
21 changes: 17 additions & 4 deletions pytorch_tabnet/abstract_model.py
Expand Up @@ -12,6 +12,7 @@
validate_eval_set,
create_dataloaders,
define_device,
ComplexEncoder,
)
from pytorch_tabnet.callbacks import (
CallbackContainer,
Expand Down Expand Up @@ -333,6 +334,10 @@ def load_weights_from_unsupervised(self, unsupervised_model):

self.network.load_state_dict(update_state_dict)

def load_class_attrs(self, class_attrs):
for attr_name, attr_value in class_attrs.items():
setattr(self, attr_name, attr_value)

def save_model(self, path):
"""Saving TabNet model in two distinct files.
Expand All @@ -348,19 +353,26 @@ def save_model(self, path):
"""
saved_params = {}
init_params = {}
for key, val in self.get_params().items():
if isinstance(val, type):
# Don't save torch specific params
continue
else:
saved_params[key] = val
init_params[key] = val
saved_params["init_params"] = init_params

class_attrs = {
"preds_mapper": self.preds_mapper
}
saved_params["class_attrs"] = class_attrs

# Create folder
Path(path).mkdir(parents=True, exist_ok=True)

# Save models params
with open(Path(path).joinpath("model_params.json"), "w", encoding="utf8") as f:
json.dump(saved_params, f)
json.dump(saved_params, f, cls=ComplexEncoder)

# Save state_dict
torch.save(self.network.state_dict(), Path(path).joinpath("network.pt"))
Expand All @@ -381,7 +393,7 @@ def load_model(self, filepath):
with zipfile.ZipFile(filepath) as z:
with z.open("model_params.json") as f:
loaded_params = json.load(f)
loaded_params["device_name"] = self.device_name
loaded_params["init_params"]["device_name"] = self.device_name
with z.open("network.pt") as f:
try:
saved_state_dict = torch.load(f, map_location=self.device)
Expand All @@ -396,11 +408,12 @@ def load_model(self, filepath):
except KeyError:
raise KeyError("Your zip file is missing at least one component")

self.__init__(**loaded_params)
self.__init__(**loaded_params["init_params"])

self._set_network()
self.network.load_state_dict(saved_state_dict)
self.network.eval()
self.load_class_attrs(loaded_params["class_attrs"])

return

Expand Down
11 changes: 11 additions & 0 deletions pytorch_tabnet/multiclass_utils.py
Expand Up @@ -16,6 +16,7 @@
import scipy.sparse as sp

import numpy as np
import pandas as pd


def _assert_all_finite(X, allow_nan=False):
Expand Down Expand Up @@ -344,6 +345,14 @@ def type_of_target(y):
return "binary" # [1, 2] or [["a"], ["b"]]


def check_unique_type(y):
target_types = pd.Series(y).map(type).unique()
if len(target_types) != 1:
raise TypeError(
f"Values on the target must have the same type. Target has types {target_types}"
)


def infer_output_dim(y_train):
"""
Infer output_dim from targets
Expand All @@ -360,6 +369,7 @@ def infer_output_dim(y_train):
train_labels : list
Sorted list of initial classes
"""
check_unique_type(y_train)
train_labels = unique_labels(y_train)
output_dim = len(train_labels)

Expand All @@ -368,6 +378,7 @@ def infer_output_dim(y_train):

def check_output_dim(labels, y):
if y is not None:
check_unique_type(y)
valid_labels = unique_labels(y)
if not set(valid_labels).issubset(set(labels)):
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions pytorch_tabnet/multitask.py
Expand Up @@ -75,7 +75,7 @@ def update_fit_params(self, X_train, y_train, eval_set, weights):
for classes in self.classes_
]
self.preds_mapper = [
{index: class_label for index, class_label in enumerate(classes)}
{str(index): str(class_label) for index, class_label in enumerate(classes)}
for classes in self.classes_
]
self.updated_weights = weights
Expand Down Expand Up @@ -121,7 +121,7 @@ def predict(self, X):
results = [np.hstack(task_res) for task_res in results.values()]
# map all task individually
results = [
np.vectorize(self.preds_mapper[task_idx].get)(task_res)
np.vectorize(self.preds_mapper[task_idx].get)(task_res.astype(str))
for task_idx, task_res in enumerate(results)
]
return results
Expand Down
5 changes: 3 additions & 2 deletions pytorch_tabnet/tab_model.py
Expand Up @@ -59,7 +59,7 @@ def update_fit_params(
class_label: index for index, class_label in enumerate(self.classes_)
}
self.preds_mapper = {
index: class_label for index, class_label in enumerate(self.classes_)
str(index): class_label for index, class_label in enumerate(self.classes_)
}
self.updated_weights = self.weight_updater(weights)

Expand All @@ -71,7 +71,7 @@ def stack_batches(self, list_y_true, list_y_score):

def predict_func(self, outputs):
outputs = np.argmax(outputs, axis=1)
return np.vectorize(self.preds_mapper.get)(outputs)
return np.vectorize(self.preds_mapper.get)(outputs.astype(str))

def predict_proba(self, X):
"""
Expand Down Expand Up @@ -132,6 +132,7 @@ def update_fit_params(
"Use reshape(-1, 1) for single regression."
raise ValueError(msg)
self.output_dim = y_train.shape[1]
self.preds_mapper = None

self.updated_weights = weights
filter_weights(self.updated_weights)
Expand Down
9 changes: 9 additions & 0 deletions pytorch_tabnet/utils.py
Expand Up @@ -3,6 +3,7 @@
import torch
import numpy as np
import scipy
import json
from sklearn.utils import check_array


Expand Down Expand Up @@ -328,3 +329,11 @@ def define_device(device_name):
return "cpu"
else:
return device_name


class ComplexEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.int64):
return int(obj)
# Let the base class default method raise the TypeError
return json.JSONEncoder.default(self, obj)
52 changes: 51 additions & 1 deletion regression_example.ipynb
Expand Up @@ -219,6 +219,56 @@
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_score}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Save model and load"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# save tabnet model\n",
"saving_path_name = \"./tabnet_model_test_1\"\n",
"saved_filepath = clf.save_model(saving_path_name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# define new model with basic parameters and load state dict weights\n",
"loaded_clf = TabNetRegressor()\n",
"loaded_clf.load_model(saved_filepath)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loaded_preds = loaded_clf.predict(X_test)\n",
"loaded_test_mse = mean_squared_error(loaded_preds, y_test)\n",
"\n",
"print(f\"FINAL TEST SCORE FOR {dataset_name} : {loaded_test_mse}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert(test_score == loaded_test_mse)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -352,7 +402,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.7.5"
},
"toc": {
"base_numbering": 1,
Expand Down

0 comments on commit cab643b

Please sign in to comment.