Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WMRB on Serial Predictions #45

Merged
merged 5 commits into from
Mar 10, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions tensorrec/loss_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,36 @@ class AbstractLossGraph(object):
is_sampled_with_replacement = False

@abc.abstractmethod
def connect_loss_graph(self, tf_prediction_serial, tf_interactions_serial, tf_prediction, tf_interactions,
tf_rankings, tf_sample_predictions):
def connect_loss_graph(self, tf_prediction_serial, tf_interactions_serial, tf_interactions, tf_n_users, tf_n_items,
tf_prediction, tf_rankings, tf_sample_predictions, tf_n_sampled_items):
"""
This method is responsible for consuming a number of possible nodes from the graph and calculating loss from
those nodes.

The following parameters are always passed in
:param tf_prediction_serial: tf.Tensor
The recommendation scores as a Tensor of shape [n_samples, 1]
:param tf_interactions_serial: tf.Tensor
The sample interactions corresponding to tf_prediction_serial as a Tensor of shape [n_samples, 1]
:param tf_prediction: tf.Tensor
The recommendation scores as a Tensor of shape [n_users, n_items]
:param tf_interactions: tf.SparseTensor
The sample interactions as a SparseTensor of shape [n_users, n_items]
:param tf_n_users: tf.placeholder
The number of users in tf_interactions
:param tf_n_items: tf.placeholder
The number of items in tf_interactions

The following parameters are passed in if is_dense is True
:param tf_prediction: tf.Tensor
The recommendation scores as a Tensor of shape [n_users, n_items]
:param tf_rankings: tf.Tensor
The item ranks as a Tensor of shape [n_users, n_items]

The following parameters are passed in if is_sample_based is True
:param tf_sample_predictions: tf.Tensor
The recommendation scores of a sample of items of shape [n_users, n_sampled_items]
:param tf_n_sampled_items: tf.placeholder
The number of items per user in tf_sample_predictions

:return: tf.Tensor
The loss value.
"""
Expand Down Expand Up @@ -124,30 +137,34 @@ class WMRBLossGraph(AbstractLossGraph):
Approximation of http://ceur-ws.org/Vol-1905/recsys2017_poster3.pdf
Interactions can be any positive values, but magnitude is ignored. Negative interactions are also ignored.
"""
is_dense = True
is_sample_based = True

def connect_loss_graph(self, tf_prediction, tf_interactions, tf_sample_predictions, **kwargs):
def connect_loss_graph(self, tf_prediction_serial, tf_interactions, tf_sample_predictions, tf_n_items,
tf_n_sampled_items, **kwargs):

# WMRB expects [-1, 1] bounded predictions
bounded_prediction = tf.nn.tanh(tf_prediction)
bounded_prediction = tf.nn.tanh(tf_prediction_serial)
bounded_sample_prediction = tf.nn.tanh(tf_sample_predictions)

return self.weighted_margin_rank_batch(tf_prediction=bounded_prediction,
return self.weighted_margin_rank_batch(tf_prediction_serial=bounded_prediction,
tf_interactions=tf_interactions,
tf_sample_predictions=bounded_sample_prediction)
tf_sample_predictions=bounded_sample_prediction,
tf_n_items=tf_n_items,
tf_n_sampled_items=tf_n_sampled_items)

@classmethod
def weighted_margin_rank_batch(cls, tf_prediction, tf_interactions, tf_sample_predictions):
def weighted_margin_rank_batch(cls, tf_prediction_serial, tf_interactions, tf_sample_predictions, tf_n_items,
tf_n_sampled_items):
positive_interaction_mask = tf.greater(tf_interactions.values, 0.0)
positive_interaction_indices = tf.boolean_mask(tf_interactions.indices,
positive_interaction_mask)

# [ n_positive_interactions ]
positive_predictions = tf.gather_nd(tf_prediction, indices=positive_interaction_indices)
positive_predictions = tf.boolean_mask(tf_prediction_serial,
positive_interaction_mask)

n_items = tf.cast(tf.shape(tf_prediction)[1], dtype=tf.float32)
n_sampled_items = tf.cast(tf.shape(tf_sample_predictions)[1], dtype=tf.float32)
n_items = tf.cast(tf_n_items, dtype=tf.float32)
n_sampled_items = tf.cast(tf_n_sampled_items, dtype=tf.float32)

# [ n_positive_interactions, n_sampled_items ]
mapped_predictions_sample_per_interaction = tf.gather(params=tf_sample_predictions,
Expand Down
21 changes: 8 additions & 13 deletions tensorrec/recommendation_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,17 @@ def bias_prediction_serial(tf_prediction_serial, tf_projected_user_biases, tf_pr
return tf_prediction_serial + gathered_user_biases + gathered_item_biases


def gather_sampled_item_predictions(tf_prediction, tf_sampled_item_indices):
def densify_sampled_item_predictions(tf_sample_predictions, tf_n_sampled_items, tf_n_users):
"""
Gathers the predictions for the given sampled items.
:param tf_prediction:
:param tf_sampled_item_indices:
Turns the serial predictions of the sample items in to a dense matrix of shape [ n_users, n_sampled_items ]
:param tf_sample_predictions:
:param tf_n_sampled_items:
:param tf_n_users:
:return:
"""
prediction_shape = tf.shape(tf_prediction)
flattened_prediction = tf.reshape(tf_prediction, shape=[prediction_shape[0] * prediction_shape[1]])

indices_shape = tf.shape(tf_sampled_item_indices)
flattened_indices = tf.reshape(tf_sampled_item_indices, shape=[indices_shape[0] * indices_shape[1]])

gathered_predictions = tf.gather(params=flattened_prediction, indices=flattened_indices)
reshaped_gathered_predictions = tf.reshape(gathered_predictions, shape=indices_shape)
return reshaped_gathered_predictions
densified_shape = tf.stack([tf_n_users, tf_n_sampled_items])
densified_predictions = tf.reshape(tf_sample_predictions, shape=densified_shape)
return densified_predictions


def rank_predictions(tf_prediction):
Expand Down
52 changes: 40 additions & 12 deletions tensorrec/tensorrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
EuclidianSimilarityPredictionGraph
)
from .recommendation_graphs import (project_biases, split_sparse_tensor_indices, bias_prediction_dense,
bias_prediction_serial, rank_predictions, gather_sampled_item_predictions)
bias_prediction_serial, rank_predictions, densify_sampled_item_predictions)
from .representation_graphs import AbstractRepresentationGraph, LinearRepresentationGraph
from .session_management import get_session
from .util import sample_items, calculate_batched_alpha
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(self, n_components=100,
# Feed placeholders
'tf_n_users', 'tf_n_items', 'tf_user_feature_indices', 'tf_user_feature_values', 'tf_item_feature_indices',
'tf_item_feature_values', 'tf_interaction_indices', 'tf_interaction_values', 'tf_learning_rate', 'tf_alpha',
'tf_sampled_item_indices'
'tf_sample_indices', 'tf_n_sampled_items'
]
self.graph_operation_hook_attr_names = [

Expand Down Expand Up @@ -245,15 +245,21 @@ def _build_tf_graph(self, n_user_features, n_item_features):
# Initialize placeholder values for inputs
self.tf_n_users = tf.placeholder('int64')
self.tf_n_items = tf.placeholder('int64')
self.tf_n_sampled_items = tf.placeholder('int64')

# SparseTensor placeholders
self.tf_user_feature_indices = tf.placeholder('int64', [None, 2])
self.tf_user_feature_values = tf.placeholder('float', [None])
self.tf_item_feature_indices = tf.placeholder('int64', [None, 2])
self.tf_item_feature_values = tf.placeholder('float', [None])
self.tf_interaction_indices = tf.placeholder('int64', [None, 2])
self.tf_interaction_values = tf.placeholder('float', [None])

self.tf_sample_indices = tf.placeholder('int64', [None, None])
self.tf_learning_rate = tf.placeholder('float', None)
self.tf_alpha = tf.placeholder('float', None)
self.tf_sampled_item_indices = tf.placeholder('int64', [None, None])

# from nose.tools import set_trace;set_trace()
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: remove


# Construct the features and interactions as sparse matrices
tf_user_features = tf.SparseTensor(self.tf_user_feature_indices, self.tf_user_feature_values,
Expand Down Expand Up @@ -294,6 +300,16 @@ def _build_tf_graph(self, n_user_features, n_item_features):
tf_x_item=tf_x_item,
)

tf_transposed_sample_indices = tf.transpose(self.tf_sample_indices)
tf_x_user_sample = tf_transposed_sample_indices[0]
tf_x_item_sample = tf_transposed_sample_indices[1]
tf_sample_predictions = self.prediction_graph_factory.connect_serial_prediction_graph(
tf_user_representation=self.tf_user_representation,
tf_item_representation=self.tf_item_representation,
tf_x_user=tf_x_user_sample,
tf_x_item=tf_x_item_sample,
)

# Add biases, if this is a biased estimator
if self.biased:
tf_user_feature_biases, tf_projected_user_biases = project_biases(
Expand All @@ -316,6 +332,12 @@ def _build_tf_graph(self, n_user_features, n_item_features):
tf_x_user=tf_x_user,
tf_x_item=tf_x_item)

tf_sample_predictions = bias_prediction_serial(tf_prediction_serial=tf_sample_predictions,
tf_projected_user_biases=tf_projected_user_biases,
tf_projected_item_biases=tf_projected_item_biases,
tf_x_user=tf_x_user_sample,
tf_x_item=tf_x_item_sample)

tf_interactions_serial = tf_interactions.values

# Construct API nodes
Expand All @@ -339,18 +361,23 @@ def _build_tf_graph(self, n_user_features, n_item_features):
loss_graph_kwargs = {
'tf_prediction_serial': self.tf_prediction_serial,
'tf_interactions_serial': tf_interactions_serial,
'tf_interactions': tf_interactions,
'tf_n_users': self.tf_n_users,
'tf_n_items': self.tf_n_items,
}
if self.loss_graph_factory.is_dense:
loss_graph_kwargs.update({
'tf_prediction': self.tf_prediction,
'tf_interactions': tf_interactions,
'tf_rankings': self.tf_rankings,
})
if self.loss_graph_factory.is_sample_based:
tf_sample_predictions = gather_sampled_item_predictions(
tf_prediction=self.tf_prediction, tf_sampled_item_indices=self.tf_sampled_item_indices
tf_sample_predictions = densify_sampled_item_predictions(
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: should rename one of these tf_sample_predictions

tf_sample_predictions=tf_sample_predictions,
tf_n_sampled_items=self.tf_n_sampled_items,
tf_n_users=self.tf_n_users,
)
loss_graph_kwargs.update({'tf_sample_predictions': tf_sample_predictions})
loss_graph_kwargs.update({'tf_sample_predictions': tf_sample_predictions,
'tf_n_sampled_items': self.tf_n_sampled_items})

# Build loss graph
self.tf_basic_loss = self.loss_graph_factory.connect_loss_graph(**loss_graph_kwargs)
Expand Down Expand Up @@ -475,11 +502,12 @@ def fit_partial(self, interactions, user_features, item_features, epochs=1, lear

# Handle random item sampling, if applicable
if self.loss_graph_factory.is_sample_based:
sampled_item_indices = sample_items(n_users=feed_dict[self.tf_n_users],
n_items=feed_dict[self.tf_n_items],
n_sampled_items=n_sampled_items,
replace=self.loss_graph_factory.is_sampled_with_replacement)
feed_dict[self.tf_sampled_item_indices] = sampled_item_indices
sample_indices = sample_items(n_users=feed_dict[self.tf_n_users],
n_items=feed_dict[self.tf_n_items],
n_sampled_items=n_sampled_items,
replace=self.loss_graph_factory.is_sampled_with_replacement)
feed_dict[self.tf_sample_indices] = sample_indices
feed_dict[self.tf_n_sampled_items] = n_sampled_items

# TODO find something more elegant than these cascaded ifs
if not verbose:
Expand Down
12 changes: 8 additions & 4 deletions tensorrec/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@


def sample_items(n_items, n_users, n_sampled_items, replace):
return np.array([
np.random.choice(a=n_items, size=n_sampled_items, replace=replace) + (user_count * n_users)
for user_count in range(n_users)
])
items_per_user = [np.random.choice(a=n_items, size=n_sampled_items, replace=replace) for _ in range(n_users)]

sample_indices = []
for user, users_items in enumerate(items_per_user):
for item in users_items:
sample_indices.append((user, item))

return sample_indices


def calculate_batched_alpha(num_batches, alpha):
Expand Down
27 changes: 10 additions & 17 deletions test/test_recommendation_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from tensorrec.recommendation_graphs import (
project_biases, split_sparse_tensor_indices, bias_prediction_dense, bias_prediction_serial,
gather_sampled_item_predictions, rank_predictions
densify_sampled_item_predictions, rank_predictions
)
from tensorrec.session_management import get_session

Expand Down Expand Up @@ -105,25 +105,18 @@ def test_bias_prediction_serial(self):

self.assertTrue((biased_predictions == expected_biased_predictions).all())

def test_gather_sampled_item_predictions(self):
input_data = np.array([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]
])
sample_indices = np.array([
[0, 3], # Corresponds to [1, 4]
[5, 6], # Corresponds to [6, 7]
[8, 8], # Corresponds to [9, 9]
])
result = gather_sampled_item_predictions(
tf_prediction=tf.identity(input_data), tf_sampled_item_indices=tf.identity(sample_indices)
def test_densify_sampled_item_predictions(self):
input_data = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
result = densify_sampled_item_predictions(
tf_sample_predictions=input_data,
tf_n_sampled_items=4,
tf_n_users=3,
).eval(session=self.session)

expected_result = np.array([
[1, 4],
[6, 7],
[9, 9],
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]
])
self.assertTrue((result == expected_result).all())

Expand Down