In [1]:
import random
from functools import reduce
import numpy as np
from datetime import datetime
from math import log

log2 = lambda x: log(x)/log(2)
prod = lambda xs: reduce(lambda x, y: x * y, xs)

# Number theory

All we really need is finding inverses.

In [2]:
def egcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)
    
def gcd(a, b):
    g, _, _ = egcd(a, b)
    return g

def inverse(a, m):
    _, b, _ = egcd(a, m)
    return b % m

# Ring

Define the ring in which we're working. We want to operate on ~120 bit numbers that split into 64 bit words.

In [3]:
ms = [89702869, 78489023, 69973811, 70736797, 79637461]
for mi in ms: assert log2(mi) < 63

M = prod(ms)
assert log2(M) >= 120

k = 100

# Scalars

## Naive number representation

In [4]:
class TypicalScalar:
    """ Uses the typical built-in representation of numbers """
    
    def __init__(self, value):
        self.value = value
        
    def __repr__(self):
        return 'TypicalScalar({})'.format(self.unwrap())
    
    def unwrap(self):
        return self.value
    
    def __add__(x, y):
        return TypicalScalar((x.value + y.value) % M)
    
    def __sub__(x, y):
        return TypicalScalar((x.value - y.value) % M)
    
    def __mul__(x, y):
        return TypicalScalar((x.value * y.value) % M)
    
    def mod(x):
        return TypicalScalar(x.value % k)
    
    @staticmethod
    def sample():
        return TypicalScalar(random.randrange(M))

    
a = 1000000000
b = 123456789
x = TypicalScalar(a); print(x)
y = TypicalScalar(b); print(y)
z = x + y; assert z.unwrap() == a+b, z
z = x - y; assert z.unwrap() == a-b, z
z = x * y; assert z.unwrap() == a*b, z
z = y.mod(); assert z.unwrap() == 89, z

TypicalScalar(1000000000)
TypicalScalar(123456789)


## CRT number representation

