Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: Changes repr api to mirror loss+pred #41

Merged
merged 2 commits into from
Mar 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ install:

script:
- flake8 tensorrec
- flake8 test
- nosetests --with-timer

notifications:
Expand Down
52 changes: 26 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,34 +77,34 @@ import tensorflow as tf
import tensorrec

# Define a custom representation function graph
def tanh_representation_graph(tf_features, n_components, n_features, node_name_ending):
"""
This representation function embeds the user/item features by passing them through a single tanh layer.
:param tf_features: tf.SparseTensor
The user/item features as a SparseTensor of dimensions [n_users/items, n_features]
:param n_components: int
The dimensionality of the resulting representation.
:param n_features: int
The number of features in tf_features
:param node_name_ending: String
Either 'user' or 'item'
:return:
A tuple of (tf.Tensor, list) where the first value is the resulting representation in n_components
dimensions and the second value is a list containing all tf.Variables which should be subject to
regularization.
"""
tf_tanh_weights = tf.Variable(tf.random_normal([n_features, n_components],
stddev=.5),
name='tanh_weights_%s' % node_name_ending)

tf_repr = tf.nn.tanh(tf.sparse_tensor_dense_matmul(tf_features, tf_tanh_weights))

# Return repr layer and variables
return tf_repr, [tf_tanh_weights]
class TanhRepresentationGraph(tensorrec.representation_graphs.AbstractRepresentationGraph):
def connect_representation_graph(self, tf_features, n_components, n_features, node_name_ending):
"""
This representation function embeds the user/item features by passing them through a single tanh layer.
:param tf_features: tf.SparseTensor
The user/item features as a SparseTensor of dimensions [n_users/items, n_features]
:param n_components: int
The dimensionality of the resulting representation.
:param n_features: int
The number of features in tf_features
:param node_name_ending: String
Either 'user' or 'item'
:return:
A tuple of (tf.Tensor, list) where the first value is the resulting representation in n_components
dimensions and the second value is a list containing all tf.Variables which should be subject to
regularization.
"""
tf_tanh_weights = tf.Variable(tf.random_normal([n_features, n_components], stddev=.5),
name='tanh_weights_%s' % node_name_ending)

tf_repr = tf.nn.tanh(tf.sparse_tensor_dense_matmul(tf_features, tf_tanh_weights))

# Return repr layer and variables
return tf_repr, [tf_tanh_weights]

# Build a model with the custom representation function
model = tensorrec.TensorRec(user_repr_graph=tanh_representation_graph,
item_repr_graph=tanh_representation_graph)
model = tensorrec.TensorRec(user_repr_graph=TanhRepresentationGraph(),
item_repr_graph=TanhRepresentationGraph())

