New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for saving and loading models #377
Closed
Closed
Changes from 14 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
1c640fa
Add support for saving and loading models
NegatioN a27d15d
Add a test to confirm we're saving all relevant parameters
NegatioN 26e172f
Add test to ensure model is instanciated on load()
NegatioN 7f63dcd
Properly clean up after tests
NegatioN 4f98c92
Formatting
NegatioN a4a5856
Reformat to match black settings
NegatioN 23efc0c
Add test to confirm performance of model is identical after model loa…
NegatioN 16bf53b
Change load method to classmethod
NegatioN 4beaebd
Remove redundant check of actual model performance
NegatioN 2968358
Don't load dataset multiple times
NegatioN 56060a7
reformat
NegatioN 12ab436
Trigger rebuild
NegatioN b981710
Change from classmethod to staticmethod since circleCI is not having it
NegatioN 6e556ad
Trigger build
NegatioN 67af48b
use pytest fixtures to instanziate and clean up
NegatioN b7fd48d
also update method tooltip
NegatioN File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import pytest | ||
|
||
import numpy as np | ||
import os | ||
|
||
from sklearn.metrics import roc_auc_score | ||
|
||
from lightfm.lightfm import LightFM | ||
from lightfm.datasets import fetch_movielens | ||
|
||
|
||
def _binarize(dataset): | ||
|
||
positives = dataset.data >= 4.0 | ||
dataset.data[positives] = 1.0 | ||
dataset.data[np.logical_not(positives)] = -1.0 | ||
|
||
return dataset | ||
|
||
|
||
def _cleanup(): | ||
os.remove(TEST_FILE_PATH) | ||
|
||
|
||
TEST_FILE_PATH = "./tests/test.npz" | ||
movielens = fetch_movielens() | ||
train, test = _binarize(movielens["train"]), _binarize(movielens["test"]) | ||
|
||
|
||
def test_all_params_persisted(): | ||
model = LightFM(loss="warp") | ||
model.fit(movielens["train"], epochs=1, num_threads=4) | ||
model.save(TEST_FILE_PATH) | ||
|
||
# Load and confirm all model params are present. | ||
saved_model_params = list(np.load(TEST_FILE_PATH).keys()) | ||
for x in dir(model): | ||
ob = getattr(model, x) | ||
# We don't need to persist model functions, or magic variables of the model. | ||
if not callable(ob) and not x.startswith("__"): | ||
assert x in saved_model_params | ||
|
||
_cleanup() | ||
|
||
|
||
def test_model_populated(): | ||
model = LightFM(loss="warp") | ||
model.fit(movielens["train"], epochs=1, num_threads=4) | ||
model.save(TEST_FILE_PATH) | ||
|
||
# Load a model onto an uninstanciated object | ||
loaded_model = LightFM.load_uncached(TEST_FILE_PATH) | ||
|
||
assert loaded_model.item_embeddings.any() | ||
assert loaded_model.user_embeddings.any() | ||
|
||
_cleanup() | ||
|
||
|
||
def test_model_performance(): | ||
# Train and persist a model | ||
model = LightFM(random_state=10) | ||
model.fit_partial(train, epochs=10, num_threads=4) | ||
model.save(TEST_FILE_PATH) | ||
|
||
train_predictions = model.predict(train.row, train.col) | ||
test_predictions = model.predict(test.row, test.col) | ||
|
||
trn_pred = roc_auc_score(train.data, train_predictions) | ||
tst_pred = roc_auc_score(test.data, test_predictions) | ||
|
||
# Performance is same as before when loaded from disk | ||
loaded_model = LightFM.load_uncached(TEST_FILE_PATH) | ||
|
||
train_predictions = loaded_model.predict(train.row, train.col) | ||
test_predictions = loaded_model.predict(test.row, test.col) | ||
|
||
assert roc_auc_score(train.data, train_predictions) == trn_pred | ||
assert roc_auc_score(test.data, test_predictions) == tst_pred | ||
|
||
_cleanup() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the test fails, this is never executed, and is never cleaned up. Could you use pytest fixtures for setup/teardown?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be done now. Needs another cache-clean. Also, if you do end up merging this, just squash everything. I don't think it makes sense to keep any of the history except the initial commit.