In [1]:
import tensorflow as tf
import numpy as np

In [72]:
tf.reset_default_graph()

# Number of examples
N = 4
# (Maximum) number of time steps in this batch
T = 8
RNN_DIM = 128
NUM_CLASSES = 10

# The *acutal* length of the examples
example_len = [1, 2, 3, 8]

# The classes of the examples at each step (between 1 and 9, 0 means padding)
y = np.random.randint(1, 10, [N, T])
for i, length in enumerate(example_len):
    y[i, length:] = 0   
    
# The RNN outputs
rnn_outputs = tf.convert_to_tensor(np.random.randn(N, T, RNN_DIM), dtype=tf.float32)

# Output layer weights
W = tf.get_variable(
    name="W",
    initializer=tf.random_normal_initializer(),
    shape=[RNN_DIM, NUM_CLASSES])

# Swap the first two dimensions (batch and time) of the outputs
rnn_outputs_t = tf.transpose(rnn_outputs, [1, 0, 2])

# For each time step, calculate the logits and probabilities per batch
logits = tf.map_fn(lambda x: tf.batch_matmul(x, W), rnn_outputs_t, name="logits")
probs = tf.map_fn(tf.nn.softmax, logits, name="probs")

# Calculate the losses
# Transpose logits back to [Batch, Time, Classes]
logits = tf.transpose(logits, [1, 0, 2])
losses = tf.map_fn(
    lambda x: tf.nn.sparse_softmax_cross_entropy_with_logits(*x),
    [logits, y],
    dtype=tf.float32)

# Set all losss where y=0 (padding) to 0
mask = tf.sign(tf.to_float(y))
masked_losses = mask * losses

# Calculate the losses for the whole sequence
mean_loss_by_example = tf.reduce_sum(masked_losses, reduction_indices=1) / example_len
mean_loss = tf.reduce_mean(mean_loss_by_example)

result = tf.contrib.learn.run_n(
    {
        "masked_losses": masked_losses,
        "mean_loss_by_example": mean_loss_by_example,
        "mean_loss": mean_loss
    },
    n=1,
    feed_dict=None)

print(result[0]["masked_losses"])
print(result[0]["mean_loss_by_example"])
print(result[0]["mean_loss"])

[[  2.38418551e-07   0.00000000e+00   0.00000000e+00   0.00000000e+00
    0.00000000e+00   0.00000000e+00   0.00000000e+00   0.00000000e+00]
 [  6.72377443e+00   3.12269268e+01   0.00000000e+00   0.00000000e+00
    0.00000000e+00   0.00000000e+00   0.00000000e+00   0.00000000e+00]
 [  7.59313777e-02   2.27465210e+01   2.21607952e+01   0.00000000e+00
    0.00000000e+00   0.00000000e+00   0.00000000e+00   0.00000000e+00]
 [  1.85533726e+00   1.92939949e+01   7.84722519e+00   9.50123692e+00
    1.75973301e+01   7.04502090e-05   3.47891846e+01   2.17570434e-03]]
[  2.38418551e-07   1.89753513e+01   1.49944153e+01   1.13608189e+01]
11.3326
