Skip to content

Commit

Permalink
optimized the JIT compliation speed for flow
Browse files Browse the repository at this point in the history
  • Loading branch information
haowen-xu committed Mar 2, 2020
1 parent 299ed80 commit c6cdb3e
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 87 deletions.
49 changes: 31 additions & 18 deletions tensorkit/backend/pytorch_/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
# shape utils
'shape', 'rank', 'reshape', 'repeat', 'expand', 'squeeze', 'expand_dim',
'swap_axes', 'transpose',
'get_broadcast_shape', 'broadcast_to_shape', 'broadcast_to',
'explicit_broadcast', 'flatten_to_ndims',
'unflatten_from_ndims', 'reshape_tail',
'get_broadcast_shape', 'broadcast_to_shape', 'strict_broadcast_to_shape',
'broadcast_to', 'strict_broadcast_to', 'explicit_broadcast',
'flatten_to_ndims', 'unflatten_from_ndims', 'reshape_tail',

# split / join / indexing / gathering ...
'index_select', 'concat', 'split', 'stack', 'unstack', 'slice', 'slice_axis',
Expand Down Expand Up @@ -316,13 +316,14 @@ def as_tensor(data,

if isinstance(data, Tensor):
# input `data` may be `StochasticTensor`, `Tensor` or `numpy.ndarray`
kwargs = {}
if data.dtype != target_dtype:
kwargs['dtype'] = target_dtype
if str(data.device) != device:
kwargs['device'] = device
if kwargs:
data = data.to(**kwargs)
from_dev = str(data.device)
if data.dtype != target_dtype and from_dev != device:
data = data.to(device=device, dtype=target_dtype)
elif data.dtype != target_dtype:
data = data.to(target_dtype)
elif from_dev != device:
data = data.to(device=device)

if force_copy:
data = data.clone()
return data
Expand Down Expand Up @@ -756,10 +757,16 @@ def broadcast_to_shape(input: Tensor, new_shape: List[int]) -> Tensor:
output = input
if list(output.shape) != new_shape:
output = output + torch.zeros(new_shape, dtype=output.dtype, device=output.device)
if list(output.shape) != new_shape:
raise ValueError(
'`input` cannot be broadcast to `new_shape`: shape(input) {} '
'vs new_shape {}'.format(shape(input), new_shape))
return output


@jit
def strict_broadcast_to_shape(input: Tensor, new_shape: List[int]) -> Tensor:
output = broadcast_to_shape(input, new_shape)
if list(output.shape) != new_shape:
raise ValueError(
'`input` cannot be broadcast to `new_shape`: shape(input) {} '
'vs new_shape {}'.format(shape(input), new_shape))
return output


Expand All @@ -768,10 +775,16 @@ def broadcast_to(input: Tensor, target: Tensor) -> Tensor:
output = input
if output.shape != target.shape:
output = output + torch.zeros(target.shape, dtype=output.dtype, device=output.device)
if output.shape != target.shape:
raise ValueError(
'`input` cannot be broadcast to `target`: shape(input) {} '
'vs shape(target) {}'.format(shape(input), shape(target)))
return output


@jit
def strict_broadcast_to(input: Tensor, target: Tensor) -> Tensor:
output = broadcast_to(input, target)
if output.shape != target.shape:
raise ValueError(
'`input` cannot be broadcast to `target`: shape(input) {} '
'vs shape(target) {}'.format(shape(input), shape(target)))
return output


Expand Down
56 changes: 31 additions & 25 deletions tensorkit/backend/pytorch_/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def forward(self,
the previous flow layer and this layer.
"""
if inverse:
if not self.explicitly_invertible:
raise RuntimeError('Flow is not explicitly invertible.')
event_ndims = self.y_event_ndims
else:
event_ndims = self.x_event_ndims
Expand Down Expand Up @@ -291,17 +293,10 @@ def forward(self,

class SequentialFlow(Flow):

__constants__ = Flow.__constants__ + ('_chain', '_inverse_chain')
__constants__ = Flow.__constants__ + ('_chain',)

_chain: ModuleList

# The inverse chain is provided, such that JIT support is still okay.
# TODO: This separated inverse chain will cause `state_dict()` to have
# duplicated weights. Deal with this issue.
_inverse_chain: ModuleList

flatten_to_ndims: bool

def __init__(self,
*flows: Union[Module, Sequence[Module]]):
from tensorkit.layers import flatten_nested_layers
Expand Down Expand Up @@ -331,11 +326,27 @@ def __init__(self,
)

self._chain = ModuleList(flows)
if self.explicitly_invertible:
self._inverse_chain = ModuleList(reversed(flows))
else:
self._inverse_chain = ModuleList([_NotInvertibleFlow()])
self.flatten_to_ndims = bool(flatten_to_ndims)

# The following method is not compiled by JIT, because:
#
# 1. PyTorch JIT does not support "self._chain[::-1]" yet, nor does it
# support subscription in ModuleList.
# 2. If we provide a separated "_inverse_chain", then it will cost much
# more time to compile the module by JIT, and will double the number
# of parameters returned from `state_dict()`.
@jit_ignore
def _call_chain(self,
input: Tensor,
input_log_det: Optional[Tensor],
inverse: bool,
compute_log_det: bool
) -> Tuple[Tensor, Optional[Tensor]]:
output, output_log_det = input, input_log_det
chain = self._chain[::-1] if inverse else self._chain
for flow in chain:
output, output_log_det = flow(
output, output_log_det, inverse, compute_log_det)
return output, output_log_det

def _transform(self,
input: Tensor,
Expand All @@ -346,21 +357,15 @@ def _transform(self,
output, output_log_det = input, input_log_det
event_ndims = self.y_event_ndims if inverse else self.x_event_ndims

if rank(output) > event_ndims:
if rank(output) > event_ndims + 1:
output, batch_shape = flatten_to_ndims(output, event_ndims + 1)
if output_log_det is not None:
output_log_det = reshape(output_log_det, [-1])
else:
batch_shape: Optional[List[int]] = None

if inverse:
for flow in self._inverse_chain:
output, output_log_det = flow(
output, output_log_det, True, compute_log_det)
else:
for flow in self._chain:
output, output_log_det = flow(
output, output_log_det, False, compute_log_det)
output, output_log_det = self._call_chain(
output, output_log_det, inverse, compute_log_det)

if batch_shape is not None:
output = unflatten_from_ndims(output, batch_shape)
Expand Down Expand Up @@ -612,16 +617,17 @@ def _transform(self,
weight, log_det = self.invertible_matrix(
inverse=inverse, compute_log_det=compute_log_det)
spatial_ndims = self.x_event_ndims - 1
weight = reshape(weight, shape(weight) + [1] * spatial_ndims)
weight = torch.reshape(weight, weight.shape + (1,) * spatial_ndims)

# compute the output
output = self._affine_transform(input, weight)

# compute the log_det
output_log_det = input_log_det
if log_det is not None:
for axis in int_range(-spatial_ndims, 0):
log_det = log_det * float(input.shape[axis])
log_det *= torch.prod(
torch.as_tensor(input.shape[input.dim() - spatial_ndims:],
dtype=log_det.dtype, device=log_det.device))
if input_log_det is not None:
output_log_det = input_log_det + log_det
else:
Expand Down
15 changes: 10 additions & 5 deletions tensorkit/backend/pytorch_/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,16 @@ def cross_entropy_with_logits(logits: Tensor,
reduction: str = 'none', # {'sum', 'mean' or 'none'}
negative: bool = False) -> Tensor:
if logits.shape[:-1] != labels.shape:
logits_shape = list(logits.shape)
labels_shape = list(labels.shape)
b_shape = get_broadcast_shape(logits_shape[:-1], labels_shape)
logits = broadcast_to_shape(logits, b_shape + logits_shape[-1:])
labels = broadcast_to_shape(labels, b_shape)
labels = labels + torch.zeros(
logits.shape[:-1], dtype=labels.dtype, device=labels.device)
logits = logits + torch.zeros(
labels.shape + logits.shape[-1:], dtype=logits.dtype, device=logits.device)

# logits_shape = list(logits.shape)
# labels_shape = list(labels.shape)
# b_shape = get_broadcast_shape(logits_shape[:-1], labels_shape)
# logits = strict_broadcast_to_shape(logits, b_shape + logits_shape[-1:])
# labels = strict_broadcast_to_shape(labels, b_shape)

if len(logits.shape) < 2 or len(labels.shape) < 1:
raise ValueError('`logits` must be at least 2d, and `labels` must '
Expand Down
5 changes: 3 additions & 2 deletions tensorkit/distributions/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,15 @@ def _log_prob(self,
reduce_ndims: int) -> T.Tensor:
low = self.low if self.low is not None else 0.
high = self.high if self.high is not None else 1.
b_shape = T.get_broadcast_shape(T.shape(given), self.value_shape)
log_pdf = self._get_neg_log_high_minus_low()
log_pdf = log_pdf_mask(
T.logical_and(low <= given, given <= high),
log_pdf,
self.log_zero,
)
log_pdf = T.broadcast_to_shape(log_pdf, b_shape)

# broadcast against given if required
log_pdf = T.broadcast_to(log_pdf, given)
if reduce_ndims > 0:
log_pdf = T.reduce_sum(log_pdf, axis=list(range(-reduce_ndims, 0)))
return log_pdf
Expand Down
6 changes: 2 additions & 4 deletions tensorkit/flows/rearrangement.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,8 @@ def _transform(self,
inverse: bool,
compute_log_det: bool
) -> Tuple[Tensor, Optional[Tensor]]:
if inverse:
output = index_select(input, self.inv_permutation, axis=self.axis)
else:
output = index_select(input, self.permutation, axis=self.axis)
perm = self.inv_permutation if inverse else self.permutation
output = index_select(input, perm, axis=self.axis)
output_log_det = input_log_det
if compute_log_det and output_log_det is None:
output_log_det = float_scalar_like(0., input)
Expand Down
8 changes: 6 additions & 2 deletions tensorkit/flows/reshape_.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,14 @@ def _transform(self,
inverse: bool,
compute_log_det: bool
) -> Tuple[Tensor, Optional[Tensor]]:
target_shape: List[int] = []
if inverse:
output = reshape_tail(input, self.y_event_ndims, self.x_event_shape)
source_ndims = self.y_event_ndims
target_shape.extend(self.x_event_shape)
else:
output = reshape_tail(input, self.x_event_ndims, self.y_event_shape)
source_ndims = self.x_event_ndims
target_shape.extend(self.y_event_shape)
output = reshape_tail(input, source_ndims, target_shape)

output_log_det = input_log_det
if compute_log_det and output_log_det is None:
Expand Down
14 changes: 7 additions & 7 deletions tensorkit/flows/split_.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,28 +131,28 @@ def _transform(self,
inverse: bool,
compute_log_det: bool
) -> Tuple[Tensor, Optional[Tensor]]:
sections: List[int] = []
if inverse:
out_left, out_right = split(
input, sections=self.y_sections, axis=self.y_axis)
sections.extend(self.y_sections)
axis = self.y_axis
join_axis = self.x_axis
else:
out_left, out_right = split(
input, sections=self.x_sections, axis=self.x_axis)
sections.extend(self.x_sections)
axis = self.x_axis
join_axis = self.y_axis

# apply the left transformation
out_left, out_right = split(input, sections=sections, axis=axis)
out_left, output_log_det = self.left(
input=out_left, input_log_det=input_log_det, inverse=inverse,
compute_log_det=compute_log_det,
)

if self.right is not None:
out_right, output_log_det = self.right(
input=out_right, input_log_det=output_log_det, inverse=inverse,
compute_log_det=compute_log_det,
)
output = concat([out_left, out_right], axis=join_axis)

output = concat([out_left, out_right], axis=join_axis)
return output, output_log_det


Expand Down
2 changes: 1 addition & 1 deletion tensorkit/tensor/random_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def discretized_logistic_log_prob(given: Tensor,
if min_val is not None and max_val is not None:
if biased_edges:
# broadcasted given, shape == x_mid
broadcast_given = broadcast_to_shape(given, shape(x_low))
broadcast_given = broadcast_to(given, x_low)

# the left-edge bin case
# log(sigmoid(x_high) - sigmoid(-infinity))
Expand Down
14 changes: 6 additions & 8 deletions tests/flows/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,9 @@ def test_call(self):

# test output_log_det shape error
flow = tk.layers.jit_compile(_MyBadFlow())
with pytest.raises(Exception,
match='The shape of `output_log_det` is not expected'):
with pytest.raises(Exception, match='(shape|size)'):
_ = flow(x)
with pytest.raises(Exception,
match='The shape of `output_log_det` is not expected'):
with pytest.raises(Exception, match='(shape|size)'):
_ = flow(x, inverse=True)


Expand Down Expand Up @@ -278,7 +276,7 @@ def test_call(self):
flow = tk.layers.jit_compile(SequentialFlow(flows))

with pytest.raises(Exception,
match='Not an explicitly invertible flow'):
match='Flow is not explicitly invertible'):
_ = flow(x, inverse=True)


Expand Down Expand Up @@ -521,7 +519,7 @@ def test_ExpScale(self):
T.random.randn([2, 1, 1]),
T.random.randn([2, 3, 4])]:
expected_y = x * T.exp(pre_scale)
expected_log_det = T.broadcast_to_shape(pre_scale, T.shape(x))
expected_log_det = T.strict_broadcast_to_shape(pre_scale, T.shape(x))
check_scale(self, scale, x, pre_scale, expected_y, expected_log_det)

def test_SigmoidScale(self):
Expand All @@ -542,7 +540,7 @@ def test_SigmoidScale(self):
T.random.randn([2, 1, 1]),
T.random.randn([2, 3, 4])]:
expected_y = x * T.nn.sigmoid(pre_scale + pre_scale_bias)
expected_log_det = T.broadcast_to_shape(
expected_log_det = T.strict_broadcast_to_shape(
T.nn.log_sigmoid(pre_scale + pre_scale_bias), T.shape(x))
check_scale(self, scale, x, pre_scale, expected_y, expected_log_det)

Expand All @@ -557,7 +555,7 @@ def test_LinearScale(self):
T.random.randn([2, 1, 1]),
T.random.randn([2, 3, 4])]:
expected_y = x * pre_scale
expected_log_det = T.broadcast_to_shape(
expected_log_det = T.strict_broadcast_to_shape(
T.log(T.abs(pre_scale)), T.shape(x))
check_scale(self, scale, x, pre_scale, expected_y, expected_log_det)

Expand Down

0 comments on commit c6cdb3e

Please sign in to comment.