# Generate some dummy data
interactions, user_features, item_features = tensorrec.util.generate_dummy_data(num_users=100,
Expand Down
33 changes: 25 additions & 8 deletions etc/plot_movielens.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import os

import imageio
imageio.plugins.ffmpeg.download()
imageio.plugins.ffmpeg.download() # noqa

import matplotlib.pyplot as plt
import moviepy.editor as mpy
import numpy as np

from tensorrec import TensorRec
from tensorrec.loss_graphs import WMRBLossGraph
from tensorrec.prediction_graphs import DotProductPredictionGraph
from tensorrec.eval import precision_at_k, recall_at_k
from tensorrec.loss_graphs import SeparationDenseLossGraph
from tensorrec.prediction_graphs import EuclidianDistancePredictionGraph

from test.datasets import get_movielens_100k

Expand All @@ -19,7 +22,7 @@
get_movielens_100k(negative_value=-1.0)

epochs = 300
alpha = 0.00001
alpha = 0.0001
n_components = 2
biased = False
verbose = True
Expand All @@ -28,31 +31,45 @@
fit_kwargs = {'epochs': 1, 'alpha': alpha, 'verbose': verbose, 'learning_rate': learning_rate,
'n_sampled_items': int(item_features.shape[0] * .01)}

model = TensorRec(n_components=n_components, biased=biased, loss_graph=WMRBLossGraph(),
prediction_graph=DotProductPredictionGraph())
model = TensorRec(n_components=n_components, biased=biased, loss_graph=SeparationDenseLossGraph(),
prediction_graph=EuclidianDistancePredictionGraph())

for epoch in range(epochs):
model.fit_partial(interactions=train_interactions, user_features=user_features, item_features=item_features,
**fit_kwargs)

movie_positions = model.predict_item_representation(item_features)
user_positions = model.predict_user_representation(user_features)

movies_to_plot = (100, 200)
user_to_plot = (200, 400)

_, ax = plt.subplots()
ax.grid(b=True, which='both')
ax.axhline(y=0, color='k')
ax.axvline(x=0, color='k')
ax.scatter(*zip(*user_positions[user_to_plot[0]:user_to_plot[1]]), color='r', s=1)
ax.scatter(*zip(*movie_positions[movies_to_plot[0]:movies_to_plot[1]]))

for i, movie_name in enumerate(item_titles[movies_to_plot[0]:movies_to_plot[1]]):
ax.annotate(movie_name, movie_positions[i + movies_to_plot[0]], fontsize='x-small')
plt.savefig('/tmp/tensorrec/movielens/epoch_{}.png'.format(epoch))
logging.info("Finished epoch {}".format(epoch))

p_at_k = precision_at_k(model, test_interactions,
user_features=user_features,
item_features=item_features,
k=5)
r_at_k = recall_at_k(model, test_interactions,
user_features=user_features,
item_features=item_features,
k=30)

print("Precision:5: {}, Recall@30: {}".format(np.mean(p_at_k), np.mean(r_at_k)))

fps = 12
file_list = glob.glob('/tmp/tensorrec/movielens/*.png') # Get all the pngs in the current directory
list.sort(file_list, key=lambda x: int(x.split('_')[1].split('.png')[0])) # Sort the images by #, this may need to be tweaked for your use case
file_list = glob.glob('/tmp/tensorrec/movielens/*.png')
list.sort(file_list, key=lambda x: int(x.split('_')[1].split('.png')[0]))
clip = mpy.ImageSequenceClip(file_list, fps=fps)
clip.write_gif('/tmp/tensorrec/movielens/movielens.gif', fps=fps)
for file in file_list:
Expand Down
4 changes: 3 additions & 1 deletion tensorrec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .tensorrec import TensorRec
from . import eval
from . import loss_graphs
from . import representation_graphs
from . import prediction_graphs
from . import util

__version__ = '0.1'

__all__ = [TensorRec, eval, util, loss_graphs]
__all__ = [TensorRec, eval, util, loss_graphs, representation_graphs, prediction_graphs]

# Suppress TensorFlow logs
import logging
Expand Down
74 changes: 51 additions & 23 deletions tensorrec/representation_graphs.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,61 @@
import abc
import tensorflow as tf


def linear_representation_graph(tf_features, n_components, n_features, node_name_ending):
# Rough approximation of http://ceur-ws.org/Vol-1448/paper4.pdf
class AbstractRepresentationGraph(object):
__metaclass__ = abc.ABCMeta

# Create variable nodes
tf_linear_weights = tf.Variable(tf.random_normal([n_features, n_components], stddev=.5),
name='linear_weights_{}'.format(node_name_ending))
tf_repr = tf.sparse_tensor_dense_matmul(tf_features, tf_linear_weights)
@abc.abstractmethod
def connect_representation_graph(self, tf_features, n_components, n_features, node_name_ending):
pass

# Return repr layer and variables
return tf_repr, [tf_linear_weights]

class LinearRepresentationGraph(AbstractRepresentationGraph):
"""
Calculates the representation by passing the features through a linear embedding.
Rough approximation of http://ceur-ws.org/Vol-1448/paper4.pdf
"""

def relu_representation_graph(tf_features, n_components, n_features, node_name_ending):
relu_size = 4 * n_components
def connect_representation_graph(self, tf_features, n_components, n_features, node_name_ending):
# Create variable nodes
tf_linear_weights = tf.Variable(tf.random_normal([n_features, n_components], stddev=.5),
name='linear_weights_{}'.format(node_name_ending))
tf_repr = tf.sparse_tensor_dense_matmul(tf_features, tf_linear_weights)

# Create variable nodes
tf_relu_weights = tf.Variable(tf.random_normal([n_features, relu_size], stddev=.5),
name='relu_weights_{}'.format(node_name_ending))
tf_relu_biases = tf.Variable(tf.zeros([1, relu_size]),
name='relu_biases_{}'.format(node_name_ending))
tf_linear_weights = tf.Variable(tf.random_normal([relu_size, n_components], stddev=.5),
name='linear_weights_{}'.format(node_name_ending))
# Return repr layer and variables
return tf_repr, [tf_linear_weights]

# Create ReLU layer
tf_relu = tf.nn.relu(tf.add(tf.sparse_tensor_dense_matmul(tf_features, tf_relu_weights),
tf_relu_biases))
tf_repr = tf.matmul(tf_relu, tf_linear_weights)

# Return repr layer and variables
return tf_repr, [tf_relu_weights, tf_linear_weights, tf_relu_biases]
class ReLURepresentationGraph(AbstractRepresentationGraph):
"""
Calculates the repesentations by passing the features through a single-layer ReLU neural network.
:param relu_size: int or None
The number of nodes in the ReLU layer. If None, the layer will be of size 4*n_components.
"""

def __init__(self, relu_size=None):
self.relu_size = relu_size

def connect_representation_graph(self, tf_features, n_components, n_features, node_name_ending):

# Infer ReLU layer size if necessary
if self.relu_size is None:
relu_size = 4 * n_components
else:
relu_size = self.relu_size

# Create variable nodes
tf_relu_weights = tf.Variable(tf.random_normal([n_features, relu_size], stddev=.5),
name='relu_weights_{}'.format(node_name_ending))
tf_relu_biases = tf.Variable(tf.zeros([1, relu_size]),
name='relu_biases_{}'.format(node_name_ending))
tf_linear_weights = tf.Variable(tf.random_normal([relu_size, n_components], stddev=.5),
name='linear_weights_{}'.format(node_name_ending))

# Create ReLU layer
tf_relu = tf.nn.relu(tf.add(tf.sparse_tensor_dense_matmul(tf_features, tf_relu_weights),
tf_relu_biases))
tf_repr = tf.matmul(tf_relu, tf_linear_weights)

# Return repr layer and variables
return tf_repr, [tf_relu_weights, tf_linear_weights, tf_relu_biases]
29 changes: 17 additions & 12 deletions tensorrec/tensorrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
)
from .recommendation_graphs import (project_biases, split_sparse_tensor_indices, bias_prediction_dense,
bias_prediction_serial, rank_predictions, gather_sampled_item_predictions)
from .representation_graphs import linear_representation_graph
from .representation_graphs import AbstractRepresentationGraph, LinearRepresentationGraph
from .session_management import get_session
from .util import sample_items, calculate_batched_alpha


