Skip to content

Commit

Permalink
Update TFE Primitives to new TF Big API (tf-encrypted#820)
Browse files Browse the repository at this point in the history
* correct new shape support

* update to new tf-big API

* export using tf.unit8 by default everywhere

* tiny bit of type checking
  • Loading branch information
mortendahl committed Aug 5, 2020
1 parent 0be81ad commit 284da44
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 35 deletions.
83 changes: 53 additions & 30 deletions primitives/tf_encrypted/primitives/paillier/primitives.py
@@ -1,31 +1,49 @@
from typing import Optional
from typing import Tuple

import tensorflow as tf
import tf_big

tf_big.set_secure_default(True)


def _import_maybe_limbs(tensor):
if isinstance(tensor, tf_big.Tensor):
return tensor
if isinstance(tensor, tf.Tensor):
if tensor.dtype == tf.string:
return tf_big.import_tensor(tensor)
return tf_big.import_limbs_tensor(tensor)
raise ValueError("Don't know how to import tensors of type {}".format(type(tensor)))


def _export_maybe_limbs(tensor, dtype):
assert isinstance(tensor, tf_big.Tensor), type(tensor)
if dtype == tf.string:
return tf_big.export_tensor(tensor, dtype=dtype)
return tf_big.export_limbs_tensor(tensor, dtype=dtype)


class EncryptionKey:
"""Paillier encryption key.
Note that the generator `g` has been fixed to `1 + n`.
"""

def __init__(self, n):
n = tf_big.convert_to_tensor(n)
def __init__(self, n: tf.Tensor):
n = _import_maybe_limbs(n)

self.n = n
self.nn = n * n

def export(self, dtype: tf.DType = tf.string):
return tf_big.convert_from_tensor(self.n, dtype=dtype)
def export(self, dtype: tf.DType = tf.uint8) -> tf.Tensor:
return _export_maybe_limbs(self.n, dtype)


class DecryptionKey:
def __init__(self, p, q):
self.p = tf_big.convert_to_tensor(p)
self.q = tf_big.convert_to_tensor(q)
def __init__(self, p: tf.Tensor, q: tf.Tensor):
self.p = _import_maybe_limbs(p)
self.q = _import_maybe_limbs(q)

self.n = self.p * self.q
self.nn = self.n * self.n
Expand All @@ -35,49 +53,52 @@ def __init__(self, p, q):
self.d2 = tf_big.inv(order_of_n, self.n)
self.e = tf_big.inv(self.n, order_of_n)

def export(self, dtype: tf.DType = tf.string):
def export(self, dtype: tf.DType = tf.uint8) -> Tuple[tf.Tensor, tf.Tensor]:
return (
tf_big.convert_from_tensor(self.p, dtype=dtype),
tf_big.convert_from_tensor(self.q, dtype=dtype),
_export_maybe_limbs(self.p, dtype),
_export_maybe_limbs(self.q, dtype),
)


def gen_keypair(bitlength=2048):
def gen_keypair(bitlength=2048) -> Tuple[EncryptionKey, DecryptionKey]:
p, q, n = tf_big.random_rsa_modulus(bitlength=bitlength)
ek = EncryptionKey(n)
dk = DecryptionKey(p, q)
return ek, dk


class Randomness:
def __init__(self, raw_randomness):
self.raw = tf_big.convert_to_tensor(raw_randomness)
def __init__(self, raw_randomness: tf.Tensor):
self.raw = _import_maybe_limbs(raw_randomness)

def export(self, dtype: tf.DType = tf.string):
return tf_big.convert_from_tensor(self.raw, dtype=dtype)
def export(self, dtype: tf.DType = tf.uint8) -> tf.Tensor:
return _export_maybe_limbs(self.raw, dtype=dtype)


def gen_randomness(ek, shape):
def gen_randomness(ek: EncryptionKey, shape) -> Randomness:
return Randomness(tf_big.random_uniform(shape=shape, maxval=ek.n))


class Ciphertext:
def __init__(self, ek: EncryptionKey, raw_ciphertext):
def __init__(self, ek: EncryptionKey, raw_ciphertext: tf.Tensor):
self.ek = ek
self.raw = tf_big.convert_to_tensor(raw_ciphertext)
self.raw = _import_maybe_limbs(raw_ciphertext)

def export(self, dtype: tf.DType = tf.string):
return tf_big.convert_from_tensor(self.raw, dtype=dtype)
def export(self, dtype: tf.DType = tf.uint8) -> tf.Tensor:
return _export_maybe_limbs(self.raw, dtype=dtype)

def __add__(self, other):
assert self.ek == other.ek
return add(self.ek, self, other)

def __mul__(self, other):
return mul(self.ek, self, other)


def encrypt(
ek: EncryptionKey, plaintext: tf.Tensor, randomness: Optional[Randomness] = None,
):
x = tf_big.convert_to_tensor(plaintext)
) -> Ciphertext:
x = tf_big.import_tensor(plaintext)

