Skip to content

Commit

Permalink
BUGFIX: clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
nikisix committed May 22, 2018
1 parent 3077db7 commit e93c812
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/test_evaluation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from spotlight.cross_validation import random_train_test_split
from spotlight.datasets import movielens
from spotlight.factorization.implicit import ImplicitFactorizationModel
from spotlight.sequence import ImplicitSequenceModel


RANDOM_STATE = np.random.RandomState(42)
CUDA = bool(os.environ.get('SPOTLIGHT_CUDA', False))
Expand Down Expand Up @@ -41,24 +43,22 @@ def data_implicit_sequence():
train, test = random_train_test_split(interactions,
random_state=RANDOM_STATE)

h = hyperparameters

model = ImplicitSequenceModel(loss='adaptive_hinge',
representation='lstm',
batch_size=[8, 16, 32, 256],
learning_rate=[1e-3, 1e-2, 5 * 1e-2, 1e-1],
l2=[1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 0.0],
n_iter=list(range(1, 2)),
use_cuda=CUDA,
random_state=random_state)
random_state=RANDOM_STATE)

model.fit(train, verbose=True)

return train, test, model


@pytest.mark.parametrize('k', 10)
def test_precision_recall(data_implicit_sequence, k):
def test_sequence_precision_recall(data_implicit_sequence, k):

(train, test, model) = data_implicit_sequence

Expand Down

0 comments on commit e93c812

Please sign in to comment.