In [73]:
import tensorflow as tf
import numpy as np

In [78]:
tf.reset_default_graph()

# Number of examples
N = 4
# (Maximum) number of time steps in this batch
T = 8
RNN_DIM = 128
NUM_CLASSES = 10

# The *acutal* length of the examples
example_len = [1, 2, 3, 8]

# The classes of the examples at each step (between 1 and 9, 0 means padding)
y = np.random.randint(1, 10, [N, T])
for i, length in enumerate(example_len):
    y[i, length:] = 0   
    
# The RNN outputs
rnn_outputs = tf.convert_to_tensor(np.random.randn(N, T, RNN_DIM), dtype=tf.float32)

# Output layer weights
W = tf.get_variable(
    name="W",
    initializer=tf.random_normal_initializer(),
    shape=[RNN_DIM, NUM_CLASSES])

# Swap the first two dimensions (batch and time) of the outputs
rnn_outputs_t = tf.transpose(rnn_outputs, [1, 0, 2])

# For each time step, calculate the logits and probabilities per batch
logits = tf.map_fn(lambda x: tf.batch_matmul(x, W), rnn_outputs_t, name="logits")
probs = tf.map_fn(tf.nn.softmax, logits, name="probs")

# Calculate the losses for each time step
y_by_time = tf.transpose(y)
losses = tf.map_fn(
    lambda x: tf.nn.sparse_softmax_cross_entropy_with_logits(*x),
    [logits, y_by_time],
    dtype=tf.float32)

# Set all losss where y=0 (padding) to 0
mask = tf.sign(tf.to_float(y_by_time))
masked_losses = mask * losses

# Calculate the losses for the whole sequence
mean_loss_by_example = tf.reduce_sum(masked_losses, reduction_indices=0) / example_len
mean_loss = tf.reduce_mean(mean_loss_by_example)

result = tf.contrib.learn.run_n(
    {
        "masked_losses": masked_losses,
        "mean_loss_by_example": mean_loss_by_example,
        "mean_loss": mean_loss
    },
    n=1,
    feed_dict=None)

print(result[0]["masked_losses"])
print(result[0]["mean_loss_by_example"])
print(result[0]["mean_loss"])

[[ 10.8038826    6.80986166  24.33704185  21.3249836 ]
 [  0.          15.77543354  27.58551407  28.09634399]
 [  0.           0.           2.44341588  24.29099083]
 [  0.           0.           0.          20.08712196]
 [  0.           0.           0.          34.75185013]
 [  0.           0.           0.          44.74113464]
 [  0.           0.           0.          23.10951996]
 [  0.           0.           0.          20.82464027]]
[ 10.8038826   11.29264736  18.1219902   27.15332413]
16.843