randomness = randomness or gen_randomness(ek=ek, shape=x.shape)
r = randomness.raw
Expand All @@ -89,7 +110,9 @@ def encrypt(
return Ciphertext(ek, c)


def decrypt(dk: DecryptionKey, ciphertext: Ciphertext, dtype: tf.DType = tf.int32):
def decrypt(
dk: DecryptionKey, ciphertext: Ciphertext, dtype: tf.DType = tf.int32
) -> tf.Tensor:
c = ciphertext.raw

gxd = tf_big.pow(c, dk.d1, dk.nn)
Expand All @@ -99,10 +122,10 @@ def decrypt(dk: DecryptionKey, ciphertext: Ciphertext, dtype: tf.DType = tf.int3
if dtype == tf.variant:
return x

return tf_big.convert_from_tensor(x, dtype=dtype)
return tf_big.export_tensor(x, dtype=dtype)


def refresh(ek: EncryptionKey, ciphertext: Ciphertext):
def refresh(ek: EncryptionKey, ciphertext: Ciphertext) -> Ciphertext:
c = ciphertext.raw
s = gen_randomness(ek=ek, shape=c.shape).raw
sn = tf_big.pow(s, ek.n, ek.nn)
Expand All @@ -112,9 +135,9 @@ def refresh(ek: EncryptionKey, ciphertext: Ciphertext):

def add(
ek: EncryptionKey, lhs: Ciphertext, rhs: Ciphertext, do_refresh: bool = True,
):
c0 = tf_big.convert_to_tensor(lhs.raw)
c1 = tf_big.convert_to_tensor(rhs.raw)
) -> Ciphertext:
c0 = lhs.raw
c1 = rhs.raw
c = (c0 * c1) % ek.nn
res = Ciphertext(ek, c)

Expand All @@ -125,9 +148,9 @@ def add(

def mul(
ek: EncryptionKey, lhs: Ciphertext, rhs: tf.Tensor, do_refresh: bool = True,
):
) -> Ciphertext:
c = lhs.raw
k = tf_big.convert_to_tensor(rhs)
k = tf_big.import_tensor(rhs)
d = tf_big.pow(c, k) % ek.nn
res = Ciphertext(ek, d)

Expand Down
10 changes: 5 additions & 5 deletions primitives/tf_encrypted/primitives/paillier/primitives_test.py
Expand Up @@ -31,14 +31,14 @@ def test_export(self, run_eagerly, export_dtype, export_expansion):
n_exported = ek.export(export_dtype)
assert isinstance(n_exported, tf.Tensor)
assert n_exported.dtype == export_dtype
assert n_exported.shape == ()
assert n_exported.shape == (1, 1), n_exported.shape
p_exported, q_exported = dk.export(export_dtype)
assert isinstance(p_exported, tf.Tensor)
assert p_exported.dtype == export_dtype
assert p_exported.shape == ()
assert p_exported.shape == (1, 1), p_exported.shape
assert isinstance(q_exported, tf.Tensor)
assert q_exported.dtype == export_dtype
assert q_exported.shape == ()
assert q_exported.shape == (1, 1), q_exported.shape

r = paillier.gen_randomness(ek, shape=x.shape)
assert isinstance(r, paillier.Randomness)
Expand Down Expand Up @@ -71,9 +71,9 @@ def test_correctness(self, run_eagerly):
context = tf_execution_context(run_eagerly)
with context.scope():

ek = paillier.EncryptionKey(str(n))
ek = paillier.EncryptionKey(tf.constant([[str(n)]]))
plaintext = np.array([[x]]).astype(str)
randomness = paillier.Randomness(np.array([[r]]).astype(str))
randomness = paillier.Randomness(tf.constant([[str(r)]]))
ciphertext = paillier.encrypt(ek, plaintext, randomness)

expected = np.array([[c]]).astype(str)
Expand Down

0 comments on commit 284da44

Please sign in to comment.