TODO intro

Much of this notebook is code written by [Morten Dahl](https://github.com/mortendahl) in his [Private ML repo](https://github.com/mortendahl/privateml/) and [Pallier post](https://github.com/mortendahl/mortendahl.github.io/blob/master/_drafts/2019-04-15-paillier-encryption.md). Putting it together required several conversations where he helped me through how to understand the properties. Thank you Morten for the support, advice and chats.

In [1]:
import math
import random

## Required Primatives & Utility Functions

In [2]:
# see https://inventwithpython.com/rabinMiller.py

SMALL_PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 
                67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 
                139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 
                223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 
                293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 
                383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 
                463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 
                569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 
                647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 
                743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 
                839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 
                937, 941, 947, 953, 967, 971, 977, 983, 991, 997]

def rewrite(num):
    s = num - 1
    t = 0
    while s % 2 == 0:
        s = s // 2
        t += 1
    return s, t

def rabin_miller(num, iterations=10):
    s, t = rewrite(num)
    for _ in range(iterations):
        a = random.randrange(2, num - 1)
        v = pow(a, s, num)
        if v != 1:
            i = 0
            while v != (num - 1):
                if i == t - 1:
                    return False
                else:
                    i = i + 1
                    v = pow(v, 2, num)
    return True

def is_prime(num):
    if (num < 2): return False
    for prime in SMALL_PRIMES:
        if num == prime: return True
        if num % prime == 0: return False
    return rabin_miller(num)

In [3]:
def sample_randomness(ek):
    while True:
        r = random.randrange(ek.n)
        if math.gcd(r, ek.n) == 1:
            return r

In [4]:
def sample_prime(bitsize):
    lower = 1 << (bitsize-1)
    upper = 1 << (bitsize)
    while True:
        candidate = random.randrange(lower, upper)
        if is_prime(candidate):
            return candidate

In [5]:
# from http://www.ucl.ac.uk/~ucahcjm/combopt/ext_gcd_python_programs.pdf
def egcd_binary(a, b):
    u, v, s, t, r = 1, 0, 0, 1, 0
    while (a % 2 == 0) and (b % 2 == 0):
        a, b, r = a//2, b//2, r+1
    alpha, beta = a, b
    while (a % 2 == 0):
        a = a//2
        if (u % 2 == 0) and (v % 2 == 0):
            u, v = u//2, v//2
        else:
            u, v = (u + beta)//2, (v - alpha)//2
    while a != b:
        if (b % 2 == 0):
            b = b//2
            if (s % 2 == 0) and (t % 2 == 0):
                s, t = s//2, t//2
            else:
                s, t = (s + beta)//2, (t - alpha)//2
        elif b < a:
            a, b, u, v, s, t = b, a, s, t, u, v
        else:
            b, s, t = b - a, s - u, t - v
    return (2 ** r) * a, s, t


def inverse(a, field):
    _, b, _ = egcd_binary(a, field)
    return b

## Pallier Encryption & Decryption Keys

In [6]:
class EncryptionKey:
    def __init__(self, n):
        self.n = n
        self.nn = n * n
        self.g = 1 + n
        
    def __repr__(self):
        return "Encryption Key <n: {}, nn: {}, g: {}>".format(
            self.n, self.nn, self.g)

In [7]:
class DecryptionKey:
    def __init__(self, p, q):
        n = p * q

        self.n = p * q
        self.nn = n * n
        self.g = 1 + n

        order_of_n = (p - 1) * (q - 1)
        self.d1 = order_of_n
        self.d2 = inverse(order_of_n, n)
        self.e = inverse(n, order_of_n)
        
    def __repr__(self):
        return "Decryption-Key Key <n: {}, d1: {}, e: {}>".format(
            self.n, self.d1, self.e)        

