From a632bf4a5cc323f4d4bc42b1ca4286ff29752f9f Mon Sep 17 00:00:00 2001 From: Brian Knott Date: Wed, 23 Jun 2021 10:51:06 -0700 Subject: [PATCH] Handle all scaling logic in AST Summary: Remove the need for the `_scale` kwarg in logic functions by handling all encoding logic in `ArithmeticSharedTensor` Reviewed By: lvdmaaten Differential Revision: D29064976 fbshipit-source-id: 9f0f73b979422cfc979f04069868cd34f2fd3452 --- crypten/common/approximations.py | 4 +- crypten/mpc/max_helper.py | 2 +- crypten/mpc/mpc.py | 59 +++++++++++++--------------- crypten/mpc/primitives/arithmetic.py | 29 ++++++++++++++ crypten/mpc/primitives/beaver.py | 25 ++++++++++-- crypten/nn/module.py | 2 +- test/test_distributions.py | 2 +- test/test_mpc.py | 39 ++++++++---------- 8 files changed, 101 insertions(+), 61 deletions(-) diff --git a/crypten/common/approximations.py b/crypten/common/approximations.py index ae0d5c03..e8918a45 100644 --- a/crypten/common/approximations.py +++ b/crypten/common/approximations.py @@ -203,7 +203,7 @@ def reciprocal(self, input_in_01=False): method = config.reciprocal_method if not config.reciprocal_all_pos: - sgn = self.sign(_scale=False) + sgn = self.sign() pos = sgn * self with ConfigManager("reciprocal_all_pos", True): return sgn * reciprocal(pos) @@ -348,7 +348,7 @@ def sigmoid(self): tanh_approx = tanh(self.div(2)) return tanh_approx.div(2) + 0.5 elif method == "reciprocal": - ltz = self._ltz(_scale=False) + ltz = self._ltz() sign = 1 - 2 * ltz pos_input = self.mul(sign) diff --git a/crypten/mpc/max_helper.py b/crypten/mpc/max_helper.py index 7edfbcb3..d85643d3 100644 --- a/crypten/mpc/max_helper.py +++ b/crypten/mpc/max_helper.py @@ -28,7 +28,7 @@ def _argmax_helper_pairwise(enc_tensor, dim=None): # Use either prod or sum & comparison depending on size if row_length - 1 < torch.iinfo(torch.long).bits * 2: - pairwise_comparisons = a.ge(b, _scale=False) + pairwise_comparisons = a.ge(b) result = pairwise_comparisons.prod(0) result.share *= enc_tensor.encoder._scale result.encoder = enc_tensor.encoder diff --git a/crypten/mpc/mpc.py b/crypten/mpc/mpc.py index 5903598e..5fe2775e 100644 --- a/crypten/mpc/mpc.py +++ b/crypten/mpc/mpc.py @@ -413,56 +413,55 @@ def bernoulli(self): # Comparators @mode(Ptype.binary) - def _ltz(self, _scale=True): + def _ltz(self): """Returns 1 for elements that are < 0 and 0 otherwise""" shift = torch.iinfo(torch.long).bits - 1 - result = (self >> shift).to(Ptype.arithmetic, bits=1) - if _scale: - return result * result.encoder._scale - else: - result.encoder._scale = 1 - return result + + precision = 0 if self.encoder.scale == 1 else None + result = (self >> shift).to(Ptype.arithmetic, precision=precision, bits=1) + result.encoder._scale = 1 + return result @mode(Ptype.arithmetic) - def ge(self, y, _scale=True): + def ge(self, y): """Returns self >= y""" - return 1 - self.lt(y, _scale=_scale) + return 1 - self.lt(y) @mode(Ptype.arithmetic) - def gt(self, y, _scale=True): + def gt(self, y): """Returns self > y""" - return (-self + y)._ltz(_scale=_scale) + return (-self + y)._ltz() @mode(Ptype.arithmetic) - def le(self, y, _scale=True): + def le(self, y): """Returns self <= y""" - return 1 - self.gt(y, _scale=_scale) + return 1 - self.gt(y) @mode(Ptype.arithmetic) - def lt(self, y, _scale=True): + def lt(self, y): """Returns self < y""" - return (self - y)._ltz(_scale=_scale) + return (self - y)._ltz() @mode(Ptype.arithmetic) - def eq(self, y, _scale=True): + def eq(self, y): """Returns self == y""" if comm.get().get_world_size() == 2: - return (self - y)._eqz_2PC(_scale=_scale) + return (self - y)._eqz_2PC() - return 1 - self.ne(y, _scale=_scale) + return 1 - self.ne(y) @mode(Ptype.arithmetic) - def ne(self, y, _scale=True): + def ne(self, y): """Returns self != y""" if comm.get().get_world_size() == 2: - return 1 - self.eq(y, _scale=_scale) + return 1 - self.eq(y) difference = self - y difference.share = torch_stack([difference.share, -(difference.share)]) - return difference._ltz(_scale=_scale).sum(0) + return difference._ltz().sum(0) @mode(Ptype.arithmetic) - def _eqz_2PC(self, _scale=True): + def _eqz_2PC(self): """Returns self == 0""" # Create BinarySharedTensors from shares x0 = MPCTensor(self.share, src=0, ptype=Ptype.binary) @@ -470,30 +469,28 @@ def _eqz_2PC(self, _scale=True): # Perform equality testing using binary shares x0._tensor = x0._tensor.eq(x1._tensor) - x0.encoder = x0.encoder if _scale else self.encoder + x0.encoder = self.encoder # Convert to Arithmetic sharing result = x0.to(Ptype.arithmetic, bits=1) - - if not _scale: - result.encoder._scale = 1 + result.encoder._scale = 1 return result @mode(Ptype.arithmetic) - def sign(self, _scale=True): + def sign(self): """Computes the sign value of a tensor (0 is considered positive)""" - return 1 - 2 * self._ltz(_scale=_scale) + return 1 - 2 * self._ltz() @mode(Ptype.arithmetic) def abs(self): """Computes the absolute value of a tensor""" - return self * self.sign(_scale=False) + return self * self.sign() @mode(Ptype.arithmetic) def relu(self): """Compute a Rectified Linear function on the input tensor.""" - return self * self.ge(0, _scale=False) + return self * self.ge(0) @mode(Ptype.arithmetic) def weighted_index(self, dim=None): @@ -521,7 +518,7 @@ def weighted_index(self, dim=None): ) r = MPCTensor.rand(max_weight.size(), device=self.device) * max_weight - gt = x.gt(r, _scale=False) + gt = x.gt(r) shifted = gt.roll(1, dims=dim) shifted.share.index_fill_(dim, torch.tensor(0, device=self.device), 0) diff --git a/crypten/mpc/primitives/arithmetic.py b/crypten/mpc/primitives/arithmetic.py index 0a46c07e..f59453e4 100644 --- a/crypten/mpc/primitives/arithmetic.py +++ b/crypten/mpc/primitives/arithmetic.py @@ -310,6 +310,30 @@ def get_plain_text(self, dst=None): return torch.empty(self.share.size()) return self.encoder.decode(self.reveal(dst=dst)) + def encode_(self, new_encoder): + """Rescales the input to a new encoding in-place""" + if self.encoder.scale == new_encoder.scale: + return self + elif self.encoder.scale < new_encoder.scale: + scale_factor = new_encoder.scale // self.encoder.scale + self.share *= scale_factor + else: + scale_factor = self.encoder.scale // new_encoder.scale + self = self.div_(scale_factor) + self.encoder = new_encoder + return self + + def encode(self, new_encoder): + """Rescales the input to a new encoding""" + return self.clone().encode_(new_encoder) + + def encode_as_(self, other): + """Rescales self to have the same encoding as other""" + return self.encode_(other.encoder) + + def encode_as(self, other): + return self.encode(other.encoder) + def _arithmetic_function_(self, y, op, *args, **kwargs): return self._arithmetic_function(y, op, inplace=True, *args, **kwargs) @@ -350,6 +374,11 @@ def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): # noqa:C result.share = getattr(torch, op)(result.share, y, *args, **kwargs) elif private: if additive_func: # ['add', 'sub', 'add_', 'sub_'] + # Re-encode if necessary: + if self.encoder.scale > y.encoder.scale: + y.encode_as_(result) + elif self.encoder.scale < y.encoder.scale: + result.encode_as_(y) result.share = getattr(result.share, op)(y.share) else: # ['mul', 'matmul', 'convNd', 'conv_transposeNd'] # NOTE: 'mul_' calls 'mul' here diff --git a/crypten/mpc/primitives/beaver.py b/crypten/mpc/primitives/beaver.py index 7da8b43f..21758b97 100644 --- a/crypten/mpc/primitives/beaver.py +++ b/crypten/mpc/primitives/beaver.py @@ -11,6 +11,22 @@ from crypten.common.util import count_wraps +class IgnoreEncodings: + """Context Manager to ignore tensor encodings""" + + def __init__(self, list_of_tensors): + self.list_of_tensors = list_of_tensors + self.encodings_cache = [tensor.encoder.scale for tensor in list_of_tensors] + + def __enter__(self): + for tensor in self.list_of_tensors: + tensor.encoder._scale = 1 + + def __exit__(self, exc_type, exc_value, exc_traceback): + for i, tensor in enumerate(self.list_of_tensors): + tensor.encoder._scale = self.encodings_cache[i] + + def __beaver_protocol(op, x, y, *args, **kwargs): """Performs Beaver protocol for additively secret-shared tensors x and y @@ -58,7 +74,8 @@ def __beaver_protocol(op, x, y, *args, **kwargs): raise ValueError("Beaver Triples verification failed!") # Vectorized reveal to reduce rounds of communication - epsilon, delta = ArithmeticSharedTensor.reveal_batch([x - a, y - b]) + with IgnoreEncodings([a, b, x, y]): + epsilon, delta = ArithmeticSharedTensor.reveal_batch([x - a, y - b]) # z = c + (a * delta) + (epsilon * b) + epsilon * delta c._tensor += getattr(torch, op)(epsilon, b._tensor, *args, **kwargs) @@ -103,7 +120,8 @@ def square(x): provider = crypten.mpc.get_default_provider() r, r2 = provider.square(x.size(), device=x.device) - epsilon = (x - r).reveal() + with IgnoreEncodings([x, r]): + epsilon = (x - r).reveal() return r2 + 2 * r * epsilon + epsilon * epsilon @@ -125,7 +143,8 @@ def wraps(x): beta_xr = theta_r.clone() beta_xr._tensor = count_wraps([x._tensor, r._tensor]) - z = x + r + with IgnoreEncodings([x, r]): + z = x + r theta_z = comm.get().gather(z._tensor, 0) theta_x = beta_xr - theta_r diff --git a/crypten/nn/module.py b/crypten/nn/module.py index a6bde281..bb09569c 100644 --- a/crypten/nn/module.py +++ b/crypten/nn/module.py @@ -340,7 +340,7 @@ def update_parameters(self, learning_rate, grad_threshold=100): # Compute based on square value since abs is more expensive square_threshold = grad_threshold * grad_threshold grad = param.grad.mul( - param.grad.square().lt(square_threshold, _scale=False) + param.grad.square().lt(square_threshold) ) else: grad = param.grad diff --git a/test/test_distributions.py b/test/test_distributions.py index f1ae158a..f532e1c0 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -40,7 +40,7 @@ def _check_distribution( sample.size() == size, "Incorrect size for %s distribution" % name ) - plain_sample = sample.get_plain_text() + plain_sample = sample.get_plain_text().float() mean = plain_sample.mean() var = plain_sample.var() self.assertTrue( diff --git a/test/test_mpc.py b/test/test_mpc.py index 9c735f72..58f3dd62 100644 --- a/test/test_mpc.py +++ b/test/test_mpc.py @@ -776,35 +776,30 @@ def test_relu(self): def test_comparators(self): """Test comparators (>, >=, <, <=, ==, !=)""" - for _scale in [False, True]: - for comp in ["gt", "ge", "lt", "le", "eq", "ne"]: - for tensor_type in [lambda x: x, MPCTensor]: - tensor1 = self._get_random_test_tensor(is_float=True) - tensor2 = self._get_random_test_tensor(is_float=True) + for comp in ["gt", "ge", "lt", "le", "eq", "ne"]: + for tensor_type in [lambda x: x, MPCTensor]: + tensor1 = self._get_random_test_tensor(is_float=True) + tensor2 = self._get_random_test_tensor(is_float=True) - encrypted_tensor1 = MPCTensor(tensor1) - encrypted_tensor2 = tensor_type(tensor2) + encrypted_tensor1 = MPCTensor(tensor1) + encrypted_tensor2 = tensor_type(tensor2) - reference = getattr(tensor1, comp)(tensor2).float() - encrypted_out = getattr(encrypted_tensor1, comp)( - encrypted_tensor2, _scale=_scale - ) + reference = getattr(tensor1, comp)(tensor2).float() + encrypted_out = getattr(encrypted_tensor1, comp)(encrypted_tensor2) - self._check(encrypted_out, reference, "%s comparator failed" % comp) + self._check(encrypted_out, reference, "%s comparator failed" % comp) - # Check deterministic example to guarantee all combinations - tensor1 = torch.tensor([2.0, 3.0, 1.0, 2.0, 2.0]) - tensor2 = torch.tensor([2.0, 2.0, 2.0, 3.0, 1.0]) + # Check deterministic example to guarantee all combinations + tensor1 = torch.tensor([2.0, 3.0, 1.0, 2.0, 2.0]) + tensor2 = torch.tensor([2.0, 2.0, 2.0, 3.0, 1.0]) - encrypted_tensor1 = MPCTensor(tensor1) - encrypted_tensor2 = tensor_type(tensor2) + encrypted_tensor1 = MPCTensor(tensor1) + encrypted_tensor2 = tensor_type(tensor2) - reference = getattr(tensor1, comp)(tensor2).float() - encrypted_out = getattr(encrypted_tensor1, comp)( - encrypted_tensor2, _scale=_scale - ) + reference = getattr(tensor1, comp)(tensor2).float() + encrypted_out = getattr(encrypted_tensor1, comp)(encrypted_tensor2) - self._check(encrypted_out, reference, "%s comparator failed" % comp) + self._check(encrypted_out, reference, "%s comparator failed" % comp) def test_max_min_pairwise(self): """Tests max and min for the deterministic constant (n^2) algorithm"""