Skip to content

Commit

Permalink
Handle all scaling logic in AST
Browse files Browse the repository at this point in the history
Summary: Remove the need for the `_scale` kwarg in logic functions by handling all encoding logic in `ArithmeticSharedTensor`

Reviewed By: lvdmaaten

Differential Revision: D29064976

fbshipit-source-id: 9f0f73b979422cfc979f04069868cd34f2fd3452
  • Loading branch information
knottb authored and facebook-github-bot committed Jun 23, 2021
1 parent 81001e7 commit a632bf4
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 61 deletions.
4 changes: 2 additions & 2 deletions crypten/common/approximations.py
Expand Up @@ -203,7 +203,7 @@ def reciprocal(self, input_in_01=False):

method = config.reciprocal_method
if not config.reciprocal_all_pos:
sgn = self.sign(_scale=False)
sgn = self.sign()
pos = sgn * self
with ConfigManager("reciprocal_all_pos", True):
return sgn * reciprocal(pos)
Expand Down Expand Up @@ -348,7 +348,7 @@ def sigmoid(self):
tanh_approx = tanh(self.div(2))
return tanh_approx.div(2) + 0.5
elif method == "reciprocal":
ltz = self._ltz(_scale=False)
ltz = self._ltz()
sign = 1 - 2 * ltz

pos_input = self.mul(sign)
Expand Down
2 changes: 1 addition & 1 deletion crypten/mpc/max_helper.py
Expand Up @@ -28,7 +28,7 @@ def _argmax_helper_pairwise(enc_tensor, dim=None):

# Use either prod or sum & comparison depending on size
if row_length - 1 < torch.iinfo(torch.long).bits * 2:
pairwise_comparisons = a.ge(b, _scale=False)
pairwise_comparisons = a.ge(b)
result = pairwise_comparisons.prod(0)
result.share *= enc_tensor.encoder._scale
result.encoder = enc_tensor.encoder
Expand Down
59 changes: 28 additions & 31 deletions crypten/mpc/mpc.py
Expand Up @@ -413,87 +413,84 @@ def bernoulli(self):

# Comparators
@mode(Ptype.binary)
def _ltz(self, _scale=True):
def _ltz(self):
"""Returns 1 for elements that are < 0 and 0 otherwise"""
shift = torch.iinfo(torch.long).bits - 1
result = (self >> shift).to(Ptype.arithmetic, bits=1)
if _scale:
return result * result.encoder._scale
else:
result.encoder._scale = 1
return result

precision = 0 if self.encoder.scale == 1 else None
result = (self >> shift).to(Ptype.arithmetic, precision=precision, bits=1)
result.encoder._scale = 1
return result

@mode(Ptype.arithmetic)
def ge(self, y, _scale=True):
def ge(self, y):
"""Returns self >= y"""
return 1 - self.lt(y, _scale=_scale)
return 1 - self.lt(y)

@mode(Ptype.arithmetic)
def gt(self, y, _scale=True):
def gt(self, y):
"""Returns self > y"""
return (-self + y)._ltz(_scale=_scale)
return (-self + y)._ltz()

@mode(Ptype.arithmetic)
def le(self, y, _scale=True):
def le(self, y):
"""Returns self <= y"""
return 1 - self.gt(y, _scale=_scale)
return 1 - self.gt(y)

@mode(Ptype.arithmetic)
def lt(self, y, _scale=True):
def lt(self, y):
"""Returns self < y"""
return (self - y)._ltz(_scale=_scale)
return (self - y)._ltz()

@mode(Ptype.arithmetic)
def eq(self, y, _scale=True):
def eq(self, y):
"""Returns self == y"""
if comm.get().get_world_size() == 2:
return (self - y)._eqz_2PC(_scale=_scale)
return (self - y)._eqz_2PC()

return 1 - self.ne(y, _scale=_scale)
return 1 - self.ne(y)

@mode(Ptype.arithmetic)
def ne(self, y, _scale=True):
def ne(self, y):
"""Returns self != y"""
if comm.get().get_world_size() == 2:
return 1 - self.eq(y, _scale=_scale)
return 1 - self.eq(y)

difference = self - y
difference.share = torch_stack([difference.share, -(difference.share)])
return difference._ltz(_scale=_scale).sum(0)
return difference._ltz().sum(0)

@mode(Ptype.arithmetic)
def _eqz_2PC(self, _scale=True):
def _eqz_2PC(self):
"""Returns self == 0"""
# Create BinarySharedTensors from shares
x0 = MPCTensor(self.share, src=0, ptype=Ptype.binary)
x1 = MPCTensor(-self.share, src=1, ptype=Ptype.binary)

