Skip to content

Commit

Permalink
added support in RNN estimators for variable sequence lengths
Browse files Browse the repository at this point in the history
  • Loading branch information
mheilman committed Mar 7, 2016
1 parent c99ed59 commit 71a3155
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions skflow/estimators/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ def null_input_op_fn(X):
return X


def _setup_sequence_length(sequence_length, X):
try:
# Try using self.sequence_length as a function that returns a
# tensor of ints as long as the input.
result = sequence_length(X)
except TypeError as e:
# sequence_length isn't a function, so assume it's an int or None.
result = sequence_length
return result


class TensorFlowRNNClassifier(TensorFlowEstimator, ClassifierMixin):
"""TensorFlow RNN Classifier model.
Expand All @@ -37,8 +48,12 @@ class TensorFlowRNNClassifier(TensorFlowEstimator, ClassifierMixin):
creating word embeddings, byte list, etc. This takes
an argument X for input and returns transformed X.
bidirectional: boolean, Whether this is a bidirectional rnn.
sequence_length: If sequence_length is provided, dynamic calculation is performed.
This saves computational time when unrolling past max sequence length.
sequence_length: A number or a function that returns a tensor of
integers representing the lengths of sequences in a
minibatch. The RNN will only compute outputs and
states for up to the specified number of inputs,
and the prediction will be made based on the last
state.
initial_state: An initial state for the RNN. This must be a tensor of appropriate type
and shape [batch_size x cell.state_size].
n_classes: Number of classes in the target.
Expand Down Expand Up @@ -100,11 +115,13 @@ def __init__(self, rnn_size, n_classes, cell_type='gru', num_layers=1,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

def _model_fn(self, X, y):
sequence_length = _setup_sequence_length(self.sequence_length, X)

return models.get_rnn_model(self.rnn_size, self.cell_type,
self.num_layers,
self.input_op_fn, self.bidirectional,
models.logistic_regression,
self.sequence_length,
sequence_length,
self.initial_state)(X, y)

@property
Expand All @@ -129,8 +146,12 @@ class TensorFlowRNNRegressor(TensorFlowEstimator, RegressorMixin):
creating word embeddings, byte list, etc. This takes
an argument X for input and returns transformed X.
bidirectional: boolean, Whether this is a bidirectional rnn.
sequence_length: If sequence_length is provided, dynamic calculation is performed.
This saves computational time when unrolling past max sequence length.
sequence_length: A number or a function that returns a tensor of
integers representing the lengths of sequences in a
minibatch. The RNN will only compute outputs and
states for up to the specified number of inputs,
and the prediction will be made based on the last
state.
initial_state: An initial state for the RNN. This must be a tensor of appropriate type
and shape [batch_size x cell.state_size].
tf_master: TensorFlow master. Empty string is default for local.
Expand Down Expand Up @@ -189,11 +210,13 @@ def __init__(self, rnn_size, cell_type='gru', num_layers=1,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

def _model_fn(self, X, y):
sequence_length = _setup_sequence_length(self.sequence_length, X)

return models.get_rnn_model(self.rnn_size, self.cell_type,
self.num_layers,
self.input_op_fn, self.bidirectional,
models.linear_regression,
self.sequence_length,
sequence_length,
self.initial_state)(X, y)

@property
Expand Down

0 comments on commit 71a3155

Please sign in to comment.