In [1]:
from functools import reduce
import numpy as np
from datetime import datetime
from math import log

log2 = lambda x: log(x)/log(2)
prod = lambda xs: reduce(lambda x, y: x * y, xs)

# Number theory

In [2]:
def egcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)
    
def gcd(a, b):
    g, _, _ = egcd(a, b)
    return g

def inverse(a, m):
    _, b, _ = egcd(a, m)
    return b % m

# CRT mapping

In [84]:
def decomposer(ms):
    def decompose(x):
        return tuple( x % m for m in ms )
    return decompose

def recombiner(ms):
    M = prod(ms)
    Mis = ( M // mi for mi in ms )
    ls = [ Mi * inverse(Mi, mi) % M for Mi, mi in zip(Mis, ms) ]
    def recombine(xs):
        return sum( xi * li for xi, li in zip(xs, ls) ) % M
    return recombine

# def explicit_recombiner(ms):
    

def crt(ms):
    for i, mi in enumerate(ms):
        for j, mj in enumerate(ms[i+1:]):
            assert gcd(mi, mj) == 1, '{} and {} are not coprime'.format(mi, mj)
    return decomposer(ms), recombiner(ms)
    
ms = [100, 10001, 12347]
decompose, recombine = crt(ms)

x = 123456
assert recombine(decompose(x)) == x

In [98]:
ps = ms
P = prod(ps)

u = 1234567891
assert u < P

us = decompose(u)

# precompute
qs = [ inverse(P//pi, pi) for pi in ps ]
for qi, pi in zip(qs, ps): assert qi < pi

# recombine
ts = [ ui * qi % pi for ui, qi, pi in zip(us, qs, ps) ]
for ti, pi in zip(ts, ps): assert ti < pi

rs = [ ti / pi for ti, pi in zip(ts, ps) ]
for ri, pi in zip(rs, ps): assert ri < pi

alpha = sum(rs)
assert alpha < 3

v = int(P*alpha - P*round(alpha))

assert u == v, (u, v)

In [99]:
ps = ms
P = prod(ps)

u = 1234567891
assert u < P

us = decompose(u)

# precompute
qs = [ inverse(P//pi, pi) for pi in ps ]
for qi, pi in zip(qs, ps): assert qi < pi

# recombine
ts = [ ui * qi % pi for ui, qi, pi in zip(us, qs, ps) ]
for ti, pi in zip(ts, ps): assert ti < pi

rs = [ ti / pi for ti, pi in zip(ts, ps) ]
for ri, pi in zip(rs, ps): assert ri < pi
    
alpha = sum(rs)
assert alpha < 3

v = int(P*alpha - P*round(alpha))

assert u == v, (u, v)

In [30]:
x = 123456

start = datetime.now()
for _ in range(10000):
    recombine(decompose(x))
end = datetime.now()

print(end - start)

0:00:00.042136


# Basic performance test

We want to do a dot product in a ~120 bit ring between numbers that are likely to be of that same size (secret shares). What is performance difference between working on multi-precision numbers in naive representation and CRT representation that fits in 64 bit words?

In [4]:
x = np.array([ 2**120 for _ in range(1024) ]).astype(object)
y = np.array([ 2**120 for _ in range(1024) ]).astype(object)

## Naive representation

NOTE: since we're only doing the reduction we should test what results we get if we did reductions as part of multiplication and before/during addition.

In [5]:
Q = 2657003489534545107915232808830590043
assert log2(Q) > 120

start = datetime.now()
for _ in range(10000):
    z = np.dot(x, y) % Q
end = datetime.now()

print(end - start)

0:00:01.471238


## Using the CRT

We want every component to fit in a 64bit signed integer after one multiplication and 1024 additions: `63 > m * m + log2(1024)` meaning `log2(m) ~ 26` works.

In [6]:
ms = [89702869, 78489023, 69973811, 70736797, 79637461]

# make sure that we can do add up to 1024 single multiplications and still have room in 64bit signed word
for m in ms:
    assert log2(m * m * 1024) < 63
    
M = prod(ms)
assert log2(M) > 120

In [17]:
decompose, recombine = crt(ms)
vec_decompose = np.vectorize(decompose)
vec_recombine = recombine

assert (x == recombine(vec_decompose(x))).all()
assert (y == recombine(vec_decompose(y))).all()

How long does the decomposition and recombination take? This is the hidden prize we pay and need to minimize.

In [31]:
start = datetime.now()
for _ in range(10000):
    _ = vec_decompose(x)
    _ = vec_decompose(y)
end = datetime.now()

print(end - start)

0:00:22.590107


In [32]:
xs = vec_decompose(x)
ys = vec_decompose(y)

start = datetime.now()
for _ in range(10000):
    _ = vec_recombine(ys)
end = datetime.now()

print(end - start)

0:00:07.837588


How long does the dot product itself take?

In [26]:
xs = vec_decompose(x)
ys = vec_decompose(y)

start = datetime.now()
for _ in range(10000):
    for xi, yi, mi in zip(xs, ys, ms):
        zi = np.dot(xi, yi) % mi
end = datetime.now()

print(end - start)

0:00:00.132995


# Arithmetic

In [39]:
def add(a, b):
    return tuple( (ai + bi) % mi for ai, bi, mi in zip(a, b, ms) )

def sub(a, b):
    return tuple( (ai - bi) % mi for ai, bi, mi in zip(a, b, ms) )

def mul(a, b):
    return tuple( (ai * bi) % mi for ai, bi, mi in zip(a, b, ms) )
    
x = 1234
y = 5

a = decompose(x)
b = decompose(y)
c = add(a, b); assert recombine(c) == x + y
c = sub(a, b); assert recombine(c) == x - y
c = mul(a, b); assert recombine(c) == x * y

# Modulus and truncation

In [40]:
xs = decompose(12345)

def naive_mod(xs, m):
    x = recombine(xs)
    y = x % m
    ys = decompose(y)
    return ys

ys = naive_mod(xs, 100); assert recombine(ys) == 45
zs = sub(xs, ys); assert recombine(zs) == 12300

In [48]:
def truncator(k):
    k_inv = inverse(k, M)
    assert (k * k_inv) % M == 1
    ks = decompose(k_inv)
    def truncate(xs):
        ys = sub(xs, naive_mod(xs, k))
        return mul(ys, ks)
    return truncate

truncate = truncator(10)
recombine(truncate(xs))

1234

# Mixing CRT and secret sharing

In [None]:
import random

In [None]:
def share(secret, modulus):
    x0 = random.randrange(modulus)
    x1 = (secret - x0) % modulus
    return x0, x1

def reconstruct(x0, x1, modulus):
    return (x0 + x1) % modulus

In [None]:
x = 12345
m = inverse(M, N)

In [None]:
a, b = decompose(x)
ab = recombine(a, b)
assert ab == x, ab

c, d = (0, (b - a) % N)
cd = recombine(c, d)
assert cd in [12300], cd

e, f = (d * m) % N % M, (d * m) % N
ef = recombine(e, f)
assert ef in [123], ef

In [27]:
def test():
    
    a, b = decompose(x)
    ab = recombine(a, b)
    assert ab == x, ab

    a0, a1 = share(a, M)
    b0, b1 = share(b, N)
    a = reconstruct(a0, a1, M)
    b = reconstruct(b0, b1, N)
    ab = recombine(a, b)
    assert ab == x, ab

    c0, c1 = share(0, M)
    d0, d1 = (b0 - a0 + M) % N, (b1 - a1) % N
    c = reconstruct(c0, c1, M)
    d = reconstruct(d0, d1, N)
    cd = recombine(c, d)
    assert cd in [12300, 12400], cd

    f0, f1 = (d0 * m) % N, (d1 * m) % N
#     if f0 < N//2 or f1 < N//2:
#         print(f0, f1, f0 + f1 < N)
    f0_patched = f0 #if f0 > N/2 else f0 + N//2
    f1_patched = f1 #if f1 > N/2 else f1 + N//2

    # option A: works
    if f0 + f1 < N:
        f0_patched += N//2
        f1_patched += N//2 + 1
        
    # option B: doesn't work
    # ********* MAYBE TEST AGAINST B INSTEAD???? **********
#     if f0 < N//2:
#         f0_patched += N//2 + M
#     if f1 < N//2:
#         f1_patched += N//2 + M
#     assert f0 % M == f0_patched % M
#     assert f1 % M == f1_patched % M

    # option C with one bit (one field element encrypted) of communication: not sure works
#     if f0 < B and f1 < B:
#         f0_patched += N//2
#         f1_patched += N//2 + 1

#     assert f0_patched + f1_patched > N

    e0, e1 = f0_patched % M, f1_patched % M
    e0 = (e0 - 1) % M
    
    e = reconstruct(e0, e1, M)
    f = reconstruct(f0, f1, N)
    ef = recombine(e, f)

    correct_result = ef in [123, 124]
#     f_wrap_around = (f0 + f1) >= N
#     f_wrap_around = (f0_patched + f1_patched) >= N
    

#     if not f_wrap_around:
#         print(correct_result, f_wrap_around, f0, f1, f0_patched, f1_patched, e, f, f%M)
#     assert correct_result == f_wrap_around, (correct_result, f_wrap_around, ef)
#     assert e == f % M, (e, f%M)
#     assert (f0_patched + f1_patched) >= N
    return correct_result

for x in [12345]: # range(1, B)[:10]:
    print(sum(test() for _ in range(200000)))

ValueError: too many values to unpack (expected 2)

In [None]:
shifter = inverse(M, MN)
print(shifter)

(a0, a1), (b0, b1) = zip(*share(1234))

c0 = 0
c1 = 0
t0 = (b0 - a0)  % N
t1 = (b1 - a1 + M) % N
d0 = (t0 * shifter) % N
d1 = (t1 * shifter) % N

recombine((c0 + c1) % M, (d0 + d1) % N)

In [None]:
inverse(100, N)

In [None]:
1200 % 200

In [None]:
1200 % 201

In [None]:
decompose(1234)

In [None]:
decompose(1200)

In [None]:
decompose(12)

In [None]:
mul

In [None]:
1234 = a * 200 + b
1234 = c * 201 + d = c * 200 + c + d = c * 200 + e

a * 200 + b = c * 200 + e
a * 200 - c * 200 = e - b