In [5]:
def gen_crt():
    
    # make sure all values in ms are coprime
    for i, mi in enumerate(ms):
        for j, mj in enumerate(ms[i+1:]):
            assert gcd(mi, mj) == 1, '{} and {} are not coprime'.format(mi, mj)
    
    def decompose(x):
        return [ x % mi for mi in ms ]
    
    # precomputation for recombine
    Mis = ( M // mi for mi in ms )
    ls = [ Mi * inverse(Mi, mi) % M for Mi, mi in zip(Mis, ms) ]
    
    def recombine(xs):
        return sum( xi * li for xi, li in zip(xs, ls) ) % M
    
    return decompose, recombine

decompose, recombine = gen_crt()

assert recombine(decompose(123456789)) == 123456789

In [6]:
def gen_mod():
    
    # precomputation for mod
    qs = [ inverse(M//mi, mi) for mi in ms ]
    B = M % k
    bs = [ (M//mi) % k for mi in ms ]

    def mod(xs):
        ts = [ (xi * qi) % mi for xi, qi, mi in zip(xs, qs, ms) ]
        alpha = round(sum( float(ti) / float(mi) for ti, mi in zip(ts, ms) ))
        v = int( sum( ti * bi for ti, bi in zip(ts, bs) ) - B * alpha )
        
        assert abs(v) < k * sum(ms) # TODO express in bit length
        
        return decompose(v % k) # TODO inline decompose?
    
    return mod

mod = gen_mod()

assert mod(decompose(123456789)) == decompose(89)

In [7]:
class CrtScalar:
    """ Uses the CRT representation of numbers """
    
    def __init__(self, value, values=None):
        if value is not None:
            values = decompose(value)
        self.values = values
        
    def __repr__(self):
        return 'CrtScalar({}, {})'.format(self.unwrap(), self.values)
    
    def unwrap(self):
        return recombine(self.values)
    
    def __add__(x, y):
        # component-wise operation that can be done in parallel
        return CrtScalar(None, [ 
            (xi + yi) % mi for xi, yi, mi in zip(x.values, y.values, ms) 
        ])
    
    def __sub__(x, y):
        # component-wise operation that can be done in parallel
        return CrtScalar(None, [ 
            (xi - yi) % mi for xi, yi, mi in zip(x.values, y.values, ms) 
        ])

    def __mul__(x, y):
        # component-wise operation that can be done in parallel
        return CrtScalar(None, [ 
            (xi * yi) % mi for xi, yi, mi in zip(x.values, y.values, ms) 
        ])
    
    def mod(x):
        return CrtScalar(None, mod(x.values))
    
    @staticmethod
    def sample():
        return CrtScalar(None, [
            random.randrange(mi) for mi in ms
        ])


a = 1000000000
b = 123456789
x = CrtScalar(a); print(x)
y = CrtScalar(b); print(y)
z = x + y; assert z.unwrap() == a+b, z
z = x - y; assert z.unwrap() == a-b, z
z = x * y; assert z.unwrap() == a*b, z
z = y.mod(); assert z.unwrap() == 89, z

CrtScalar(1000000000, [13268441, 58131724, 20366646, 9684842, 44350468])
CrtScalar(123456789, [33753920, 44967766, 53482978, 52719992, 43819328])


## Secure scalar

We use either typical or CRT numbers to represent the shares.

In [8]:
def gen_secure_scalar(scalar_type):
    
    # precomputation for truncation
    k_inv = scalar_type(inverse(k, M))
    M_wrapped = scalar_type(M)
    def raw_truncate(x):
        y = x - x.mod()
        return y * k_inv
    
    class AbstractSecureScalar:

        def __init__(self, value, share0=None, share1=None):
            if value is not None:
                value = scalar_type(value)
                share0 = scalar_type.sample()
                share1 = value - share0
            self.share0 = share0
            self.share1 = share1

        def __repr__(self):
            return 'SecureScalar({}, {}, {})'.format(self.unwrap(), self.share0, self.share1)
        
        def unwrap(self):
            return self.reconstruct().unwrap()
        
        def reconstruct(self):
            return self.share0 + self.share1
        
        def __add__(x, y):
            # component-wise operation that can be done in parallel
            return AbstractSecureScalar(None,
                share0 = x.share0 + y.share0,
                share1 = x.share1 + y.share1
            )
        
        def __sub__(x, y):
            # component-wise operation that can be done in parallel
            return AbstractSecureScalar(None,
                share0 = x.share0 - y.share0,
                share1 = x.share1 - y.share1
            )
        
        def __mul__(x, k):
            # component-wise operation that can be done in parallel
            return AbstractSecureScalar(None,
                share0 = x.share0 * scalar_type(k),
                share1 = x.share1 * scalar_type(k)
            )
        
        def truncate(x):
            return AbstractSecureScalar(None,
                share0 = raw_truncate(x.share0),
                share1 = M_wrapped - raw_truncate(M_wrapped - x.share1)
            )
            

    return AbstractSecureScalar
        
SecureScalar = gen_secure_scalar(CrtScalar)

a = 1000000000
b = 123456789
x = SecureScalar(a); print(x)
y = SecureScalar(b); print(y)
z = x + y; assert z.unwrap() == a+b, z
z = x - y; assert z.unwrap() == a-b, z
z = x * b; assert z.unwrap() == a*b, z
z = y.truncate(); assert z.unwrap() in [b // k, b // k + 1], z

SecureScalar(1000000000, CrtScalar(2493260233914690991069346241564458421632, [40377348, 4277563, 60050979, 47286623, 23170354]), CrtScalar(282063058213580005080066617023291421937, [62593962, 53854161, 30289478, 33135016, 21180114]))
SecureScalar(123456789, CrtScalar(134899795294854066712558849575381951685, [48470190, 12951725, 63506258, 32426020, 67520658]), CrtScalar(2640423496833416929436854009011491348673, [74986599, 32016041, 59950531, 20293972, 55936131]))
