Skip to content

Commit

Permalink
Merge pull request #56 from jfkirk/attention
Browse files Browse the repository at this point in the history
ENH: Adds attention
  • Loading branch information
jfkirk committed Apr 10, 2018
2 parents f597f47 + 6bc4748 commit b799305
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 161 deletions.
57 changes: 57 additions & 0 deletions examples/attention_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from tensorrec import TensorRec
from tensorrec.eval import fit_and_eval
from tensorrec.representation_graphs import (
LinearRepresentationGraph, NormalizedLinearRepresentationGraph
)
from tensorrec.loss_graphs import BalancedWMRBLossGraph

from test.datasets import get_movielens_100k

import logging
logging.getLogger().setLevel(logging.INFO)

# Load the movielens dataset
train_interactions, test_interactions, user_features, item_features, _ = get_movielens_100k(negative_value=0)

# Construct parameters for fitting
epochs = 500
alpha = 0.00001
n_components = 10
verbose = True
learning_rate = .01
n_sampled_items = int(item_features.shape[0] * .1)
fit_kwargs = {'epochs': epochs, 'alpha': alpha, 'verbose': verbose, 'learning_rate': learning_rate,
'n_sampled_items': n_sampled_items}

# Build two models -- one without an attention graph, one with a linear attention graph
model_without_attention = TensorRec(
n_components=10,
n_tastes=3,
user_repr_graph=NormalizedLinearRepresentationGraph(),
attention_graph=None,
loss_graph=BalancedWMRBLossGraph(),
)

model_with_attention = TensorRec(
n_components=10,
n_tastes=3,
user_repr_graph=NormalizedLinearRepresentationGraph(),
attention_graph=LinearRepresentationGraph(),
loss_graph=BalancedWMRBLossGraph(),
)

results_without_attention = fit_and_eval(model=model_without_attention,
user_features=user_features,
item_features=item_features,
train_interactions=train_interactions,
test_interactions=test_interactions,
fit_kwargs=fit_kwargs)
results_with_attention = fit_and_eval(model=model_with_attention,
user_features=user_features,
item_features=item_features,
train_interactions=train_interactions,
test_interactions=test_interactions,
fit_kwargs=fit_kwargs)

logging.info("Results without attention: {}".format(results_without_attention))
logging.info("Results with attention: {}".format(results_with_attention))
60 changes: 29 additions & 31 deletions examples/check_movielens_losses.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from tensorrec import TensorRec
from tensorrec.eval import fit_and_eval
from tensorrec.representation_graphs import (
LinearRepresentationGraph, ReLURepresentationGraph
)
from tensorrec.loss_graphs import (
RMSEDenseLossGraph, SeparationDenseLossGraph, WMRBLossGraph,
LinearRepresentationGraph, ReLURepresentationGraph, NormalizedLinearRepresentationGraph
)
from tensorrec.loss_graphs import WMRBLossGraph, BalancedWMRBLossGraph
from tensorrec.prediction_graphs import (
DotProductPredictionGraph, CosineSimilarityPredictionGraph, EuclidianSimilarityPredictionGraph
)
Expand All @@ -26,6 +24,7 @@
verbose = True
learning_rate = .01
n_sampled_items = int(item_features.shape[0] * .1)
biased = False
fit_kwargs = {'epochs': epochs, 'alpha': alpha, 'verbose': verbose, 'learning_rate': learning_rate,
'n_sampled_items': n_sampled_items}

Expand All @@ -43,35 +42,34 @@
res_strings.append(header)

# Iterate through many possibilities for model configuration
for biased in (True, False):
for loss_graph in (RMSEDenseLossGraph, SeparationDenseLossGraph, WMRBLossGraph):
for pred_graph in (DotProductPredictionGraph, CosineSimilarityPredictionGraph,
EuclidianSimilarityPredictionGraph):
for repr_graph in (LinearRepresentationGraph, ReLURepresentationGraph):
for n_tastes in (1, 3):
for loss_graph in (WMRBLossGraph, BalancedWMRBLossGraph):
for pred_graph in (DotProductPredictionGraph, CosineSimilarityPredictionGraph,
EuclidianSimilarityPredictionGraph):
for repr_graph in (LinearRepresentationGraph, ReLURepresentationGraph):
for n_tastes in (1, 3):

# Build the model, fit, and get a result packet
model = TensorRec(n_components=n_components,
n_tastes=n_tastes,
biased=biased,
loss_graph=loss_graph(),
prediction_graph=pred_graph(),
user_repr_graph=LinearRepresentationGraph(),
item_repr_graph=repr_graph())
result = fit_and_eval(model, user_features, item_features, train_interactions, test_interactions,
fit_kwargs)
# Build the model, fit, and get a result packet
model = TensorRec(n_components=n_components,
n_tastes=n_tastes,
biased=biased,
loss_graph=loss_graph(),
prediction_graph=pred_graph(),
user_repr_graph=NormalizedLinearRepresentationGraph(),
item_repr_graph=repr_graph())
result = fit_and_eval(model, user_features, item_features, train_interactions, test_interactions,
fit_kwargs)

