Reference: https://github.com/tensorflow/tensorflow/issues/342

In [1]:
import tensorflow as tf
import numpy as np
from pprint import pprint

#### A naive lookup table

In [2]:
# 1st lookup table
x_ones = np.ones(shape=(10,1), dtype=np.float32)

# 2nd lookup table (Check the difference below!)
x_one_to_nine = np.array([[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]], dtype=np.float32)

x_ones

array([[ 1.],
       [ 1.],
       [ 1.],
       [ 1.],
       [ 1.],
       [ 1.],
       [ 1.],
       [ 1.],
       [ 1.],
       [ 1.]], dtype=float32)

In [3]:
sp_indices = tf.placeholder(tf.int64)
sp_shape = tf.placeholder(tf.int64)
sp_ids_val = tf.placeholder(tf.int64)
sp_weights_val = tf.placeholder(tf.float32)

#### `tf.nn.embedding_lookup_sparse` is very similar to `tf.nn.embedding_lookup`
#### The only difference is it requires `tf.SparseTensor`s rather than `tf.Tensor`s

In [4]:
sp_ids = tf.SparseTensor(sp_indices, sp_ids_val, sp_shape)
sp_weights = tf.SparseTensor(sp_indices, sp_weights_val, sp_shape)

In [5]:
X = tf.placeholder(tf.float32, [10, 1])
y = tf.nn.embedding_lookup_sparse(X, sp_ids, sp_weights, combiner='sum')

In [6]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [7]:
y_values = sess.run(
    [y,
     tf.sparse_tensor_to_dense(sp_ids),
     tf.sparse_tensor_to_dense(sp_weights)],
    feed_dict={X: x_ones,
               sp_indices: [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # 3 entries in minibatch entry 0, 2 entries in entry 1.
               sp_shape: [2, 3], # batch size: 2, max index: 2 (so index count == 3)
               sp_ids_val: [2, 5, 8, 3, 4],
               sp_weights_val: [1.0, 1.5, 2.5, 3.5, 4.5]})

pprint(y_values)

[array([[ 5.],
       [ 8.]], dtype=float32),
 array([[2, 5, 8],
       [3, 4, 0]]),
 array([[ 1. ,  1.5,  2.5],
       [ 3.5,  4.5,  0. ]], dtype=float32)]


In [8]:
y_values = sess.run(
    [y,
     tf.sparse_tensor_to_dense(sp_ids),
     tf.sparse_tensor_to_dense(sp_weights)],
    feed_dict={X: x_one_to_nine,
               sp_indices: [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # 3 entries in minibatch entry 0, 2 entries in entry 1.
               sp_shape: [2, 3], # batch size: 2, max index: 2 (so index count == 3)
               sp_ids_val: [2, 5, 8, 3, 4],
               sp_weights_val: [1.0, 1.5, 2.5, 3.5, 4.5]})

pprint(y_values)

[array([[ 29.5],
       [ 28.5]], dtype=float32),
 array([[2, 5, 8],
       [3, 4, 0]]),
 array([[ 1. ,  1.5,  2.5],
       [ 3.5,  4.5,  0. ]], dtype=float32)]
