Skip to content

Commit

Permalink
Add LSTM prototype.
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejkula committed Jul 1, 2017
1 parent 81c1f86 commit dcca006
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 11 deletions.
20 changes: 15 additions & 5 deletions spotlight/sequence/pooling.py → spotlight/sequence/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
from spotlight.losses import (bpr_loss,
hinge_loss,
pointwise_loss)
from spotlight.sequence.representations import PoolNet
from spotlight.sequence.representations import LSTMNet, PoolNet
from spotlight.sampling import sample_items
from spotlight.torch_utils import cpu, gpu, minibatch, shuffle


class ImplicitPoolingModel(object):
class ImplicitSequenceModel(object):

def __init__(self,
loss='pointwise',
representation='pooling',
embedding_dim=32,
n_iter=10,
batch_size=256,
Expand All @@ -34,7 +35,11 @@ def __init__(self,
'hinge',
'adaptive_hinge')

assert representation in ('pooling',
'lstm')

self._loss = loss
self._representation = representation
self._embedding_dim = embedding_dim
self._n_iter = n_iter
self._learning_rate = learning_rate
Expand All @@ -57,9 +62,14 @@ def fit(self, interactions, verbose=False):

self._num_items = interactions.num_items

self._net = PoolNet(self._num_items,
self._embedding_dim,
self._sparse)
if self._representation == 'pooling':
self._net = PoolNet(self._num_items,
self._embedding_dim,
self._sparse)
else:
self._net = LSTMNet(self._num_items,
self._embedding_dim,
self._sparse)

if self._optimizer is None:
self._optimizer = optim.Adam(
Expand Down
39 changes: 39 additions & 0 deletions spotlight/sequence/representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,42 @@ def forward(self, user_representations, targets):
dot = (user_representations * target_embedding).sum(1)

return target_bias + dot


class LSTMNet(nn.Module):

def __init__(self, num_items, embedding_dim, sparse=False):
super().__init__()

self.embedding_dim = embedding_dim

self.item_embeddings = ScaledEmbedding(num_items, embedding_dim,
sparse=sparse,
padding_idx=PADDING_IDX)
self.item_biases = ZeroEmbedding(num_items, 1, sparse=sparse,
padding_idx=PADDING_IDX)

self.lstm = nn.LSTM(batch_first=True,
input_size=embedding_dim,
hidden_size=embedding_dim)

def user_representation(self, item_sequences):

sequence_embeddings = self.item_embeddings(item_sequences)

user_representations, (hidden, cell) = self.lstm(sequence_embeddings)

return hidden.view(-1, self.embedding_dim)

return user_representations[:, -1, :]

return user_representations

def forward(self, user_representations, targets):

target_embedding = self.item_embeddings(targets)
target_bias = self.item_biases(targets)

dot = (user_representations * target_embedding).sum(1)

return target_bias + dot
43 changes: 37 additions & 6 deletions tests/sequence/test_pooling.py → tests/sequence/test_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from spotlight.cross_validation import user_based_train_test_split
from spotlight.datasets import movielens, synthetic
from spotlight.evaluation import sequence_mrr_score
from spotlight.sequence.pooling import ImplicitPoolingModel
from spotlight.sequence.implicit import ImplicitSequenceModel


RANDOM_STATE = np.random.RandomState(42)
Expand All @@ -28,11 +28,42 @@ def test_implicit_pooling_synthetic(randomness, expected_mrr):
train = train.to_sequence(max_sequence_length=30)
test = test.to_sequence(max_sequence_length=30)

model = ImplicitPoolingModel(loss='bpr',
batch_size=1024,
learning_rate=1e-1,
l2=1e-9,
n_iter=3)
model = ImplicitSequenceModel(loss='bpr',
batch_size=1024,
learning_rate=1e-1,
l2=1e-9,
n_iter=3)
model.fit(train, verbose=True)
mrr = sequence_mrr_score(model, test)

print('MRR {} randomness {}'.format(mrr.mean(), randomness))

assert mrr.mean() > expected_mrr


@pytest.mark.parametrize('randomness, expected_mrr', [
(1e-3, 0.10),
(1e2, 0.005),
])
def test_implicit_lstm_synthetic(randomness, expected_mrr):

interactions = synthetic.generate_sequential(num_users=1000,
num_items=1000,
num_interactions=10000,
concentration_parameter=randomness,
random_state=RANDOM_STATE)
train, test = user_based_train_test_split(interactions,
random_state=RANDOM_STATE)

train = train.to_sequence(max_sequence_length=10)
test = test.to_sequence(max_sequence_length=10)

model = ImplicitSequenceModel(loss='bpr',
representation='lstm',
batch_size=128,
learning_rate=1e-1,
l2=1e-7,
n_iter=300)
model.fit(train, verbose=True)
mrr = sequence_mrr_score(model, test)

Expand Down

0 comments on commit dcca006

Please sign in to comment.