In [8]:
def keygen(n_bitlength=512): # should be 2048
    p = sample_prime(n_bitlength // 2)
    q = sample_prime(n_bitlength // 2)
    n = p * q

    return EncryptionKey(n), DecryptionKey(p, q)

### Encrypting, Decrypting and Extracting the Randomness (r)

In [9]:
def enc(ek: EncryptionKey, x, r):
    gx = pow(ek.g, x, ek.nn)
    rn = pow(r, ek.n, ek.nn)
    c = (gx * rn) % ek.nn
    return c

In [10]:
def dec(dk: DecryptionKey, c):
    gxd = pow(c, dk.d1, dk.nn)
    xd = dlog(gxd, dk.n)
    x = (xd * dk.d2) % dk.n
    return x

In [11]:
def dlog(gy, n):
    y = (gy - 1) // n
    return y

In [12]:
def extract(dk: DecryptionKey, c):
    x = dec(dk, c)
    gx = pow(dk.g, x, ek.nn)
    gx_inv = inverse(gx, ek.nn)
    rn = (c * gx_inv) % ek.nn
    r = pow(rn, dk.e, dk.n)
    return r

## Encrypting and Decrypting with Pallier

In [13]:
EncryptionKey(n)

NameError: name 'n' is not defined

In [14]:
ek, dk = keygen()

In [15]:
type(dk.d2)

int

In [16]:
r = sample_randomness(ek)

In [17]:
msg = 4

In [18]:
ciphertext = enc(ek, msg, r)

In [19]:
ciphertext

10321439241051426613688577859887450659360732871846743524747505386116408678420713060698574909639828140864425239801996680664277916866170817060720967117577638490669230272283797353510295109478714690166282526796457792244014675094366354680771467227854015723961105427189081109441947429553481059513398227277852872863

In [20]:
dec(dk, ciphertext) == msg

True

Note: you might want to extract the randomness provided as part of the encryption to prove the correctness of the decryption. For one implementation of how you might use this, take a look at [the tf-encrypted implementation of secure aggregation](https://medium.com/dropoutlabs/building-secure-aggregation-into-tensorflow-federated-4514fca40cc0) where it is used to prove correct decryption.

In [21]:
extract(dk, ciphertext) == r

True

## Homomorphic Properties

In [22]:
def add_cipher(ek, c1, c2):
    c = (c1 * c2) % ek.nn
    return c

def add_plain(ek, c1, x2):
    c2 = pow(ek.g, x2, ek.nn)
    c = (c1 * c2) % ek.nn
    return c

def neg(ek, c):
    return inverse(c, ek.nn)

def sub_cipher(ek, c1, c2):
    c = add_cipher(ek, c1, neg(ek, c2))
    return c

def sub_plain(ek, c1, x2):
    c = add_plain(ek, c1, ek.n - x2)
    return c

def mul_plain(ek, c1, x2):
    c = pow(c1, x2, ek.nn)
    return c

In [23]:
msg_one, msg_two = 45, 234

In [24]:
r1 = sample_randomness(ek)

In [25]:
c1 = enc(ek, msg_one, r1)
c2 = enc(ek, msg_two, r1)

In [26]:
result_addition = add_cipher(ek, c1, c2)

In [27]:
dec(dk, result_addition) == msg_one + msg_two

True

In [28]:
result_subtraction = sub_cipher(ek, c1, c2)

In [29]:
result_subtraction

22512405487098101026220377624269112858493878588296091001057972750295625379574544319399408730062645078505096557094407502760262516521376800542036087368920625374424405822290566778876980328144137236570779613000715493029746725260611612094758753615568842393279061434567694650250471138918593599528233942480830028821

In [30]:
#TODO: this is broken
dec(dk, result_subtraction)

4744723963214098404145090568423826168557536311067305708831251073494007699286048863782394795535880968531337850502932297689411099468684162931555993034477780

You can also perform plaintext operations, like multiplying, adding and subtracting publicly known values as part of the computation.

In [31]:
dec(dk, mul_plain(ek, c1, 5)) == 5 * msg_one

True

In [32]:
dec(dk, sub_plain(ek, c2, 1000)) == msg_two - 1000

False