From 79ad668a4ce0b41a7b96c8b8c56e8fc85b091b3f Mon Sep 17 00:00:00 2001 From: Pavel Belevich Date: Thu, 10 Dec 2020 19:26:54 -0500 Subject: [PATCH] Allow decryption output tensor to be less than input(skipping padding) [ghstack-poisoned] --- test/test_csprng.py | 12 +++++------- torchcsprng/csrc/kernels_body.inc | 3 ++- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/test/test_csprng.py b/test/test_csprng.py index b78087c..f50c66a 100644 --- a/test/test_csprng.py +++ b/test/test_csprng.py @@ -377,11 +377,12 @@ def create_aes(m, k): for initial_size in [0, 4, 8, 15, 16, 23, 42]: initial = torch.empty(initial_size, dtype=initial_dtype).random_() initial_np = initial.numpy().view(np.int8) + initial_size_bytes = initial_size * sizeof(initial_dtype) for encrypted_dtype in self.all_dtypes: - encrypted_size = (initial_size * sizeof(initial_dtype) + block_size_bytes - 1) // block_size_bytes * block_size_bytes // sizeof(encrypted_dtype) + encrypted_size = (initial_size_bytes + block_size_bytes - 1) // block_size_bytes * block_size_bytes // sizeof(encrypted_dtype) encrypted = torch.zeros(encrypted_size, dtype=encrypted_dtype) for decrypted_dtype in self.all_dtypes: - decrypted_size = (encrypted_size * sizeof(encrypted_dtype) + block_size_bytes - 1) // block_size_bytes * block_size_bytes // sizeof(decrypted_dtype) + decrypted_size = (initial_size_bytes + sizeof(decrypted_dtype) - 1) // sizeof(decrypted_dtype) decrypted = torch.zeros(decrypted_size, dtype=decrypted_dtype) for mode in ["ecb", "ctr"]: for device in self.all_devices: @@ -399,16 +400,13 @@ def create_aes(m, k): self.assertTrue(np.array_equal(encrypted_np, encrypted_expected)) csprng.decrypt(encrypted, decrypted, key, "aes128", mode) - decrypted_np = decrypted.cpu().numpy().view(np.int8) + decrypted_np = decrypted.cpu().numpy().view(np.int8)[:initial_size_bytes] aes = create_aes(mode, key_np) - decrypted_expected = np.frombuffer(aes.decrypt(pad(encrypted_np.tobytes(), block_size_bytes)), dtype=np.int8) + decrypted_expected = np.frombuffer(aes.decrypt(pad(encrypted_np.tobytes(), block_size_bytes)), dtype=np.int8)[:initial_size_bytes] self.assertTrue(np.array_equal(decrypted_np, decrypted_expected)) - padding_size_bytes = initial_size * sizeof(initial_dtype) - decrypted_size * sizeof(decrypted_dtype) - if padding_size_bytes != 0: - decrypted_np = decrypted_np[:padding_size_bytes] self.assertTrue(np.array_equal(initial_np, decrypted_np)) if __name__ == '__main__': diff --git a/torchcsprng/csrc/kernels_body.inc b/torchcsprng/csrc/kernels_body.inc index dd7c02e..639976b 100644 --- a/torchcsprng/csrc/kernels_body.inc +++ b/torchcsprng/csrc/kernels_body.inc @@ -420,7 +420,8 @@ Tensor decrypt(Tensor input, Tensor output, Tensor key, const std::string& ciphe TORCH_CHECK(input.device() == output.device() && input.device() == key.device(), "input, output and key tensors must have the same device"); const auto output_size_bytes = output.numel() * output.itemsize(); const auto input_size_bytes = input.numel() * input.itemsize(); - TORCH_CHECK(output_size_bytes == input_size_bytes, "input and output tensors must have the same size in byte"); + const auto diff = input_size_bytes - output_size_bytes; + TORCH_CHECK(0 <= diff && diff < aes::block_t_size, "output tensor size in bytes must be less then or equal to input tensor size in bytes, the difference must be less than block size"); TORCH_CHECK(input_size_bytes % aes::block_t_size == 0, "input tensor size in bytes must divisible by cipher block size in bytes"); check_cipher(cipher, key); const auto key_bytes = reinterpret_cast(key.contiguous().data_ptr());