In [1]:
import numpy as np
import torch

In [2]:
# Conigs
HIGH_BIT = 16
LOW_BIT = 8
BITS = 32
PRECISION = 13
LOW_TYPE = np.uint32
HIGH_TYPE = np.uint64

MOD = 2 ** HIGH_BIT
P = 37

In [3]:
class RSS:
    def __init__(self, st, rd, mod):
        self.first = st
        self.second = rd
        self.mod = mod

def float2fxp(fl_x, f):
    int_x = fl_x * (2 ** f)
    ring_x = int_x % MOD

def RSS_share(x, mod):
    x1 = np.random.randint(MOD, dtype='uint16')
    x2 = np.random.randint(MOD, dtype='uint16')
    x3 = (x - x1 - x2) 
    return RSS(x1, x2, mod), RSS(x2, x3, mod), RSS(x3, x1, mod)

def RSS_reveal(x_s, mod):   
    rs1, rs2, rs3 = x_s
    return (rs1.first + rs2.first + rs3.first) % mod

def reduction(x_s):
    rs1, rs2, rs3 = x_s
    rs1_low = RSS(rs1.first.astype(LOW_TYPE), rs1.second.astype(LOW_TYPE), 2**LOW_BIT)
    rs2_low = RSS(rs2.first.astype(LOW_TYPE), rs2.second.astype(LOW_TYPE), 2**LOW_BIT)
    rs3_low = RSS(rs3.first.astype(LOW_TYPE), rs3.second.astype(LOW_TYPE), 2**LOW_BIT)
    return rs1_low, rs2_low, rs3_low

def extension(x_s):
    rs1, rs2, rs3 = x_s
    assert rs1.first.dtype == LOW_TYPE
    rs1_low = RSS(rs1.first.astype(HIGH_TYPE), rs1.second.astype(HIGH_TYPE), 2**HIGH_BIT)
    rs2_low = RSS(rs2.first.astype(HIGH_TYPE), rs2.second.astype(HIGH_TYPE), 2**HIGH_BIT)
    rs3_low = RSS(rs3.first.astype(HIGH_TYPE), rs3.second.astype(HIGH_TYPE), 2**HIGH_BIT)
    return rs1_low, rs2_low, rs3_low
    
def local_truncation(x_s, m):
    rs1, rs2, rs3 = x_s
    x0, x1, x2 = rs1.first, rs1.second, rs2.second

    def trunc(x0, x1, x2):
        return np.floor(x0 / 2**m), np.floor(x1 / 2**m), np.floor(x2 / 2**m)
    
    def trunc_red(x0, x1, x2):
        return np.floor(x0 / 2**m) %  2**LOW_BIT, np.floor(x1 / 2**m) %  2**LOW_BIT, np.floor(x2 / 2**m) %  2**LOW_BIT
    x0, x1, x2 = trunc_red(x0, x1, x2)
    rs1_tr = RSS(x0, x1, 2**HIGH_BIT)
    rs2_tr = RSS(x1, x2, 2**HIGH_BIT)
    rs3_tr = RSS(x2, x0, 2**HIGH_BIT)
    return rs1_tr, rs2_tr, rs3_tr


################   HELPER FUNCTIONS    ################
def print_rs(x_rs):
    print(f'first: {x_rs.first}, second: {x_rs.second}')

def print_xs(x_s):
    rs1, rs2, rs3 = x_s
    print_rs(rs1)
    print_rs(rs2)
    print_rs(rs3)



In [4]:
x = np.uint16(10)
x_s = RSS_share(x, mod=MOD)
print_xs(x_s)
print(RSS_reveal(x_s, MOD))

#### TEST Truncation ####
x_s_tr = local_truncation(x_s, 1)
print_xs(x_s_tr)
print(RSS_reveal(x_s_tr, 2**LOW_BIT))


#### Share Reduction ####
x_s_red = reduction(x_s)
print_xs(x_s_red)
print(RSS_reveal(x_s_red, mod=2**LOW_BIT))

x_s_red_ext = extension(x_s_red)
print_xs(x_s_red_ext)
print(RSS_reveal(x_s_red_ext, mod=2**HIGH_BIT))
print(RSS_reveal(x_s_red_ext, mod=2**HIGH_BIT) % (2 ** LOW_BIT))

first: 38861, second: 7877
first: 7877, second: 18788
first: 18788, second: 38861
65526
first: 230.0, second: 98.0
first: 98.0, second: 178.0
first: 178.0, second: 230.0
250.0
first: 38861, second: 7877
first: 7877, second: 18788
first: 18788, second: 38861
246
first: 38861, second: 7877
first: 7877, second: 18788
first: 18788, second: 38861
65526.0
246.0


In [5]:
(x - 9365 - 8571) % 2**16

47590