class TensorRec(object):

def __init__(self, n_components=100,
user_repr_graph=linear_representation_graph,
item_repr_graph=linear_representation_graph,
user_repr_graph=LinearRepresentationGraph(),
item_repr_graph=LinearRepresentationGraph(),
prediction_graph=DotProductPredictionGraph(),
loss_graph=RMSELossGraph(),
biased=True):
Expand All @@ -42,10 +42,15 @@ def __init__(self, n_components=100,
"""

# Arg-check
if (n_components is None) or (user_repr_graph is None) or (item_repr_graph is None) or (loss_graph is None):
if (n_components is None) or (user_repr_graph is None) or (item_repr_graph is None) \
or (prediction_graph is None) or (loss_graph is None):
raise ValueError("All arguments to TensorRec() must be non-None")
if n_components < 1:
raise ValueError("n_components must be >= 1")
if not isinstance(user_repr_graph, AbstractRepresentationGraph):
raise ValueError("user_repr_graph must inherit AbstractRepresentationGraph")
if not isinstance(item_repr_graph, AbstractRepresentationGraph):
raise ValueError("item_repr_graph must inherit AbstractRepresentationGraph")
if not isinstance(prediction_graph, AbstractPredictionGraph):
raise ValueError("prediction_graph must inherit AbstractPredictionGraph")
if not isinstance(loss_graph, AbstractLossGraph):
Expand Down Expand Up @@ -241,15 +246,15 @@ def _build_tf_graph(self, n_user_features, n_item_features):

# Build the representations
self.tf_user_representation, user_weights = \
self.user_repr_graph_factory(tf_features=tf_user_features,
n_components=self.n_components,
n_features=n_user_features,
node_name_ending='user')
self.user_repr_graph_factory.connect_representation_graph(tf_features=tf_user_features,
n_components=self.n_components,
n_features=n_user_features,
node_name_ending='user')
self.tf_item_representation, item_weights = \
self.item_repr_graph_factory(tf_features=tf_item_features,
n_components=self.n_components,
n_features=n_item_features,
node_name_ending='item')
self.item_repr_graph_factory.connect_representation_graph(tf_features=tf_item_features,
n_components=self.n_components,
n_features=n_item_features,
node_name_ending='item')

# Collect the weights for normalization
tf_weights = []
Expand Down
52 changes: 26 additions & 26 deletions test/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,34 @@ def test_custom_repr_graph(self):
import tensorrec

# Define a custom representation function graph
def tanh_representation_graph(tf_features, n_components, n_features, node_name_ending):
"""
This representation function embeds the user/item features by passing them through a single tanh layer.
:param tf_features: tf.SparseTensor
The user/item features as a SparseTensor of dimensions [n_users/items, n_features]
:param n_components: int
The dimensionality of the resulting representation.
:param n_features: int
The number of features in tf_features
:param node_name_ending: String
Either 'user' or 'item'
:return:
A tuple of (tf.Tensor, list) where the first value is the resulting representation in n_components
dimensions and the second value is a list containing all tf.Variables which should be subject to
regularization.
"""
tf_tanh_weights = tf.Variable(tf.random_normal([n_features, n_components],
stddev=.5),
name='tanh_weights_%s' % node_name_ending)

tf_repr = tf.nn.tanh(tf.sparse_tensor_dense_matmul(tf_features, tf_tanh_weights))

# Return repr layer and variables
return tf_repr, [tf_tanh_weights]
class TanhRepresentationGraph(tensorrec.representation_graphs.AbstractRepresentationGraph):
def connect_representation_graph(self, tf_features, n_components, n_features, node_name_ending):
"""
This representation function embeds the user/item features by passing them through a single tanh layer.
:param tf_features: tf.SparseTensor
The user/item features as a SparseTensor of dimensions [n_users/items, n_features]
:param n_components: int
The dimensionality of the resulting representation.
:param n_features: int
The number of features in tf_features
:param node_name_ending: String
Either 'user' or 'item'
:return:
A tuple of (tf.Tensor, list) where the first value is the resulting representation in n_components
dimensions and the second value is a list containing all tf.Variables which should be subject to
regularization.
"""
tf_tanh_weights = tf.Variable(tf.random_normal([n_features, n_components], stddev=.5),
name='tanh_weights_%s' % node_name_ending)

tf_repr = tf.nn.tanh(tf.sparse_tensor_dense_matmul(tf_features, tf_tanh_weights))

# Return repr layer and variables
return tf_repr, [tf_tanh_weights]

# Build a model with the custom representation function
model = tensorrec.TensorRec(user_repr_graph=tanh_representation_graph,
item_repr_graph=tanh_representation_graph)
model = tensorrec.TensorRec(user_repr_graph=TanhRepresentationGraph(),
item_repr_graph=TanhRepresentationGraph())

# Generate some dummy data
interactions, user_features, item_features = tensorrec.util.generate_dummy_data(num_users=100,
Expand Down