# Perform equality testing using binary shares
x0._tensor = x0._tensor.eq(x1._tensor)
x0.encoder = x0.encoder if _scale else self.encoder
x0.encoder = self.encoder

# Convert to Arithmetic sharing
result = x0.to(Ptype.arithmetic, bits=1)

if not _scale:
result.encoder._scale = 1
result.encoder._scale = 1

return result

@mode(Ptype.arithmetic)
def sign(self, _scale=True):
def sign(self):
"""Computes the sign value of a tensor (0 is considered positive)"""
return 1 - 2 * self._ltz(_scale=_scale)
return 1 - 2 * self._ltz()

@mode(Ptype.arithmetic)
def abs(self):
"""Computes the absolute value of a tensor"""
return self * self.sign(_scale=False)
return self * self.sign()

@mode(Ptype.arithmetic)
def relu(self):
"""Compute a Rectified Linear function on the input tensor."""
return self * self.ge(0, _scale=False)
return self * self.ge(0)

@mode(Ptype.arithmetic)
def weighted_index(self, dim=None):
Expand Down Expand Up @@ -521,7 +518,7 @@ def weighted_index(self, dim=None):
)
r = MPCTensor.rand(max_weight.size(), device=self.device) * max_weight

gt = x.gt(r, _scale=False)
gt = x.gt(r)
shifted = gt.roll(1, dims=dim)
shifted.share.index_fill_(dim, torch.tensor(0, device=self.device), 0)

Expand Down
29 changes: 29 additions & 0 deletions crypten/mpc/primitives/arithmetic.py
Expand Up @@ -310,6 +310,30 @@ def get_plain_text(self, dst=None):
return torch.empty(self.share.size())
return self.encoder.decode(self.reveal(dst=dst))

def encode_(self, new_encoder):
"""Rescales the input to a new encoding in-place"""
if self.encoder.scale == new_encoder.scale:
return self
elif self.encoder.scale < new_encoder.scale:
scale_factor = new_encoder.scale // self.encoder.scale
self.share *= scale_factor
else:
scale_factor = self.encoder.scale // new_encoder.scale
self = self.div_(scale_factor)
self.encoder = new_encoder
return self

def encode(self, new_encoder):
"""Rescales the input to a new encoding"""
return self.clone().encode_(new_encoder)

def encode_as_(self, other):
"""Rescales self to have the same encoding as other"""
return self.encode_(other.encoder)

def encode_as(self, other):
return self.encode(other.encoder)

def _arithmetic_function_(self, y, op, *args, **kwargs):
return self._arithmetic_function(y, op, inplace=True, *args, **kwargs)

Expand Down Expand Up @@ -350,6 +374,11 @@ def _arithmetic_function(self, y, op, inplace=False, *args, **kwargs): # noqa:C
result.share = getattr(torch, op)(result.share, y, *args, **kwargs)
elif private:
if additive_func: # ['add', 'sub', 'add_', 'sub_']
# Re-encode if necessary:
if self.encoder.scale > y.encoder.scale:
y.encode_as_(result)
elif self.encoder.scale < y.encoder.scale:
result.encode_as_(y)
result.share = getattr(result.share, op)(y.share)
else: # ['mul', 'matmul', 'convNd', 'conv_transposeNd']
# NOTE: 'mul_' calls 'mul' here
Expand Down
25 changes: 22 additions & 3 deletions crypten/mpc/primitives/beaver.py
Expand Up @@ -11,6 +11,22 @@
from crypten.common.util import count_wraps


class IgnoreEncodings:
"""Context Manager to ignore tensor encodings"""

def __init__(self, list_of_tensors):
self.list_of_tensors = list_of_tensors
self.encodings_cache = [tensor.encoder.scale for tensor in list_of_tensors]

def __enter__(self):
for tensor in self.list_of_tensors:
tensor.encoder._scale = 1

def __exit__(self, exc_type, exc_value, exc_traceback):
for i, tensor in enumerate(self.list_of_tensors):
tensor.encoder._scale = self.encodings_cache[i]


