diff --git a/tensorkit/backend/pytorch_/core.py b/tensorkit/backend/pytorch_/core.py index 5096368..e9fac8e 100644 --- a/tensorkit/backend/pytorch_/core.py +++ b/tensorkit/backend/pytorch_/core.py @@ -52,9 +52,9 @@ # shape utils 'shape', 'rank', 'reshape', 'repeat', 'expand', 'squeeze', 'expand_dim', 'swap_axes', 'transpose', - 'get_broadcast_shape', 'broadcast_to_shape', 'broadcast_to', - 'explicit_broadcast', 'flatten_to_ndims', - 'unflatten_from_ndims', 'reshape_tail', + 'get_broadcast_shape', 'broadcast_to_shape', 'strict_broadcast_to_shape', + 'broadcast_to', 'strict_broadcast_to', 'explicit_broadcast', + 'flatten_to_ndims', 'unflatten_from_ndims', 'reshape_tail', # split / join / indexing / gathering ... 'index_select', 'concat', 'split', 'stack', 'unstack', 'slice', 'slice_axis', @@ -316,13 +316,14 @@ def as_tensor(data, if isinstance(data, Tensor): # input `data` may be `StochasticTensor`, `Tensor` or `numpy.ndarray` - kwargs = {} - if data.dtype != target_dtype: - kwargs['dtype'] = target_dtype - if str(data.device) != device: - kwargs['device'] = device - if kwargs: - data = data.to(**kwargs) + from_dev = str(data.device) + if data.dtype != target_dtype and from_dev != device: + data = data.to(device=device, dtype=target_dtype) + elif data.dtype != target_dtype: + data = data.to(target_dtype) + elif from_dev != device: + data = data.to(device=device) + if force_copy: data = data.clone() return data @@ -756,10 +757,16 @@ def broadcast_to_shape(input: Tensor, new_shape: List[int]) -> Tensor: output = input if list(output.shape) != new_shape: output = output + torch.zeros(new_shape, dtype=output.dtype, device=output.device) - if list(output.shape) != new_shape: - raise ValueError( - '`input` cannot be broadcast to `new_shape`: shape(input) {} ' - 'vs new_shape {}'.format(shape(input), new_shape)) + return output + + +@jit +def strict_broadcast_to_shape(input: Tensor, new_shape: List[int]) -> Tensor: + output = broadcast_to_shape(input, new_shape) + if list(output.shape) != new_shape: + raise ValueError( + '`input` cannot be broadcast to `new_shape`: shape(input) {} ' + 'vs new_shape {}'.format(shape(input), new_shape)) return output @@ -768,10 +775,16 @@ def broadcast_to(input: Tensor, target: Tensor) -> Tensor: output = input if output.shape != target.shape: output = output + torch.zeros(target.shape, dtype=output.dtype, device=output.device) - if output.shape != target.shape: - raise ValueError( - '`input` cannot be broadcast to `target`: shape(input) {} ' - 'vs shape(target) {}'.format(shape(input), shape(target))) + return output + + +@jit +def strict_broadcast_to(input: Tensor, target: Tensor) -> Tensor: + output = broadcast_to(input, target) + if output.shape != target.shape: + raise ValueError( + '`input` cannot be broadcast to `target`: shape(input) {} ' + 'vs shape(target) {}'.format(shape(input), shape(target))) return output diff --git a/tensorkit/backend/pytorch_/flows.py b/tensorkit/backend/pytorch_/flows.py index 431d57b..0c51671 100644 --- a/tensorkit/backend/pytorch_/flows.py +++ b/tensorkit/backend/pytorch_/flows.py @@ -148,6 +148,8 @@ def forward(self, the previous flow layer and this layer. """ if inverse: + if not self.explicitly_invertible: + raise RuntimeError('Flow is not explicitly invertible.') event_ndims = self.y_event_ndims else: event_ndims = self.x_event_ndims @@ -291,17 +293,10 @@ def forward(self, class SequentialFlow(Flow): - __constants__ = Flow.__constants__ + ('_chain', '_inverse_chain') + __constants__ = Flow.__constants__ + ('_chain',) _chain: ModuleList - # The inverse chain is provided, such that JIT support is still okay. - # TODO: This separated inverse chain will cause `state_dict()` to have - # duplicated weights. Deal with this issue. - _inverse_chain: ModuleList - - flatten_to_ndims: bool - def __init__(self, *flows: Union[Module, Sequence[Module]]): from tensorkit.layers import flatten_nested_layers @@ -331,11 +326,27 @@ def __init__(self, ) self._chain = ModuleList(flows) - if self.explicitly_invertible: - self._inverse_chain = ModuleList(reversed(flows)) - else: - self._inverse_chain = ModuleList([_NotInvertibleFlow()]) - self.flatten_to_ndims = bool(flatten_to_ndims) + + # The following method is not compiled by JIT, because: + # + # 1. PyTorch JIT does not support "self._chain[::-1]" yet, nor does it + # support subscription in ModuleList. + # 2. If we provide a separated "_inverse_chain", then it will cost much + # more time to compile the module by JIT, and will double the number + # of parameters returned from `state_dict()`. + @jit_ignore + def _call_chain(self, + input: Tensor, + input_log_det: Optional[Tensor], + inverse: bool, + compute_log_det: bool + ) -> Tuple[Tensor, Optional[Tensor]]: + output, output_log_det = input, input_log_det + chain = self._chain[::-1] if inverse else self._chain + for flow in chain: + output, output_log_det = flow( + output, output_log_det, inverse, compute_log_det) + return output, output_log_det def _transform(self, input: Tensor, @@ -346,21 +357,15 @@ def _transform(self, output, output_log_det = input, input_log_det event_ndims = self.y_event_ndims if inverse else self.x_event_ndims - if rank(output) > event_ndims: + if rank(output) > event_ndims + 1: output, batch_shape = flatten_to_ndims(output, event_ndims + 1) if output_log_det is not None: output_log_det = reshape(output_log_det, [-1]) else: batch_shape: Optional[List[int]] = None - if inverse: - for flow in self._inverse_chain: - output, output_log_det = flow( - output, output_log_det, True, compute_log_det) - else: - for flow in self._chain: - output, output_log_det = flow( - output, output_log_det, False, compute_log_det) + output, output_log_det = self._call_chain( + output, output_log_det, inverse, compute_log_det) if batch_shape is not None: output = unflatten_from_ndims(output, batch_shape) @@ -612,7 +617,7 @@ def _transform(self, weight, log_det = self.invertible_matrix( inverse=inverse, compute_log_det=compute_log_det) spatial_ndims = self.x_event_ndims - 1 - weight = reshape(weight, shape(weight) + [1] * spatial_ndims) + weight = torch.reshape(weight, weight.shape + (1,) * spatial_ndims) # compute the output output = self._affine_transform(input, weight) @@ -620,8 +625,9 @@ def _transform(self, # compute the log_det output_log_det = input_log_det if log_det is not None: - for axis in int_range(-spatial_ndims, 0): - log_det = log_det * float(input.shape[axis]) + log_det *= torch.prod( + torch.as_tensor(input.shape[input.dim() - spatial_ndims:], + dtype=log_det.dtype, device=log_det.device)) if input_log_det is not None: output_log_det = input_log_det + log_det else: diff --git a/tensorkit/backend/pytorch_/nn.py b/tensorkit/backend/pytorch_/nn.py index a35b1f7..85dcd6b 100644 --- a/tensorkit/backend/pytorch_/nn.py +++ b/tensorkit/backend/pytorch_/nn.py @@ -127,11 +127,16 @@ def cross_entropy_with_logits(logits: Tensor, reduction: str = 'none', # {'sum', 'mean' or 'none'} negative: bool = False) -> Tensor: if logits.shape[:-1] != labels.shape: - logits_shape = list(logits.shape) - labels_shape = list(labels.shape) - b_shape = get_broadcast_shape(logits_shape[:-1], labels_shape) - logits = broadcast_to_shape(logits, b_shape + logits_shape[-1:]) - labels = broadcast_to_shape(labels, b_shape) + labels = labels + torch.zeros( + logits.shape[:-1], dtype=labels.dtype, device=labels.device) + logits = logits + torch.zeros( + labels.shape + logits.shape[-1:], dtype=logits.dtype, device=logits.device) + + # logits_shape = list(logits.shape) + # labels_shape = list(labels.shape) + # b_shape = get_broadcast_shape(logits_shape[:-1], labels_shape) + # logits = strict_broadcast_to_shape(logits, b_shape + logits_shape[-1:]) + # labels = strict_broadcast_to_shape(labels, b_shape) if len(logits.shape) < 2 or len(labels.shape) < 1: raise ValueError('`logits` must be at least 2d, and `labels` must ' diff --git a/tensorkit/distributions/uniform.py b/tensorkit/distributions/uniform.py index fca142f..3a60b7e 100644 --- a/tensorkit/distributions/uniform.py +++ b/tensorkit/distributions/uniform.py @@ -147,14 +147,15 @@ def _log_prob(self, reduce_ndims: int) -> T.Tensor: low = self.low if self.low is not None else 0. high = self.high if self.high is not None else 1. - b_shape = T.get_broadcast_shape(T.shape(given), self.value_shape) log_pdf = self._get_neg_log_high_minus_low() log_pdf = log_pdf_mask( T.logical_and(low <= given, given <= high), log_pdf, self.log_zero, ) - log_pdf = T.broadcast_to_shape(log_pdf, b_shape) + + # broadcast against given if required + log_pdf = T.broadcast_to(log_pdf, given) if reduce_ndims > 0: log_pdf = T.reduce_sum(log_pdf, axis=list(range(-reduce_ndims, 0))) return log_pdf diff --git a/tensorkit/flows/rearrangement.py b/tensorkit/flows/rearrangement.py index 7d687b5..ae54275 100644 --- a/tensorkit/flows/rearrangement.py +++ b/tensorkit/flows/rearrangement.py @@ -65,10 +65,8 @@ def _transform(self, inverse: bool, compute_log_det: bool ) -> Tuple[Tensor, Optional[Tensor]]: - if inverse: - output = index_select(input, self.inv_permutation, axis=self.axis) - else: - output = index_select(input, self.permutation, axis=self.axis) + perm = self.inv_permutation if inverse else self.permutation + output = index_select(input, perm, axis=self.axis) output_log_det = input_log_det if compute_log_det and output_log_det is None: output_log_det = float_scalar_like(0., input) diff --git a/tensorkit/flows/reshape_.py b/tensorkit/flows/reshape_.py index b2a373a..bf6229b 100644 --- a/tensorkit/flows/reshape_.py +++ b/tensorkit/flows/reshape_.py @@ -77,10 +77,14 @@ def _transform(self, inverse: bool, compute_log_det: bool ) -> Tuple[Tensor, Optional[Tensor]]: + target_shape: List[int] = [] if inverse: - output = reshape_tail(input, self.y_event_ndims, self.x_event_shape) + source_ndims = self.y_event_ndims + target_shape.extend(self.x_event_shape) else: - output = reshape_tail(input, self.x_event_ndims, self.y_event_shape) + source_ndims = self.x_event_ndims + target_shape.extend(self.y_event_shape) + output = reshape_tail(input, source_ndims, target_shape) output_log_det = input_log_det if compute_log_det and output_log_det is None: diff --git a/tensorkit/flows/split_.py b/tensorkit/flows/split_.py index 6cf9111..7ee94b4 100644 --- a/tensorkit/flows/split_.py +++ b/tensorkit/flows/split_.py @@ -131,28 +131,28 @@ def _transform(self, inverse: bool, compute_log_det: bool ) -> Tuple[Tensor, Optional[Tensor]]: + sections: List[int] = [] if inverse: - out_left, out_right = split( - input, sections=self.y_sections, axis=self.y_axis) + sections.extend(self.y_sections) + axis = self.y_axis join_axis = self.x_axis else: - out_left, out_right = split( - input, sections=self.x_sections, axis=self.x_axis) + sections.extend(self.x_sections) + axis = self.x_axis join_axis = self.y_axis - # apply the left transformation + out_left, out_right = split(input, sections=sections, axis=axis) out_left, output_log_det = self.left( input=out_left, input_log_det=input_log_det, inverse=inverse, compute_log_det=compute_log_det, ) - if self.right is not None: out_right, output_log_det = self.right( input=out_right, input_log_det=output_log_det, inverse=inverse, compute_log_det=compute_log_det, ) - output = concat([out_left, out_right], axis=join_axis) + output = concat([out_left, out_right], axis=join_axis) return output, output_log_det diff --git a/tensorkit/tensor/random_extras.py b/tensorkit/tensor/random_extras.py index 6e75490..4227d77 100644 --- a/tensorkit/tensor/random_extras.py +++ b/tensorkit/tensor/random_extras.py @@ -342,7 +342,7 @@ def discretized_logistic_log_prob(given: Tensor, if min_val is not None and max_val is not None: if biased_edges: # broadcasted given, shape == x_mid - broadcast_given = broadcast_to_shape(given, shape(x_low)) + broadcast_given = broadcast_to(given, x_low) # the left-edge bin case # log(sigmoid(x_high) - sigmoid(-infinity)) diff --git a/tests/flows/test_core.py b/tests/flows/test_core.py index 952b13f..440638c 100644 --- a/tests/flows/test_core.py +++ b/tests/flows/test_core.py @@ -124,11 +124,9 @@ def test_call(self): # test output_log_det shape error flow = tk.layers.jit_compile(_MyBadFlow()) - with pytest.raises(Exception, - match='The shape of `output_log_det` is not expected'): + with pytest.raises(Exception, match='(shape|size)'): _ = flow(x) - with pytest.raises(Exception, - match='The shape of `output_log_det` is not expected'): + with pytest.raises(Exception, match='(shape|size)'): _ = flow(x, inverse=True) @@ -278,7 +276,7 @@ def test_call(self): flow = tk.layers.jit_compile(SequentialFlow(flows)) with pytest.raises(Exception, - match='Not an explicitly invertible flow'): + match='Flow is not explicitly invertible'): _ = flow(x, inverse=True) @@ -521,7 +519,7 @@ def test_ExpScale(self): T.random.randn([2, 1, 1]), T.random.randn([2, 3, 4])]: expected_y = x * T.exp(pre_scale) - expected_log_det = T.broadcast_to_shape(pre_scale, T.shape(x)) + expected_log_det = T.strict_broadcast_to_shape(pre_scale, T.shape(x)) check_scale(self, scale, x, pre_scale, expected_y, expected_log_det) def test_SigmoidScale(self): @@ -542,7 +540,7 @@ def test_SigmoidScale(self): T.random.randn([2, 1, 1]), T.random.randn([2, 3, 4])]: expected_y = x * T.nn.sigmoid(pre_scale + pre_scale_bias) - expected_log_det = T.broadcast_to_shape( + expected_log_det = T.strict_broadcast_to_shape( T.nn.log_sigmoid(pre_scale + pre_scale_bias), T.shape(x)) check_scale(self, scale, x, pre_scale, expected_y, expected_log_det) @@ -557,7 +555,7 @@ def test_LinearScale(self): T.random.randn([2, 1, 1]), T.random.randn([2, 3, 4])]: expected_y = x * pre_scale - expected_log_det = T.broadcast_to_shape( + expected_log_det = T.strict_broadcast_to_shape( T.log(T.abs(pre_scale)), T.shape(x)) check_scale(self, scale, x, pre_scale, expected_y, expected_log_det) diff --git a/tests/tensor/test_core.py b/tests/tensor/test_core.py index 4c5078e..8af2bcb 100644 --- a/tests/tensor/test_core.py +++ b/tests/tensor/test_core.py @@ -682,40 +682,50 @@ def test_shape_utils(self): with pytest.raises(Exception, match='cannot broadcast'): _ = T.get_broadcast_shape([2], [3]) - # test broadcast_to_shape + # test broadcast_to x = np.random.randn(1, 2, 1) t = T.as_tensor(x) g = lambda shape: T.ones(shape, dtype=T.boolean) - t2 = T.broadcast_to(t, g([4, 5, 2, 1])) - self.assertEqual(T.shape(t2), [4, 5, 2, 1]) - assert_equal(t2, np.tile(x.reshape([1, 1, 2, 1]), [4, 5, 1, 1])) + for fn in (T.broadcast_to, T.strict_broadcast_to): + t2 = fn(t, g([4, 5, 2, 1])) + self.assertEqual(T.shape(t2), [4, 5, 2, 1]) + assert_equal(t2, np.tile(x.reshape([1, 1, 2, 1]), [4, 5, 1, 1])) - with pytest.raises(Exception, match='(shape|size)'): - _ = T.broadcast_to(t, g([2, 5])) + with pytest.raises(Exception, match='(shape|size)'): + _ = fn(t, g([1, 5, 1])) + + assert_equal( + T.broadcast_to(t, g([2, 5])), + np.tile(x.reshape([1, 2, 1]), [1, 1, 5])) with pytest.raises(Exception, match='(shape|size)'): - _ = T.broadcast_to(t, g([1, 1, 1])) + _ = T.strict_broadcast_to(t, g([2, 5])) with pytest.raises(Exception, match='(shape|size)'): - _ = T.broadcast_to(t, g([1, 5, 1])) + _ = T.strict_broadcast_to(t, g([1, 1, 1])) # test broadcast_to_shape x = np.random.randn(1, 2, 1) t = T.as_tensor(x) - t2 = T.broadcast_to_shape(t, [4, 5, 2, 1]) - self.assertEqual(T.shape(t2), [4, 5, 2, 1]) - assert_equal(t2, np.tile(x.reshape([1, 1, 2, 1]), [4, 5, 1, 1])) + for fn in (T.broadcast_to_shape, T.strict_broadcast_to_shape): + t2 = fn(t, [4, 5, 2, 1]) + self.assertEqual(T.shape(t2), [4, 5, 2, 1]) + assert_equal(t2, np.tile(x.reshape([1, 1, 2, 1]), [4, 5, 1, 1])) - with pytest.raises(Exception, match='(shape|size)'): - _ = T.broadcast_to_shape(t, [2, 5]) + with pytest.raises(Exception, match='(shape|size)'): + _ = fn(t, [1, 5, 1]) + + assert_equal( + T.broadcast_to_shape(t, [2, 5]), + np.tile(x.reshape([1, 2, 1]), [1, 1, 5])) with pytest.raises(Exception, match='(shape|size)'): - _ = T.broadcast_to_shape(t, [1, 1, 1]) + _ = T.strict_broadcast_to_shape(t, [2, 5]) with pytest.raises(Exception, match='(shape|size)'): - _ = T.broadcast_to_shape(t, [1, 5, 1]) + _ = T.strict_broadcast_to_shape(t, [1, 1, 1]) # test explicit_broadcast def explicit_broadcast(x, y):