# Build results row for this configuration
res_string = "{}".format(loss_graph.__name__)
res_string = append_to_string_at_point(res_string, pred_graph.__name__, 30)
res_string = append_to_string_at_point(res_string, repr_graph.__name__, 66)
res_string = append_to_string_at_point(res_string, biased, 98)
res_string = append_to_string_at_point(res_string, n_tastes, 108)
res_string = append_to_string_at_point(res_string, ": {}".format(result[0]), 118)
res_string = append_to_string_at_point(res_string, result[1], 141)
res_string = append_to_string_at_point(res_string, result[2], 164)
res_strings.append(res_string)
print(res_string)
# Build results row for this configuration
res_string = "{}".format(loss_graph.__name__)
res_string = append_to_string_at_point(res_string, pred_graph.__name__, 30)
res_string = append_to_string_at_point(res_string, repr_graph.__name__, 66)
res_string = append_to_string_at_point(res_string, biased, 98)
res_string = append_to_string_at_point(res_string, n_tastes, 108)
res_string = append_to_string_at_point(res_string, ": {}".format(result[0]), 118)
res_string = append_to_string_at_point(res_string, result[1], 141)
res_string = append_to_string_at_point(res_string, result[2], 164)
res_strings.append(res_string)
print(res_string)

print('--------------------------------------------------')
for res_string in res_strings:
Expand Down
14 changes: 5 additions & 9 deletions examples/plot_movielens.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from tensorrec import TensorRec
from tensorrec.eval import precision_at_k, recall_at_k
from tensorrec.loss_graphs import BalancedWMRBLossGraph
from tensorrec.prediction_graphs import DotProductPredictionGraph
from tensorrec.representation_graphs import NormalizedLinearRepresentationGraph
from tensorrec.representation_graphs import ReLURepresentationGraph

from test.datasets import get_movielens_100k

Expand All @@ -31,10 +30,9 @@

# Build the TensorRec model
model = TensorRec(n_components=2,
biased=True,
biased=False,
loss_graph=BalancedWMRBLossGraph(),
prediction_graph=DotProductPredictionGraph(),
user_repr_graph=NormalizedLinearRepresentationGraph(),
item_repr_graph=ReLURepresentationGraph(),
normalize_users=True,
normalize_items=True,
n_tastes=3)
Expand All @@ -48,11 +46,9 @@
model.fit_partial(interactions=train_interactions, user_features=user_features, item_features=item_features,
**fit_kwargs)

# The position of a movie or user is that movie's/user's 2-dimensional representation. The size of the movie dot is
# related to its item bias.
# The position of a movie or user is that movie's/user's 2-dimensional representation.
movie_positions = model.predict_item_representation(item_features)
user_positions = model.predict_user_representation(user_features)
movie_sizes = model.predict_item_bias(item_features) * 10 + 1.0

# Handle multiple tastes, if applicable. If there are more than 1 taste per user, only the first of each user's
# tastes will be plotted.
Expand All @@ -64,7 +60,7 @@
ax.axhline(y=0, color='k')
ax.axvline(x=0, color='k')
ax.scatter(*zip(*user_positions[user_to_plot]), color='r', s=1)
ax.scatter(*zip(*movie_positions[movies_to_plot]), s=movie_sizes)
ax.scatter(*zip(*movie_positions[movies_to_plot]), s=2)
ax.set_aspect('equal')

for i, movie in enumerate(movies_to_plot):
Expand Down
35 changes: 26 additions & 9 deletions tensorrec/recommendation_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,31 @@ def rank_predictions(tf_prediction):
return tf.nn.top_k(-tf_indices_of_ranks, k=tf_prediction_item_size)[1] + 1


def collapse_mixture_of_tastes(tastes_predictions):
def collapse_mixture_of_tastes(tastes_predictions, tastes_attentions):
"""
Collapses a list of prediction nodes in to a single prediction node.
:param tastes_predictions:
:param tastes_attentions:
:return:
"""
stacked_tastes = tf.stack(tastes_predictions)
max_prediction = tf.reduce_max(stacked_tastes, axis=0)
return max_prediction
stacked_predictions = tf.stack(tastes_predictions)

# If there is attention, the attentions are used to weight each prediction
if tastes_attentions is not None:

# Stack the attentions and perform softmax across the tastes
stacked_attentions = tf.stack(tastes_attentions)
softmax_attentions = tf.nn.softmax(stacked_attentions, axis=0)

# The softmax'd attentions serve as weights for the taste predictiones
weighted_predictions = tf.multiply(stacked_predictions, softmax_attentions)
result_prediction = tf.reduce_sum(weighted_predictions, axis=0)

# If there is no attention, the max prediction is returned
else:
result_prediction = tf.reduce_max(stacked_predictions, axis=0)

return result_prediction


def relative_cosine(tf_tensor_1, tf_tensor_2):
Expand All @@ -105,16 +121,17 @@ def relative_cosine(tf_tensor_1, tf_tensor_2):
return tf.matmul(normalized_t1, normalized_t2, transpose_b=True)


def predict_similar_items(tf_item_representation, tf_similar_items_ids):
def predict_similar_items(prediction_graph_factory, tf_item_representation, tf_similar_items_ids):
"""
Calculates the cosine between the given item ids and all other items.
Calculates the similarity between the given item ids and all other items using the prediction graph.
:param prediction_graph_factory:
:param tf_item_representation:
:param tf_similar_items_ids:
:return:
"""
gathered_items = tf.gather(tf_item_representation, tf_similar_items_ids)
sims = relative_cosine(
tf_tensor_1=gathered_items,
tf_tensor_2=tf_item_representation
sims = prediction_graph_factory.connect_dense_prediction_graph(
tf_user_representation=gathered_items,
tf_item_representation=tf_item_representation
)
return sims
Loading

0 comments on commit b799305

Please sign in to comment.