Skip to content

Commit

Permalink
Don't load dataset multiple times
Browse files Browse the repository at this point in the history
  • Loading branch information
NegatioN committed Dec 4, 2018
1 parent 4beaebd commit 2968358
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions tests/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,24 @@ def cleanup():


def test_all_params_persisted():
# Train and persist a model
data = fetch_movielens(min_rating=5.0)
model = LightFM(loss="warp")
model.fit(data["train"], epochs=5, num_threads=4)
model.fit(movielens["train"], epochs=1, num_threads=4)
model.save(TEST_FILE_PATH)

# Load and confirm all model params are present.
saved_model_params = list(np.load(TEST_FILE_PATH).keys())
for x in dir(model):
ob = getattr(model, x)
# We don't need to persist model functions, or magic variables of the model.
if not callable(ob) and not x.startswith("__"):
assert x in saved_model_params

cleanup()


def test_model_populated():
# Train and persist a model
data = fetch_movielens(min_rating=5.0)
model = LightFM(loss="warp")
model.fit(data["train"], epochs=5, num_threads=4)
model.fit(movielens["train"], epochs=1, num_threads=4)
model.save(TEST_FILE_PATH)

# Load a model onto an uninstanciated object
Expand Down

0 comments on commit 2968358

Please sign in to comment.