def __beaver_protocol(op, x, y, *args, **kwargs):
"""Performs Beaver protocol for additively secret-shared tensors x and y
Expand Down Expand Up @@ -58,7 +74,8 @@ def __beaver_protocol(op, x, y, *args, **kwargs):
raise ValueError("Beaver Triples verification failed!")

# Vectorized reveal to reduce rounds of communication
epsilon, delta = ArithmeticSharedTensor.reveal_batch([x - a, y - b])
with IgnoreEncodings([a, b, x, y]):
epsilon, delta = ArithmeticSharedTensor.reveal_batch([x - a, y - b])

# z = c + (a * delta) + (epsilon * b) + epsilon * delta
c._tensor += getattr(torch, op)(epsilon, b._tensor, *args, **kwargs)
Expand Down Expand Up @@ -103,7 +120,8 @@ def square(x):
provider = crypten.mpc.get_default_provider()
r, r2 = provider.square(x.size(), device=x.device)

epsilon = (x - r).reveal()
with IgnoreEncodings([x, r]):
epsilon = (x - r).reveal()
return r2 + 2 * r * epsilon + epsilon * epsilon


Expand All @@ -125,7 +143,8 @@ def wraps(x):
beta_xr = theta_r.clone()
beta_xr._tensor = count_wraps([x._tensor, r._tensor])

z = x + r
with IgnoreEncodings([x, r]):
z = x + r
theta_z = comm.get().gather(z._tensor, 0)
theta_x = beta_xr - theta_r

Expand Down
2 changes: 1 addition & 1 deletion crypten/nn/module.py
Expand Up @@ -340,7 +340,7 @@ def update_parameters(self, learning_rate, grad_threshold=100):
# Compute based on square value since abs is more expensive
square_threshold = grad_threshold * grad_threshold
grad = param.grad.mul(
param.grad.square().lt(square_threshold, _scale=False)
param.grad.square().lt(square_threshold)
)
else:
grad = param.grad
Expand Down
2 changes: 1 addition & 1 deletion test/test_distributions.py
Expand Up @@ -40,7 +40,7 @@ def _check_distribution(
sample.size() == size, "Incorrect size for %s distribution" % name
)

plain_sample = sample.get_plain_text()
plain_sample = sample.get_plain_text().float()
mean = plain_sample.mean()
var = plain_sample.var()
self.assertTrue(
Expand Down
39 changes: 17 additions & 22 deletions test/test_mpc.py
Expand Up @@ -776,35 +776,30 @@ def test_relu(self):

def test_comparators(self):
"""Test comparators (>, >=, <, <=, ==, !=)"""
for _scale in [False, True]:
for comp in ["gt", "ge", "lt", "le", "eq", "ne"]:
for tensor_type in [lambda x: x, MPCTensor]:
tensor1 = self._get_random_test_tensor(is_float=True)
tensor2 = self._get_random_test_tensor(is_float=True)
for comp in ["gt", "ge", "lt", "le", "eq", "ne"]:
for tensor_type in [lambda x: x, MPCTensor]:
tensor1 = self._get_random_test_tensor(is_float=True)
tensor2 = self._get_random_test_tensor(is_float=True)

encrypted_tensor1 = MPCTensor(tensor1)
encrypted_tensor2 = tensor_type(tensor2)
encrypted_tensor1 = MPCTensor(tensor1)
encrypted_tensor2 = tensor_type(tensor2)

reference = getattr(tensor1, comp)(tensor2).float()
encrypted_out = getattr(encrypted_tensor1, comp)(
encrypted_tensor2, _scale=_scale
)
reference = getattr(tensor1, comp)(tensor2).float()
encrypted_out = getattr(encrypted_tensor1, comp)(encrypted_tensor2)

self._check(encrypted_out, reference, "%s comparator failed" % comp)
self._check(encrypted_out, reference, "%s comparator failed" % comp)

# Check deterministic example to guarantee all combinations
tensor1 = torch.tensor([2.0, 3.0, 1.0, 2.0, 2.0])
tensor2 = torch.tensor([2.0, 2.0, 2.0, 3.0, 1.0])
# Check deterministic example to guarantee all combinations
tensor1 = torch.tensor([2.0, 3.0, 1.0, 2.0, 2.0])
tensor2 = torch.tensor([2.0, 2.0, 2.0, 3.0, 1.0])

encrypted_tensor1 = MPCTensor(tensor1)
encrypted_tensor2 = tensor_type(tensor2)
encrypted_tensor1 = MPCTensor(tensor1)
encrypted_tensor2 = tensor_type(tensor2)

reference = getattr(tensor1, comp)(tensor2).float()
encrypted_out = getattr(encrypted_tensor1, comp)(
encrypted_tensor2, _scale=_scale
)
reference = getattr(tensor1, comp)(tensor2).float()
encrypted_out = getattr(encrypted_tensor1, comp)(encrypted_tensor2)

self._check(encrypted_out, reference, "%s comparator failed" % comp)
self._check(encrypted_out, reference, "%s comparator failed" % comp)

def test_max_min_pairwise(self):
"""Tests max and min for the deterministic constant (n^2) algorithm"""
Expand Down

0 comments on commit a632bf4

Please sign in to comment.