Skip to content

Commit

Permalink
Expose snt.{merge,split}_leading_dims.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 291003968
Change-Id: I7678d4aa15bb08650203afe0e21da9141c96a574
  • Loading branch information
tomhennigan authored and sonnet-copybara committed Jan 22, 2020
1 parent af606a4 commit c616bbd
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 6 deletions.
4 changes: 4 additions & 0 deletions sonnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from sonnet.src.base import no_name_scope
from sonnet.src.base import Optimizer
from sonnet.src.batch_apply import BatchApply
from sonnet.src.batch_apply import merge_leading_dims
from sonnet.src.batch_apply import split_leading_dim
from sonnet.src.batch_norm import BaseBatchNorm
from sonnet.src.batch_norm import BatchNorm
from sonnet.src.bias import Bias
Expand Down Expand Up @@ -132,6 +134,7 @@
"initializers",
"log_variables",
"lstm_with_recurrent_dropout",
"merge_leading_dims",
"no_name_scope",
"nets",
"once",
Expand All @@ -140,6 +143,7 @@
"pad",
"regularizers",
"scale_gradient",
"split_leading_dim",
"static_unroll",
)

Expand Down
59 changes: 53 additions & 6 deletions sonnet/src/batch_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __call__(self, *args, **kwargs):

num_dims = self.num_dims
merge = lambda x: merge_leading_dims(x, num_dims=num_dims)
split = lambda x: split_leading_dim(x, num_dims=num_dims, inputs=example)
split = lambda x: split_leading_dim(x, num_dims=num_dims, example=example)

# Merge leading dimensions of inputs.
# Example: [T, B, N] -> [T*B, N]
Expand Down Expand Up @@ -89,19 +89,41 @@ def first_leaf(args, kwargs) -> Optional[Any]:

def split_leading_dim(
x: Optional[tf.Tensor],
inputs: tf.Tensor,
example: tf.Tensor,
num_dims: int,
) -> Optional[tf.Tensor]:
"""Split the first dimension of a tensor."""
"""Split the first dimension of a tensor to match an example.
See :func:`merge_leading_dims`.
>>> x = tf.ones([6, 1])
>>> example = tf.ones([3, 2, 1])
>>> snt.split_leading_dim(x, example, 2)
<tf.Tensor: ...shape=(3, 2, 1), ...>
If ``x`` is not a :tf:`Tensor` or :tf:`Variable` then is is returned
unchanged:
>>> snt.split_leading_dim('not a tensor', example, 2)
'not a tensor'
Args:
x: A tensor with leading dim merged.
example: An Tensor with leading dim not merged.
num_dims: The number of leading dimensions of example to use.
Returns:
A tensor with leading dim split, or the input unchanged.
"""
if x is None or not isinstance(x, (tf.Tensor, tf.Variable)):
return x

static_shape = inputs.shape[:num_dims] + x.shape[1:]
static_shape = example.shape[:num_dims] + x.shape[1:]
if static_shape.is_fully_defined(): # pytype: disable=attribute-error
return tf.reshape(x, static_shape)

# Shape can't be inferred statically.
leading_dims = tf.shape(inputs)[:num_dims]
leading_dims = tf.shape(example)[:num_dims]
other_dims = tf.shape(x)[1:]
dynamic_shape = tf.concat([leading_dims, other_dims], axis=0)
return tf.reshape(x, dynamic_shape)
Expand All @@ -119,7 +141,32 @@ def merge_leading_dims(
x: Optional[tf.Tensor],
num_dims: int,
) -> Optional[tf.Tensor]:
"""Merges leading dimensions."""
"""Merges leading dimensions of a tensor.
See :func:`split_leading_dim`.
>>> x = tf.ones([3, 2, 1])
>>> snt.merge_leading_dims(x, num_dims=2)
<tf.Tensor: ...shape=(6, 1), ...>
If the rank of ``x`` is less than ``num_dims`` it is returned unchanged:
>>> snt.merge_leading_dims(x, 4)
<tf.Tensor: ...shape=(3, 2, 1), ...>
If ``x`` is not a :tf:`Tensor` or :tf:`Variable` then is is returned
unchanged:
>>> snt.merge_leading_dims('not a tensor', 1)
'not a tensor'
Args:
x: A :tf:`Tensor` to merge.
num_dims: The number of leading dimensions to merge.
Returns:
A :tf:`Tensor` with merged leading dimensions or the input unchanged.
"""
if x is None or not isinstance(x, (tf.Tensor, tf.Variable)):
return x

Expand Down

0 comments on commit c616bbd

Please sign in to comment.