Skip to content

Commit

Permalink
Added more useful test cases w/ movielens data
Browse files Browse the repository at this point in the history
  • Loading branch information
ThoseGrapefruits committed Mar 12, 2018
1 parent c79711a commit 98b4282
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions test/test_tensorrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tensorrec import TensorRec
from tensorrec.util import generate_dummy_data_with_indicator, generate_dummy_data
from tensorrec.session_management import set_session
from test.datasets import get_movielens_100k


class TensorRecTestCase(TestCase):
Expand Down Expand Up @@ -127,13 +128,27 @@ def test_predict_item_bias_unbiased_model(self):
self.unbiased_model.predict_item_bias,
self.item_features)


class TensorRecMovielensPrediction(TestCase):

@classmethod
def setUpClass(cls):
cls.movielens_100k = get_movielens_100k()
train_interactions, test_interactions, cls.user_features, cls.item_features, _ = cls.movielens_100k
cls.model = TensorRec()
cls.model.fit(
interactions=train_interactions,
user_features=cls.user_features,
item_features=cls.item_features,
epochs=5)

def test_predict_user_bias(self):
user_bias = self.standard_model.predict_user_bias(self.user_features)
print(user_bias) # TODO what are expected values?
user_bias = self.model.predict_user_bias(self.user_features)
self.assertTrue(any(user_bias)) # Make sure it isn't all 0s

def test_predict_item_bias(self):
item_bias = self.standard_model.predict_item_bias(self.item_features)
print(item_bias) # TODO what are expected values?
item_bias = self.model.predict_item_bias(self.item_features)
self.assertTrue(any(item_bias)) # Make sure it isn't all 0s


class TensorRecNormalizedTestCase(TensorRecTestCase):
Expand Down

0 comments on commit 98b4282

Please sign in to comment.