In [1]:
import numpy as np

# ABY2.0 Scalar Product

## Implementation

In [2]:
np.seterr(over='ignore') # Supress overflow warnings --> avoid warnings when modulo kicks in

def SP_Setup(l: int, dtype_Z2n: type = np.uint32, seed=42):
    # Set Ring info from dtype
    if not np.issubdtype(dtype_Z2n, np.integer):
        raise TypeError("dtype must be a numeric numpy type")
    Z2n_info = np.iinfo(dtype_Z2n)
    Z2n = (Z2n_info.min, Z2n_info.max)
    
    # Sample correlated input shares [δx] * [δy] = [δxy] 
    rng = np.random.default_rng(seed)
    δx0, δx1, δy0, δy1, δxy0 = \
        rng.integers(low=Z2n[0], high=Z2n[1], size=(5,l), dtype=dtype_Z2n)
    δxy1 = (δx0+δx1) * (δy0+δy1) - δxy0
    # Sample and output shares [δz]
    δz0, δz1 = rng.integers(low=Z2n[0], high=Z2n[1], size=(2), dtype=dtype_Z2n)
    return δx0, δx1, δy0, δy1, δxy0, δxy1, δz0, δz1

def SP_Online_local(j, Δx, δx_j, Δy, δy_j, δxy_j, δz_j):
    # Compute local share of [Δz]
    Δz_j = np.sum((Δx*Δy if j else 0) - Δx*δy_j - Δy*δx_j + δxy_j) + δz_j
    return Δz_j

def SP_Online(Δx, δx0, δx1, Δy, δy0, δy1, δxy0, δxy1, δz0, δz1):
    # P0
    Δz0 = SP_Online_local(0, Δx, δx0, Δy, δy0, δxy0, δz0)
    # P1
    Δz1 = SP_Online_local(1, Δx, δx1, Δy, δy1, δxy1, δz1)
    # Exchange shares and reconstruct common share Δy
    Δz = Δz0 + Δz1
    return Δz

def share(x, δx0, δx1):
    """Secret Share `x`, using the precomputed δx0 and δx1 to output the common Δx share"""
    return x+(δx0+δx1)

def reconstruct(Δx, δx0, δx1):
    """Reconstruct secret from triangular Secret Shares """
    return Δx-(δx0+δx1)

## Example

In [2]:
## SETUP
l         = 128
dtype_Z2n = np.uint32
δx0, δx1, δy0, δy1, δxy0, δxy1, δz0, δz1 = SP_Setup(l=128, dtype_Z2n=dtype_Z2n)

In [3]:
## ONLINE
# Data sharing
x = np.random.randint(0, 2**12, size=l, dtype=dtype_Z2n)
y = np.random.randint(0, 2**12, size=l, dtype=dtype_Z2n)
Δx = share(x, δx0, δx1)
Δy = share(y, δy0, δy1)

# SP computation
Δz = SP_Online(Δx, δx0, δx1, Δy, δy0, δy1, δxy0, δxy1, δz0, δz1)

In [4]:
# Reconstruct and open result
z = reconstruct(Δz, δz0, δz1)
z, x@y

(539056012, 539056012)

# FSS comparison1 (naive)

## Primitives  & utilities

In [3]:
# ----------------------------CONVERSION UTILITIES---------------------------- #
def _itoBs(x: int, n=32) -> bytes:
    '''Converts an int into a bytestring'''
    return int(x).to_bytes(length=n//8, byteorder='big', signed=False)

def _Bstoi(x: bytes)     -> int:
    '''Converts a bytestring into an int'''
    return int.from_bytes(x, byteorder='big', signed=False)

def _BstoB(x: bytes)      -> np.ndarray:
    '''Convert a bytestring into an uint8 numpy array'''
    return np.array([byte for byte in x], dtype=np.uint8)

def _BtoBs(x: np.ndarray) -> bytes:
    '''Converts an uint8 numpy array into a bytestring'''
    return x.tobytes()

def _itoB(x: int, n=32) -> np.ndarray:
    '''Converts an int into an uint8 numpy array of bytes'''
    return _BstoB(_itoBs(x, n))

def _Btoi(x: np.ndarray)     -> int:
    '''Converts an uint8 numpy array of bytes into an int'''
    return _Bstoi(_BtoBs(x))

def _itob(x: int, n=32) -> np.ndarray:
    '''bit decomposition of a n-bit input, treated as unsigned integer.'''
    bytes_arr  = _BstoB(_itoBs(x, n))
    return np.unpackbits(bytes_arr)

# -----------------------------OPERATOR UTILITIES----------------------------- #
def sample(rng, nbits):
    '''Sample a certain number of bits out of a PRNG. Must be multiple of 8'''
    return _BstoB(rng.bytes(nbits//8))

from functools import reduce
def xor(*args):
    '''Elementwise ZOR for several lists of elements'''
    xor_2 = lambda a,b: a^b
    return [reduce(xor_2, arg)  for arg in zip(*args)]

def secret_share(x: int, rng):
    '''Secret share a value'''
    share_0 = int(rng.integers(*Z2n))
    share_1 = (x - share_0) % (2**(n))
    return share_0, share_1

In [4]:
n = 32
Z2n = (0, 2**(n)-1)
λ   = 128
zeros_λ = np.zeros(λ//8, dtype=np.uint8)
zeros_n = np.zeros(n//8, dtype=np.uint8)
zeros_CW = (zeros_λ, 0, zeros_λ, 0,    zeros_n, 0, zeros_n, 0)

def G(seed, λ=128, n=32):
    '''Pseudo-random generator for FSS comparison'''
    rng = np.random.Generator(np.random.PCG64(seed=[x for x in seed]))
    tL, tR = rng.integers(2, size=2)
    sL, sR = sample(rng, λ), sample(rng, λ)
    τL, τR = rng.integers(2, size=2)
    σL, σR = sample(rng, n), sample(rng, n)
    return sL, tL, sR, tR, σL, τL, σR, τR

## Implementation: (Algo 3, Algo 4)

In [5]:
def FSS_Comparison_KeyGen():
    rng = np.random.default_rng(seed=42)
    
    # Sample random α ← Z2n
    α = int(rng.integers(*Z2n))
    α_bits = _itob(α)

    # Sample random s(1)j ← {0, 1}λ and set t(1)j ← j, for j = 0, 1
    s0, s1 = [sample(rng, λ)], [sample(rng, λ)]
    t0, t1 = [0], [1]

    CWleaf, CW, cw = [], [], []

    # Loop over each node in the tree
    for i in range(n):
        # L/R state words
        s0L, t0L, s0R, t0R, σ0L, τ0L, σ0R, τ0R = Gs0i = G(s0[i])
        s1L, t1L, s1R, t1R, σ1L, τ1L, σ1R, τ1R = Gs1i = G(s1[i])
        # cw based on bit α_bits[i]
        if α_bits[i]:
            cw.append([zeros_λ, 0, s0L^s1L, 1,    σ0R^σ1R, 1, zeros_n, 0])
        else:
            cw.append([s0R^s1R, 1, zeros_λ, 0,    zeros_n, 0, σ0L^σ1L, 1])
        # mask cw[i] with G(s0[i]) and G(s1[i]) inside CW
        CW.append(xor(cw[i], Gs0i, Gs1i))
        # compute each party's next state -> unmask his state word only if  tj[i]==1, else all zeros 
        s0α0, t0α0, s0α1, t0α1, σ0α1, τ0α1, σ0α0, τ0α0, = xor(Gs0i, CW[i] if t0[i] else zeros_CW)
        s1α0, t1α0, s1α1, t1α1, σ1α1, τ1α1, σ1α0, τ1α0, = xor(Gs1i, CW[i] if t1[i] else zeros_CW)
        #  and append to the list of states 
        #    Parse sj tj
        s0.append(s0α1 if α_bits[i] else s0α0);  t0.append(t0α1 if α_bits[i] else t0α0)
        s1.append(s1α1 if α_bits[i] else s1α0);  t1.append(t1α1 if α_bits[i] else t1α0)
        #    Parse σj, τj
        σ0, τ0 = (σ0α1, τ0α1) if α_bits[i] else (σ0α0, τ0α0)
        σ1, τ1 = (σ1α1, τ1α1) if α_bits[i] else (σ1α0, τ1α0)
        
        CWleaf.append( ((-1)**(τ1) * (α_bits[i] - _Btoi(σ0) + _Btoi(σ1)) ) % (2**(n)) )
    CWleaf.append(  ((-1)**(t1[n]) * (1   - _Btoi(s0[n]) + _Btoi(s1[n])) ) % (2**(n)) )

    # build FSS keys for each party
    α_ss0, α_ss1 = secret_share(α, rng)
    k0 = (α_ss0, s0[0], CW, CWleaf)
    k1 = (α_ss1, s1[0], CW, CWleaf)
    return k0, k1

def FSS_Comparison_Eval(j, kj, x):
    x_bits = _itob(x)
    
    # Parse kj
    α_ss, s, CW, CWleaf = kj
    t = j
    
    # tree evaluation
    out = []
    for i in range(n):
        sx0, tx0, sx1, tx1, σx0, τx0, σx1, τx1 = xor(G(s), CW[i] if t else zeros_CW)
        s, t, σ, τ = (sx1, tx1, σx1, τx1) if x_bits[i] else (sx0, tx0, σx0, τx0)
        out.append( ((-1)**j * (τ * CWleaf[i] + _Btoi(σ))) % (2**n) )
    out.append( ((-1)**j * (t * CWleaf[n] + _Btoi(s))) % (2**n) )
    return sum(out) % (2**n)

In [None]:
L, t0L, s0R, t0R, σ0L, τ0L, σ0R, τ0R = Gs0i = G(s0[i])
        s1L, t1L, s1R, t1R, σ1L, τ1L, σ1R, τ1R = Gs1i = G(s1[i])
        # cw based on bit α_bits[i]
        if α_bits[i]:
            cw.append([zeros_λ, 0, s0L^s1L, 1,    σ0R^σ1R, 1, zeros_n, 0])
        else:
            cw.append([s0R^s1R, 1, zeros_λ, 0,    zeros_n, 0, σ0L^σ1L, 1])
        # mask cw[i] with G(s0[i]) and G(s1[i]) inside CW
        CW.append(xor(cw[i], Gs0i, Gs1i))
        # compute each party's next state -> unmask his state word only if  tj[i]==1, else all zeros 
        s0α0, t0α0, s0α1, t0α1, σ0α1, τ0α1, σ0α0, τ0α0, = xor(Gs0i, CW[i] if t0[i] else zeros_CW)
        s1α0, t1α0, s1α1, t1α1, σ1α1, τ1α1, σ1α0, τ1α0, = xor(Gs1i, CW[i] if t1[i] else zeros_CW)
        #  and append to the list of states 
        #    Parse sj tj
        s0.append(s0α1 if α_bits[i] else s0α0);  t0.append(t0α1 if α_bits[i] else t0α0)
        s1.append(s1α1 if α_bits[i] else s1α0);  t1.append(t1α1 if α_bits[i] else t1α0)
        #    Parse σj, τj
        σ0, τ0 = (σ0α1, τ0α1) if α_bits[i] else (σ0α0, τ0α0)
        σ1, τ1 = (σ1α1, τ1α1) if α_bits[i] else (σ1α0, τ1α0)
        
        CWleaf.append( ((-1)**(τ1) * (α_bits[i] - _Btoi(σ0) + _Btoi(σ1)) ) % (2**(n)) )
    CWleaf.append(  ((-1)**(t1[n]) * (1   - _Btoi(s0[n]) + _Btoi(s1[n])) ) % (2**(n)) )

    # build FSS keys for each party
    α_ss0, α_ss1 = secret_share(α, rng)
    k0 = (α_ss0, s0[0], CW, CWleaf)
    k1 = (α_ss1, s1[0], CW, CWleaf)
    return k0, k1

def FSS_Comparison_Eval(j, kj, x):
    x_bits = _itob(x)
    
    # Parse kj
    α_ss, s, CW, CWleaf = kj
    t = j
    
    # tree evaluation
    out = []
    for i in range(n):
        sx0, tx0, sx1, tx1, σx0, τx0, σx1, τx1 = xor(G(s), CW[i] if t else zeros_CW)
        s, t, σ, τ = (sx1, tx1, σx1, τx1) if x_bits[i] else (sx0, tx0, σx0, τx0)
        out.append( ((-1)**j * (τ * CWleaf[i] + _Btoi(σ))) % (2**n) )
    out.append( ((-1)**j * (t * CWleaf[n] + _Btoi(s))) % (2**n) )
    return sum(out) % (2**n)

## Example

In [6]:
k0, k1 = FSS_Comparison_KeyGen()
α = (k0[0] + k1[0]) % (2**n)
α

383329927

In [10]:
for y in range(-32, 33):
    x = (y + α) % (2**n)
    res = (1 - FSS_Comparison_Eval(0, k0, x) - FSS_Comparison_Eval(1, k1, x)) % (2**n)
    print(f'{y}: {res}')

-32: 0
-31: 0
-30: 0
-29: 0
-28: 0
-27: 0
-26: 0
-25: 0
-24: 0
-23: 0
-22: 0
-21: 0
-20: 0
-19: 0
-18: 0
-17: 0
-16: 0
-15: 0
-14: 0
-13: 0
-12: 0
-11: 0
-10: 0
-9: 0
-8: 0
-7: 0
-6: 0
-5: 0
-4: 0
-3: 0
-2: 0
-1: 0
0: 0
1: 1
2: 1
3: 1
4: 1
5: 1
6: 1
7: 1
8: 1
9: 1
10: 1
11: 1
12: 1
13: 1
14: 1
15: 1
16: 1
17: 1
18: 1
19: 1
20: 1
21: 1
22: 1
23: 1
24: 1
25: 1
26: 1
27: 1
28: 1
29: 1
30: 1
31: 1
32: 1


# FSS comparison (Cython)

In [7]:
%load_ext cython

## Optimized PRG

### Numpy-based G

In [58]:
%%cython -+ -c=/O2
cimport numpy as np
import numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef np.ndarray[np.uint8_t, ndim=1] G_cy(
    np.ndarray[np.uint8_t, ndim=1] seed,
    Py_ssize_t size
):
    '''Numpy Optimized Pseudo-random generator for FSS comparison'''
    cdef Py_ssize_t i
    # Seed the PRG
    np.random.seed(seed)
    # Sample everything at once as bytes:
    #  sL, sR, sigmaL, sigmaR, tL, tR, tauL, tauR
    cdef np.ndarray[np.uint8_t, ndim=1] dest = \
        np.random.randint(low=0, high=2**8, size=size, dtype=np.uint8)
    # Cast last 4 bytes to boolean by keeping only the MSB
    for i in range(size-4, size):
        dest[i] = (dest[i] >= (1 <<7))
    return dest

Content of stdout:
_cython_magic_cbcbb99f073c99ac3d568a6b1106025605785a9a.cpp
   Creating library C:\Users\alberiba\.ipython\cython\Users\alberiba\.ipython\cython\_cython_magic_cbcbb99f073c99ac3d568a6b1106025605785a9a.cp39-win_amd64.lib and object C:\Users\alberiba\.ipython\cython\Users\alberiba\.ipython\cython\_cython_magic_cbcbb99f073c99ac3d568a6b1106025605785a9a.cp39-win_amd64.exp
Generating code
Finished generating code

### Miyaguchi–Preneel PRG

In [8]:
%%cython -+ -c=/std:c++17 -I . -S AES.cpp
cimport numpy as np
import numpy as np
cimport cython

np.import_array()

cdef extern from "AES.h":
    cdef cppclass AES:
        AES()
        AES(int keyLength)
        void EncryptECB(const unsigned char datain[], const unsigned char key[],
                        unsigned char dataout[],      unsigned int inLen,) nogil except * 
        void Xor3Blocks(const unsigned char *a, const unsigned char *b,
                        unsigned char *in_out,  unsigned int inLen) nogil except * 
        
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
@cython.initializedcheck(False)
@cython.nonecheck(False)
@cython.embedsignature(True)
cdef class CPRG():
    """Cryptographically-secure Pseudo-Random Generator.
    
    Implemented using a Miyaguchi–Preneel construction with AES-ECB as block
    cipher, and an arbitrary IV as first key.
    
    See Also:
        https://en.wikipedia.org/wiki/One-way_compression_function
    """
    cdef np.uint8_t[::1] H_0
    cdef const unsigned char * H_0_p
    cdef unsigned int aes_Bsize, hash_Bsize, n_blocks, last_n_Btob, i
    cdef AES* my_aes
    def __cinit__(
        self,
        unsigned int H0_seed    = 42,
        unsigned int aes_Bsize  = 16,
        unsigned int hash_Bsize = 44,
        unsigned int last_n_Btob= 4,
    ):
        assert aes_Bsize in (16, 24, 32), "aes_Bsize must be in {16 (128b), 24 (192b), 32 (256b)} "
        self.my_aes     = new AES(aes_Bsize*8)
        self.aes_Bsize  = aes_Bsize
        self.hash_Bsize = hash_Bsize
        self.n_blocks   = (hash_Bsize + aes_Bsize - 1) / aes_Bsize
        self.last_n_Btob= last_n_Btob
        rng = np.random.default_rng(seed=H0_seed)
        self.H_0        = rng.integers(0, 2**8, size=(aes_Bsize), dtype=np.uint8)
        self.H_0_p      = <const unsigned char *>&self.H_0[0]
 
    def __init__(self, unsigned int H0_seed    = 42, unsigned int aes_Bsize  = 16,
                       unsigned int hash_Bsize = 44, unsigned int last_n_Btob= 4):
        pass
    
    cpdef void G(self,
        np.ndarray[np.uint8_t, ndim=1] datain,
        np.ndarray[np.uint8_t, ndim=1] dataout,
    ):
        '''AES Pseudo-random generation step'''
        cdef Py_ssize_t j, k
        cdef const unsigned char * datain_p = <const unsigned char *>datain.data
        assert datain.shape[0]  >= self.aes_Bsize, \
            f"datain must have length>={self.aes_Bsize}"
        assert dataout.shape[0] >= self.aes_Bsize*self.n_blocks, \
            f"dataout must have length>={self.aes_Bsize * self.n_blocks}"
        self.MP_block(datain_p, self.H_0_p, <unsigned char *>&dataout[0])
        for j in range(1, self.n_blocks):
            self.MP_block(datain_p,
                          <const unsigned char *>&dataout[(j-1)*self.aes_Bsize],
                          <unsigned char *>&dataout[j*self.aes_Bsize])
        # Cast last n Bytes to bits (boolean) by keeping only the MSB
        for k in range(self.hash_Bsize-self.last_n_Btob, self.hash_Bsize):
            dataout[k] = (dataout[k] >= (1 <<7))
    
    cdef inline void MP_block (self,
        const unsigned char * datain_p,
        const unsigned char * key_p, 
        unsigned char * dataout_p,
    ):
        self.my_aes.EncryptECB(datain_p, key_p, dataout_p, self.aes_Bsize)
        self.my_aes.Xor3Blocks(datain_p, key_p, dataout_p, self.aes_Bsize)
        
    @property
    def H_0(self):
        return np.asarray(self.H_0)        
    @property
    def aes_Bsize(self):
        return self.aes_Bsize
    @property
    def hash_Bsize(self):
        return self.hash_Bsize
    @property
    def n_blocks(self):
        return self.n_blocks
    @property
    def last_n_Btob(self):
        return self.last_n_Btob
    
    def __repr__(self):
        return "CPRG at {} [G {}b --> {}b (w/ {} bools)]".format(
            hex(id(self)),
            self.aes_Bsize*8,
            self.hash_Bsize*8,
            self.last_n_Btob
        )

### Comparison

In [149]:
np.random.seed(42)
s_in = np.random.randint(0, 2**8, size=16, dtype=np.uint8)

In [56]:
%%timeit
G(s_in)

562 µs ± 131 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [59]:
%%timeit
G_cy(s_in, size=44)

97.9 µs ± 6.16 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [148]:
prg = CPRG()
s_out = np.empty(shape=48, dtype=np.uint8)

In [61]:
%%timeit
prg.G(s_in, s_out)

24.3 µs ± 2.94 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## KeyGen & Eval

### Python W/ CPRG

In [22]:
def FSS_Comparison_KeyGen_cprg():
    rng = np.random.default_rng(seed=42)

    # Sample random α ← Z2n
    α = int(rng.integers(*Z2n))
    α_bits = _itob(α)

    # Sample random s(1)j ← {0, 1}λ and set t(1)j ← j, for j = 0, 1
    s0, s1 = [sample(rng, λ)], [sample(rng, λ)]
    t0, t1 = [0], [1]

    CWleaf, CW, cw = [], [], []

    # set secure PRG
    prg_0 = CPRG(aes_Bsize=16, hash_Bsize=44, last_n_Btob=4)
    prg_1 = CPRG(aes_Bsize=16, hash_Bsize=44, last_n_Btob=4)
    Gs0i_buffer  = np.empty(shape=(48), dtype=np.uint8)
    Gs1i_buffer  = np.empty(shape=(48), dtype=np.uint8)
    # Loop over each node in the tree
    for i in range(n):
        # L/R state words
        prg_0.G(s0[i], Gs0i_buffer)
        prg_1.G(s1[i], Gs1i_buffer)
        Gs0i = Gs0i_buffer[:44]
        Gs1i = Gs1i_buffer[:44]
        s0L, t0L, s0R, t0R, σ0L, τ0L, σ0R, τ0R = Gs0i = Gs0i[:16], Gs0i[40], Gs0i[16:32], Gs0i[41], Gs0i[32:36], Gs0i[42], Gs0i[36:40], Gs0i[43]
        s1L, t1L, s1R, t1R, σ1L, τ1L, σ1R, τ1R = Gs1i = Gs1i[:16], Gs1i[40], Gs1i[16:32], Gs1i[41], Gs1i[32:36], Gs1i[42], Gs1i[36:40], Gs1i[43]
        # cw based on bit α_bits[i]
        if α_bits[i]:
            cw.append([zeros_λ, 0, s0L^s1L, 1,    σ0R^σ1R, 1, zeros_n, 0])
        else:
            cw.append([s0R^s1R, 1, zeros_λ, 0,    zeros_n, 0, σ0L^σ1L, 1])
        # mask cw[i] with G(s0[i]) and G(s1[i]) inside CW
        CW.append(xor(cw[i], Gs0i, Gs1i))
        # compute each party's next state -> unmask his state word only if  tj[i]==1, else all zeros 
        s0α0, t0α0, s0α1, t0α1, σ0α1, τ0α1, σ0α0, τ0α0, = xor(Gs0i, CW[i] if t0[i] else zeros_CW)
        s1α0, t1α0, s1α1, t1α1, σ1α1, τ1α1, σ1α0, τ1α0, = xor(Gs1i, CW[i] if t1[i] else zeros_CW)
        #  and append to the list of states 
        #    Parse sj tj
        s0.append(s0α1 if α_bits[i] else s0α0);  t0.append(t0α1 if α_bits[i] else t0α0)
        s1.append(s1α1 if α_bits[i] else s1α0);  t1.append(t1α1 if α_bits[i] else t1α0)
        #    Parse σj, τj
        σ0, τ0 = (σ0α1, τ0α1) if α_bits[i] else (σ0α0, τ0α0)
        σ1, τ1 = (σ1α1, τ1α1) if α_bits[i] else (σ1α0, τ1α0)

        CWleaf.append( ((-1)**(τ1) * (α_bits[i] - σ0.view(np.uint32)[0] + σ1.view(np.uint32)[0]) ) % (2**(n)) )
    CWleaf.append(  ((-1)**(t1[n]) * (1   - s0[n].view(np.uint32)[0] + s1[n].view(np.uint32)[0]) ) % (2**(n)) )

    # build FSS keys for each party
    α_ss0, α_ss1 = secret_share(α, rng)
    k0 = (α_ss0, s0[0], CW, CWleaf)
    k1 = (α_ss1, s1[0], CW, CWleaf)
    return k0, k1

def FSS_Comparison_Eval_cprg(j, kj, x):
    x_bits = _itob(x)
    
    # setup prg
    prg = CPRG(aes_Bsize=16, hash_Bsize=44, last_n_Btob=4)
    Gsi = np.empty(shape=48, dtype=np.uint8)
    
    # Parse kj
    α_ss, s, CW, CWleaf = kj
    t = j
    
    # tree evaluation
    out = []
    for i in range(n):
#         print(f"STEP {i}")
#         print(s)
        prg.G(s, Gsi)
#         print(s)
#         print(Gsi)
        Gs = (Gsi[:16], Gsi[40], Gsi[16:32], Gsi[41], Gsi[32:36], Gsi[42], Gsi[36:40], Gsi[43])
        sx0, tx0, sx1, tx1, σx0, τx0, σx1, τx1 = xor(Gs, CW[i] if t else zeros_CW)
        s, t, σ, τ = (sx1, tx1, σx1, τx1) if x_bits[i] else (sx0, tx0, σx0, τx0)
#         print(s, t, σ, τ)
#         print(CWleaf[i], σ.view(np.uint32)[0])
#         print(((-1)**j * (τ * CWleaf[i] + σ.view(np.uint32)[0])) % (2**n))
        out.append( ((-1)**j * (τ * CWleaf[i] + σ.view(np.uint32)[0])) % (2**n) )
    out.append( ((-1)**j * (t * CWleaf[n] + s.view(np.uint32)[0])) % (2**n) )
    return sum(out) % (2**n)

In [16]:
k0, k1 = FSS_Comparison_KeyGen_cprg()
α = (k0[0] + k1[0]) % (2**n)
print(α)
for y in range(-32, 33):
    x = (y + α) % (2**n)
    res = (1 - FSS_Comparison_Eval_cprg(0, k0, x) - FSS_Comparison_Eval_cprg(1, k1, x)) % (2**n)
    print(f'{y}: {res}')

383329927
-32: 2429614149
-31: 2036975329
-30: 3480040011
-29: 3961328540
-28: 1891293520
-27: 429457174
-26: 2215754146
-25: 1987251054
-24: 619660358
-23: 3640839299
-22: 2527304570
-21: 1870808484
-20: 3265802498
-19: 1817559098
-18: 484609121
-17: 1483383007
-16: 775911302
-15: 686719920
-14: 3909237266
-13: 560638211
-12: 1267639208
-11: 2436643639
-10: 2523165036
-9: 347473290
-8: 2148204890
-7: 2432421308
-6: 4063381346
-5: 2773163353
-4: 4166399832
-3: 1008724013
-2: 205326570
-1: 0
0: 0
1: 913080894
2: 3821948291
3: 1943387139
4: 360105497
5: 3455449245
6: 1816241986
7: 2196977706
8: 1293658785
9: 1192861070
10: 3658120851
11: 2735704380
12: 691310758
13: 976426775
14: 1024958432
15: 1616171260
16: 3007787612
17: 829152441
18: 1750204780
19: 1911165090
20: 3765591923
21: 4138084763
22: 2478327289
23: 1725981538
24: 4166390769
25: 2234379880
26: 1225777207
27: 327143047
28: 1733294935
29: 1048103285
30: 3374747732
31: 2707629269
32: 808212498


Now we add a conversion of the FSS keys to pure numpy arrays:

In [9]:
def _to_np_CW(CW: list, n:int = 32, w_size:int = 44) -> np.ndarray:
    return np.concatenate([np.concatenate([np.array(k).flatten() for k in 
                                           (kl[0], kl[2], kl[4], kl[6], kl[1], kl[3], kl[5], kl[7])])
                           for kl in CW]).reshape((n, w_size)).astype(np.uint8)
def _to_np_FSS_k(k: list):
    return (np.uint32(k[0]), k[1], _to_np_CW(k[2]), np.array(k[3]).astype(np.uint32))

In [18]:
def FSS_Comparison_Eval_cprg_2(j, kj, x):
    x_bits = _itob(x)
    
    # setup prg
    prg = CPRG(aes_Bsize=16, hash_Bsize=44, last_n_Btob=4)
    Gsi = np.empty(shape=48, dtype=np.uint8)
    
    # Parse kj
    α_ss, s, CW, CWleaf = kj
    CWleaf = np.array(CWleaf).astype(np.uint32)
    t = j
    
    # tree evaluation
    out = []
    for i in range(n):
#         print(f"STEP {i}")
#         print(s)
        s = np.copy(s)
        prg.G(s, Gsi)
#         print(s)
#         print(Gsi)
        Gs = (Gsi[:16], Gsi[40], Gsi[16:32], Gsi[41], Gsi[32:36], Gsi[42], Gsi[36:40], Gsi[43])
        CWi = CW[i]
        Gs = (CWi[:16]   ^ Gs[0], CWi[40] ^ Gs[1], CWi[16:32] ^ Gs[2], CWi[41] ^ Gs[3], 
              CWi[32:36] ^ Gs[4], CWi[42] ^ Gs[5], CWi[36:40] ^ Gs[6], CWi[43] ^ Gs[7]) if t else Gs
        sx0, tx0, sx1, tx1, σx0, τx0, σx1, τx1 = Gs
        s, t, σ, τ = (sx1, tx1, σx1, τx1) if x_bits[i] else (sx0, tx0, σx0, τx0)
#         print(s, t, σ, τ)
#         print( CWleaf[i], σ.view(np.uint32)[0])
#         print( ((-1)**j * (τ * CWleaf[i] + σ.view(np.uint32)[0])) % (2**n) )
        out.append( ((-1)**j * (τ * CWleaf[i] + σ.view(np.uint32)[0])) % (2**n) )
    out.append( ((-1)**j * (t * CWleaf[n] + s.view(np.uint32)[0])) % (2**n) )
    return sum(out) % (2**n)

In [19]:
key0, key1 = FSS_Comparison_KeyGen_cprg()
key0, key1 = _to_np_FSS_k(key0), _to_np_FSS_k(key1)
α = np.uint32(key0[0]) + np.uint32(key1[0])

for y in range(-32, 33):
    x = (y + α) % (2**n)
    res = (1 - FSS_Comparison_Eval_cprg_2(0, key0, x) - FSS_Comparison_Eval_cprg_2(1, key1, x)) % (2**n)
    print(f'{y}: {res}', end='  |  ')

-32: 2429614149  |  -31: 2036975329  |  -30: 3480040011  |  -29: 3961328540  |  -28: 1891293520  |  -27: 429457174  |  -26: 2215754146  |  -25: 1987251054  |  -24: 619660358  |  -23: 3640839299  |  -22: 2527304570  |  -21: 1870808484  |  -20: 3265802498  |  -19: 1817559098  |  -18: 484609121  |  -17: 1483383007  |  -16: 775911302  |  -15: 686719920  |  -14: 3909237266  |  -13: 560638211  |  -12: 1267639208  |  -11: 2436643639  |  -10: 2523165036  |  -9: 347473290  |  -8: 2148204890  |  -7: 2432421308  |  -6: 4063381346  |  -5: 2773163353  |  -4: 4166399832  |  -3: 1008724013  |  -2: 205326570  |  -1: 0  |  0: 0  |  1: 913080894  |  2: 3821948291  |  3: 1943387139  |  4: 360105497  |  5: 3455449245  |  6: 1816241986  |  7: 2196977706  |  8: 1293658785  |  9: 1192861070  |  10: 3658120851  |  11: 2735704380  |  12: 691310758  |  13: 976426775  |  14: 1024958432  |  15: 1616171260  |  16: 3007787612  |  17: 829152441  |  18: 1750204780  |  19: 1911165090  |  20: 3765591923  |  21: 4138084

In [20]:
%%timeit
(1 - FSS_Comparison_Eval_cprg(0, k0, x) - FSS_Comparison_Eval_cprg(1, k1, x)) % (2**n)

2.62 ms ± 90.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [281]:
k0, k1 = _to_np_FSS_k(k0), _to_np_FSS_k(k1)

### Cython Eval w/ CPRG

In [10]:
%%cython -+ -c=/std:c++17 -I . -S AES.cpp -a
cimport numpy as np
import numpy as np
cimport cython

from libc.string cimport memcpy
from cython.operator cimport dereference as deref

np.import_array()

cdef extern from "AES.h":
    cdef cppclass AES:
        AES()
        AES(int keyLength)
        void EncryptECB(const unsigned char datain[], const unsigned char key[],
                        unsigned char dataout[],      unsigned int inLen,) nogil except * 
        void Xor3Blocks(const unsigned char *a, const unsigned char *b,
                        unsigned char *in_out,  unsigned int inLen) nogil except * 
        
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
@cython.initializedcheck(False)
@cython.nonecheck(False)
@cython.embedsignature(True)
cdef class CPRG():
    """Cryptographically-secure Pseudo-Random Generator.
    
    Implemented using a Miyaguchi–Preneel construction with AES-ECB as block
    cipher, and an arbitrary IV as first key.
    
    See Also:
        https://en.wikipedia.org/wiki/One-way_compression_function
    """
    cdef np.uint8_t[::1] H_0
    cdef const unsigned char * H_0_p
    cdef unsigned int aes_Bsize, hash_Bsize, n_blocks, last_n_Btob, i
    cdef AES* my_aes
    def __cinit__(
        self,
        unsigned int H0_seed    = 42,
        unsigned int aes_Bsize  = 16,
        unsigned int hash_Bsize = 44,
        unsigned int last_n_Btob= 4,
    ):
        assert aes_Bsize in (16, 24, 32), "aes_Bsize must be in {16 (128b), 24 (192b), 32 (256b)} "
        self.my_aes     = new AES(aes_Bsize*8)
        self.aes_Bsize  = aes_Bsize
        self.hash_Bsize = hash_Bsize
        self.n_blocks   = (hash_Bsize + aes_Bsize - 1) / aes_Bsize
        self.last_n_Btob= last_n_Btob
        rng = np.random.default_rng(seed=H0_seed)
        self.H_0        = rng.integers(0, 2**8, size=(aes_Bsize), dtype=np.uint8)
        self.H_0_p      = <const unsigned char *>&self.H_0[0]
 
    def __init__(self, unsigned int H0_seed    = 42, unsigned int aes_Bsize  = 16,
                       unsigned int hash_Bsize = 44, unsigned int last_n_Btob= 4):
        pass
    
    cpdef void G(self,
        np.ndarray[np.uint8_t, ndim=1] datain,
        np.ndarray[np.uint8_t, ndim=1] dataout,
    ):
        '''AES Pseudo-random generation step'''
        cdef Py_ssize_t j, k
        cdef const unsigned char * datain_p = <const unsigned char *>datain.data
        assert datain.shape[0]  >= self.aes_Bsize, \
            f"datain must have length>={self.aes_Bsize}"
        assert dataout.shape[0] >= self.aes_Bsize*self.n_blocks, \
            f"dataout must have length>={self.aes_Bsize * self.n_blocks}"
        self.MP_block(datain_p, self.H_0_p, <unsigned char *>&dataout[0])
        for j in range(1, self.n_blocks):
            self.MP_block(datain_p,
                          <const unsigned char *>&dataout[(j-1)*self.aes_Bsize],
                          <unsigned char *>&dataout[j*self.aes_Bsize])
        # Cast last n Bytes to bits (boolean) by keeping only the MSB
        for k in range(self.hash_Bsize-self.last_n_Btob, self.hash_Bsize):
            dataout[k] = (dataout[k] >= (1 <<7))
    
    cdef inline void MP_block (self,
        const unsigned char * datain_p,
        const unsigned char * key_p, 
        unsigned char * dataout_p,
    ):
        self.my_aes.EncryptECB(datain_p, key_p, dataout_p, self.aes_Bsize)
        self.my_aes.Xor3Blocks(datain_p, key_p, dataout_p, self.aes_Bsize)
        
    @property
    def H_0(self):
        return np.asarray(self.H_0)
    
    def __repr__(self):
        return "CPRG at {} [G {}b --> {}b (w/ {} bools)]".format(
            hex(id(self)),
            self.aes_Bsize*8,
            self.hash_Bsize*8,
            self.last_n_Btob
        )


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
@cython.initializedcheck(False)
@cython.nonecheck(False)
@cython.embedsignature(True)
cpdef np.ndarray[np.uint8_t, ndim=1] i2b(np.uint64_t x, Py_ssize_t n=32):
    """Decompose a 64bit integer `x` into an uint8 array of `n` bits, MSB first"""
    cdef np.ndarray[np.uint8_t, ndim=1] decomp = np.empty(shape=n, dtype=np.uint8)
    cdef Py_ssize_t i
    for i in range(n):
        decomp[i] = ((x & 1ULL<<(n-1-i)) != 0)
    return decomp
    
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
@cython.initializedcheck(False)
@cython.nonecheck(False)
@cython.embedsignature(True)    
cdef inline void xor_cy(np.uint8_t[::1] a_in, np.uint8_t[::1] b_inout):
    cdef Py_ssize_t i
    for i in a_in.shape[0]:
        b_inout[i] =  a_in[i] ^ b_inout[i]
        
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
@cython.initializedcheck(False)
@cython.nonecheck(False)
@cython.embedsignature(True)    
cpdef FSS_Comparison_Eval_cy(np.uint8_t j, tuple kj, np.uint64_t x):
    cdef np.ndarray[np.uint8_t, ndim=1] x_bits = i2b(x)
    
    # setup prg
    cdef CPRG prg = CPRG(aes_Bsize=16, hash_Bsize=44, last_n_Btob=4)
    cdef np.ndarray[np.uint8_t, ndim=1] Gsi    = np.empty(shape=48, dtype=np.uint8)
    cdef np.uint32_t sigma
    cdef np.uint8_t tau
    
    # Parse kj
    cdef np.ndarray[np.uint8_t, ndim=1] s  = np.copy(kj[1])
    cdef np.ndarray[np.uint8_t, ndim=2] CW = kj[2]
    cdef np.ndarray[np.uint32_t, ndim=1] CWleaf = kj[3]
    cdef np.uint8_t t = j
    
    # tree evaluation
    cdef np.uint32_t out = 0
    cdef Py_ssize_t i, k
    for i in range(32):
        prg.G(s, Gsi)
        if t:    # xor Gsi with CW[i]
            for k in range(44):
                Gsi[k] =  Gsi[k] ^ CW[i,k]
        if x_bits[i]:
            memcpy(&s[0], &Gsi[16], 16)          # s = Gsi[16:32]
            sigma = deref(<np.uint32_t*>&Gsi[36]) # sigma = Gsi[36:40]
            t     = Gsi[41]
            tau   = Gsi[43]
        else:
            memcpy(&s[0], &Gsi[0], 16)   # s = Gsi[0:16]
            sigma = deref(<np.uint32_t*>&Gsi[32]) # sigma = Gsi[36:40]
            t     = Gsi[40]
            tau   = Gsi[42]
        out += <np.uint32_t>((-1)**j * (tau * CWleaf[i] + sigma))
    out += <np.uint32_t>((-1)**j * (t * CWleaf[32] + deref(<np.uint32_t*>&s[0])))
    return out


In [23]:
key0, key1 = FSS_Comparison_KeyGen_cprg()
key0, key1 = _to_np_FSS_k(key0), _to_np_FSS_k(key1)
α = np.uint32(k0[0]) + np.uint32(k1[0])

for y in range(2, 33):
    x = (y + α) % (2**n)
    res = (1 - FSS_Comparison_Eval_cprg_2(0, key0, x) - FSS_Comparison_Eval_cprg_2(1, key1, x)) % (2**n)
    print(f'{y}: {res}', end='  |  ')

2: 3821948291  |  3: 1943387139  |  4: 360105497  |  5: 3455449245  |  6: 1816241986  |  7: 2196977706  |  8: 1293658785  |  9: 1192861070  |  10: 3658120851  |  11: 2735704380  |  12: 691310758  |  13: 976426775  |  14: 1024958432  |  15: 1616171260  |  16: 3007787612  |  17: 829152441  |  18: 1750204780  |  19: 1911165090  |  20: 3765591923  |  21: 4138084763  |  22: 2478327289  |  23: 1725981538  |  24: 4166390769  |  25: 2234379880  |  26: 1225777207  |  27: 327143047  |  28: 1733294935  |  29: 1048103285  |  30: 3374747732  |  31: 2707629269  |  32: 808212498  |  

In [24]:
key0, key1 = FSS_Comparison_KeyGen_cprg()
key0, key1 = _to_np_FSS_k(key0), _to_np_FSS_k(key1)
α = key0[0] + key1[0]

for y in range(-32, 33):
    x = (y + α) % (2**n)
    res = (1 - FSS_Comparison_Eval_cy(0, key0, x) - FSS_Comparison_Eval_cy(1, key1, x)) % (2**n)
    print(f'{y}: {res}', end='  |  ')

-32: 1972876404  |  -31: 2549295253  |  -30: 4182167450  |  -29: 896150936  |  -28: 2244074828  |  -27: 654874466  |  -26: 1598693658  |  -25: 1098163257  |  -24: 2074724107  |  -23: 2228242450  |  -22: 18523783  |  -21: 3454456559  |  -20: 2850271572  |  -19: 3263520401  |  -18: 3493554737  |  -17: 3876612621  |  -16: 2279498498  |  -15: 581691629  |  -14: 1902460567  |  -13: 866130084  |  -12: 2278628266  |  -11: 2766914729  |  -10: 2447238960  |  -9: 1777342711  |  -8: 4293130736  |  -7: 430474760  |  -6: 3082896950  |  -5: 1943858731  |  -4: 2765517969  |  -3: 757941582  |  -2: 4126139801  |  -1: 4064528458  |  0: 62341074  |  1: 1619177790  |  2: 1485500515  |  3: 1585776463  |  4: 727047418  |  5: 1764726127  |  6: 4171685926  |  7: 3364045485  |  8: 1391783309  |  9: 1513414957  |  10: 3859921569  |  11: 266779032  |  12: 1494694170  |  13: 3361930772  |  14: 2579735777  |  15: 3497370935  |  16: 857716196  |  17: 3072430231  |  18: 1503916347  |  19: 3163896969  |  20: 40735469

In [27]:
%%timeit
(1 - FSS_Comparison_Eval_cy(0, key0, x) - FSS_Comparison_Eval_cy(1, key1, x))

900 µs ± 26.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### (WIP) Cython KeyGen w/ CPRG

In [11]:
%%cython -+ -c=/std:c++17 -I . -S AES.cpp
cimport numpy as np
import numpy as np
cimport cython

from libc.string cimport memcpy, memset

np.import_array()

cdef extern from "AES.h":
    cdef cppclass AES:
        AES()
        AES(int keyLength)
        void EncryptECB(const unsigned char datain[], const unsigned char key[],
                        unsigned char dataout[],      unsigned int inLen,) nogil except * 
        void Xor3Blocks(const unsigned char *a, const unsigned char *b,
                        unsigned char *in_out,  unsigned int inLen) nogil except * 
        
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
@cython.initializedcheck(False)
@cython.nonecheck(False)
@cython.embedsignature(True)
cdef class CPRG():
    """Cryptographically-secure Pseudo-Random Generator.
    
    Implemented using a Miyaguchi–Preneel construction with AES-ECB as block
    cipher, and an arbitrary IV as first key.
    
    See Also:
        https://en.wikipedia.org/wiki/One-way_compression_function
    """
    cdef np.uint8_t[::1] H_i
    cdef unsigned int aes_Bsize, hash_Bsize, n_blocks, last_n_Btob, i
    cdef AES* my_aes
    def __cinit__(
        self,
        unsigned int H0_seed    = 42,
        unsigned int aes_Bsize  = 16,
        unsigned int hash_Bsize = 44,
        unsigned int last_n_Btob= 4,        
    ):
        rng = np.random.default_rng(seed=H0_seed)
        self.aes_Bsize  = aes_Bsize
        self.hash_Bsize = hash_Bsize
        self.n_blocks   = (hash_Bsize + aes_Bsize - 1) / aes_Bsize
        self.last_n_Btob= last_n_Btob
        self.i          = 0
        self.H_i        = \
            rng.integers(0, 2**8, size=(self.n_blocks*aes_Bsize), dtype=np.uint8)
        assert aes_Bsize in (16, 24, 32), "aes_Bsize must be in {16 (128b), 24 (192b), 32 (256b)} "
        self.my_aes     = new AES(aes_Bsize*8)
    
    cpdef void G(self,
        np.ndarray[np.uint8_t, ndim=1] datain,
        np.ndarray[np.uint8_t, ndim=1] dataout,
    ):
        '''AES Pseudo-random generation step'''
        cdef Py_ssize_t j, k
        assert datain.shape[0]  >= self.aes_Bsize, \
            f"datain must have length>={self.aes_Bsize}"
        assert dataout.shape[0] >= self.hash_Bsize, \
            f"dataout must have length>={self.aes_Bsize * self.n_blocks}"
        for j in range(0, self.n_blocks):
            self.my_aes.EncryptECB(
                <const unsigned char *>datain.data,
                <const unsigned char *>&self.H_i[j*self.aes_Bsize],
                <unsigned char *>&dataout[j*self.aes_Bsize],
                self.aes_Bsize,
            )
            self.my_aes.Xor3Blocks(
                <const unsigned char *>datain.data,
                <const unsigned char *>&self.H_i[j*self.aes_Bsize],
                <unsigned char *>&dataout[j*self.aes_Bsize],
                self.aes_Bsize
            )
            memcpy(<void *>&self.H_i[j*self.aes_Bsize],
                   <void *>&dataout[j*self.aes_Bsize],
                   self.aes_Bsize)
        self.i += 1
        # Cast last n Bytes to bits (boolean) by keeping only the MSB
        for k in range(self.hash_Bsize-self.last_n_Btob, self.hash_Bsize):
            dataout[k] = (dataout[k] >= (1 <<7))



@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
@cython.initializedcheck(False)
@cython.nonecheck(False)
@cython.embedsignature(True)
cpdef np.ndarray[np.uint8_t, ndim=1] i2b(np.uint64_t x, Py_ssize_t n=32):
    """Decompose a 64bit integer `x` into an uint8 array of `n` bits, MSB first"""
    cdef np.ndarray[np.uint8_t, ndim=1] decomp = np.empty(shape=n, dtype=np.uint8)
    cdef Py_ssize_t i
    for i in range(n):
        decomp[i] = ((x & 1ULL<<(n-1-i)) != 0)
    return decomp


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
@cython.initializedcheck(False)
@cython.nonecheck(False)
@cython.embedsignature(True)
cdef void xor(np.uint8_t[::1] x1, np.uint8_t[::1] x2, np.uint8_t[::1] dest):
    cdef Py_ssize_t i
    for i in range(x1.shape[0]):
        dest[i] = x1[i] ^ x2[i]


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
@cython.initializedcheck(False)
@cython.nonecheck(False)
@cython.embedsignature(True)
cpdef void xor3(np.uint8_t[::1] x1, np.uint8_t[::1] x2, np.uint8_t[::1] x3, np.uint8_t[::1] dest):
    cdef Py_ssize_t i
    for i in range(x1.shape[0]):
        dest[i] = x1[i] ^ x2[i] ^ x3[i]

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
@cython.initializedcheck(False)
@cython.nonecheck(False)
@cython.embedsignature(True)
cpdef FSS_Comp_KeyGen_cy(lambd: int=128, dtype_Z2n: type=np.uint32, seed: int=42):
    # Set Ring info from dtype
    if not np.issubdtype(dtype_Z2n, np.integer):
        raise TypeError("dtype must be a numeric numpy type")
    Z2n_info = np.iinfo(dtype_Z2n)
    Z2n = (Z2n_info.min, Z2n_info.max)
    cdef Py_ssize_t n   = Z2n_info.bits
    
    # Set convenient relative positions in the cw/CW arrays for each parameter
    cdef Py_ssize_t lambd_B  = lambd//8,  n_B = n//8,\
                    CW_B     = 2*lambd_B + 2*n_B + 4,\
                    CW_B_aes = ((CW_B + lambd_B - 1) // lambd_B) * lambd_B
    cdef Py_ssize_t sL_ = 0,                        _sL = lambd_B,\
                    sR_ = lambd_B,                  _sR = 2*lambd_B,\
                    sigmaL_ = 2*lambd_B,            _sigmaL = 2*lambd_B + n_B, \
                    sigmaR_ = 2*lambd_B + n_B,      _sigmaR = 2*lambd_B + 2*n_B, \
                    tL_     = 2*lambd_B + 2*n_B,\
                    tR_     = 2*lambd_B + 2*n_B +1,\
                    tauL_   = 2*lambd_B + 2*n_B + 2,\
                    tauR_   = 2*lambd_B + 2*n_B + 3
    
    # set seed for reproducibility
    np.random.seed(seed)
    
    # set secure PRG
    cdef CPRG prg_0 = CPRG(aes_Bsize=lambd_B, hash_Bsize=CW_B, last_n_Btob=4)
    cdef CPRG prg_1 = CPRG(aes_Bsize=lambd_B, hash_Bsize=CW_B, last_n_Btob=4)
    
    # Sample random mask alpha <- Z2n, decompose it in bits and split in shares
    cdef np.int64_t alpha     = np.random.randint(low=Z2n[0], high=Z2n[1], dtype=dtype_Z2n)
    cdef np.int64_t alpha_ss0 = np.random.randint(low=Z2n[0], high=Z2n[1], dtype=dtype_Z2n)
    cdef np.int64_t alpha_ss1 = dtype_Z2n(alpha - alpha_ss0)
    cdef np.ndarray[np.uint8_t, ndim=1] alpha_bits = i2b(alpha, n)
    
    # Sample initial random s(1)j <- {0, 1}**lambd and set t(1)j <- j, for j = 0, 1
    cdef np.ndarray[np.uint8_t, ndim=1] s0 = np.random.randint(low=0, high=2**8, size=lambd_B, dtype=np.uint8)
    cdef np.ndarray[np.uint8_t, ndim=1] s1 = np.random.randint(low=0, high=2**8, size=lambd_B, dtype=np.uint8)
    cdef np.uint8_t t0 = 0, t1 = 1
 
    # Store initial states to append them to the functional keys
    cdef np.ndarray[np.uint8_t, ndim=1] s0_init = np.copy(s0)
    cdef np.ndarray[np.uint8_t, ndim=1] s1_init = np.copy(s1)
    
    # Initilize correction words
    cdef np.ndarray[np.uint8_t, ndim=2] CWleaf = np.zeros(shape=(n+1, n_B),  dtype=np.uint8)
    cdef np.ndarray[np.uint8_t, ndim=2] CW     = np.zeros(shape=(n, CW_B),   dtype=np.uint8)
    cdef np.ndarray[np.uint8_t, ndim=1] cw     = np.zeros(shape=(CW_B),      dtype=np.uint8)
    
    # Initialize intermediate values
    cdef np.ndarray[np.uint8_t, ndim=1] Gs0    = np.zeros(shape=(CW_B_aes), dtype=np.uint8)
    cdef np.ndarray[np.uint8_t, ndim=1] Gs1    = np.zeros(shape=(CW_B_aes), dtype=np.uint8)
    cdef np.ndarray[np.uint8_t, ndim=1] sigma0 = np.zeros(shape=(n_B),      dtype=np.uint8)
    cdef np.ndarray[np.uint8_t, ndim=1] sigma1 = np.zeros(shape=(n_B),      dtype=np.uint8)
    cdef np.uint8_t tau0 = 0, tau1 = 1
    
    # Loop over each node in the tree
    cdef Py_ssize_t i
    for i in range(n):
        print(f"STEP {i}")
        # Gs = [sL, sR, sigmaL, sigmaR, tL, tR, tauL, tauR]
        prg_0.G(datain=s0, dataout=Gs0)
        prg_1.G(datain=s1, dataout=Gs1)
        print(f"Gs0: {Gs0}, Gs1: {Gs1}")
        # cw based on bit α_bits[i]
        memset(&cw[0], 0, CW_B)  # Reset cw
        if alpha_bits[i]:
            xor(Gs0[sL_:_sL],         Gs1[sL_:_sL],         dest=cw[sR_:_sR])
            xor(Gs0[sigmaR_:_sigmaR], Gs1[sigmaR_:_sigmaR], dest=cw[sigmaL_:_sigmaL])
            cw[tR_]             = 1
            cw[tauL_]           = 1
        else:
            xor(Gs0[sR_:_sR],         Gs1[sR_:_sR],         dest=cw[sL_:_sL])
            xor(Gs0[sigmaL_:_sigmaL], Gs1[sigmaL_:_sigmaL], dest=cw[sigmaR_:_sigmaR])
            cw[tL_]             = 1
            cw[tauR_]           = 1
        print(f"cw: {cw}")
        # mask cw with G(s0) and G(s1) inside CW
        xor3(cw, Gs0, Gs1, dest=CW[i])
        print(f"CW[i]: {CW[i]}")
        # compute each party's next state -> unmask his state word only if  tj[i]==1, else all zeros 
        if t0:           xor(CW[i], Gs0, dest=Gs0) 
        else:            memset(&Gs0[0], 0, CW_B)
        if t1:           xor(CW[i], Gs1, dest=Gs1)
        else:            memset(&Gs1[0], 0, CW_B)
        print(f"Gs0_p: {Gs0}, Gs1_p: {Gs1}")
        #  and append to the list of states
        #    Parse sj tj
        if alpha_bits[i]:
            s0        = Gs0[sR_:_sR];         t0   = Gs0[tR_]
            s1        = Gs1[sR_:_sR];         t1   = Gs1[tR_]
            sigma0    = Gs0[sigmaL_:_sigmaL]; tau0 = Gs0[tauL_]
            sigma1    = Gs1[sigmaL_:_sigmaL]; tau1 = Gs1[tauL_]
        else:
            s0        = Gs0[sL_:_sL];         t0   = Gs0[tL_]
            s1        = Gs1[sL_:_sL];         t1   = Gs1[tL_]
            sigma0    = Gs0[sigmaR_:_sigmaR]; tau0 = Gs0[tauR_]
            sigma1    = Gs1[sigmaR_:_sigmaR]; tau1 = Gs1[tauR_]
        print(f"s0: {s0}, s1: {s1}, sigma0: {sigma0}, sigma1: {sigma1}, t0: {t0}, t1: {t1}, tau0: {tau0}, tau1: {tau1}")
        print(f"{tau1}, {alpha_bits[i]}, {sigma0.view(dtype_Z2n)}, {sigma1.view(dtype_Z2n)}")
        CWleaf[i] = dtype_Z2n((-1)**(tau1) * (alpha_bits[i] - sigma0.view(dtype_Z2n) + sigma1.view(dtype_Z2n))).flatten().view(np.uint8)
    CWleaf[n]     = dtype_Z2n((-1)**(t1)   * (1             - s0.view(dtype_Z2n)[0]  + s1.view(dtype_Z2n)[0])).flatten().view(np.uint8)

    # build FSS keys for each party
    k0 = (alpha_ss0, s0_init, CW, CWleaf)
    k1 = (alpha_ss1, s1_init, CW, CWleaf)
    return k0, k1

### Example

In [12]:
k0, k1 = FSS_Comp_KeyGen_cy()
α = np.uint32(k0[0]) + np.uint32(k1[0])
α

STEP 0
Gs0: [243  15   5  87   3  32  14 164  14 255 177 129 104 164 105  40 248  24
 123 138 235 196   5  50  79 205  81   6 242 205 226 189  71  49 112  69
 208 226 152  86   1   1   0   1 106 129 191   3], Gs1: [ 97 179 168  22 181 137 117  42  58 186 166 176 218 158 194 131 135 165
  24 244 212  44 236 178 149 212  69 210 218  68 209 222  46 164  62  88
   1 155  89  29   1   0   1   0 152 148  93 134]
cw: [127 189  99 126  63 232 233 128 218  25  20 212  40 137  51  99   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
 105 149  78  29   1   0   0   1]
CW[i]: [237   1 206  63 137  65 146  14 238  92   3 229 154 179 152 200 127 189
  99 126  63 232 233 128 218  25  20 212  40 137  51  99 105 149  78  29
 184 236 143  86   1   1   1   0]
Gs0_p: [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0 106 129 191   3], Gs1_p: [140 178

1608637542

In [254]:
i2b(α)

array([0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1,
       0, 0, 0, 1, 1, 0, 0, 1, 1, 0], dtype=uint8)

In [25]:
for y in range(-32, 33):
    x = (y + α) % (2**n)
    res = (1 - FSS_Comparison_Eval(0, k0, x) - FSS_Comparison_Eval(1, k1, x)) % (2**n)
    print(f'{y}: {res}')

-32: [1338678893 1338678398 1338678334 1338678398]
-31: [3878951822 3878951326 3878951262 3878951326]
-30: [1680832434 1680831937 1680831873 1680831937]
-29: [450150991 450150495 450150431 450150495]
-28: [1334643339 1334642842 1334642778 1334642842]
-27: [3053888208 3053887711 3053887647 3053887711]
-26: [2033930678 2033930180 2033930116 2033930180]
-25: [4076664120 4076663623 4076663559 4076663623]
-24: [1103225254 1103224754 1103224690 1103224754]
-23: [2132075133 2132074634 2132074570 2132074634]
-22: [2700399493 2700398996 2700398932 2700398996]
-21: [1520836166 1520835669 1520835605 1520835669]
-20: [3553978144 3553977647 3553977583 3553977647]
-19: [742839779 742839280 742839216 742839280]
-18: [219841912 219841414 219841350 219841414]
-17: [3908243459 3908242960 3908242896 3908242960]
-16: [2197714996 2197714497 2197714433 2197714497]
-15: [1117139674 1117139175 1117139111 1117139175]
-14: [979658782 979658285 979658221 979658285]
-13: [2686683153 2686682656 2686682592 26866826

# FSS Comparison2 (correctness!)

In [52]:
%%cython -+ -c=/std:c++17 -I . -S AES.cpp -a
cimport numpy as np
import numpy as np
cimport cython

from libc.string cimport memcpy
from cython.operator cimport dereference as deref

np.import_array()

cdef extern from "AES.h":
    cdef cppclass AES:
        AES()
        AES(int keyLength)
        void EncryptECB(const unsigned char datain[], const unsigned char key[],
                        unsigned char dataout[],      unsigned int inLen,) nogil except * 
        void Xor3Blocks(const unsigned char *a, const unsigned char *b,
                        unsigned char *in_out,  unsigned int inLen) nogil except * 
        
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
@cython.initializedcheck(False)
@cython.nonecheck(False)
@cython.embedsignature(True)
cdef class CPRG():
    """Cryptographically-secure Pseudo-Random Generator.
    
    Implemented using a Miyaguchi–Preneel construction with AES-ECB as block
    cipher, and an arbitrary IV as first key.
    
    See Also:
        https://en.wikipedia.org/wiki/One-way_compression_function
    """
    cdef np.uint8_t[::1] H_0
    cdef const unsigned char * H_0_p
    cdef unsigned int aes_Bsize, hash_Bsize, n_blocks, last_n_Btob, i
    cdef AES* my_aes
    def __cinit__(
        self,
        unsigned int H0_seed    = 42,
        unsigned int aes_Bsize  = 16,
        unsigned int hash_Bsize = 44,
        unsigned int last_n_Btob= 4,
    ):
        assert aes_Bsize in (16, 24, 32), "aes_Bsize must be in {16 (128b), 24 (192b), 32 (256b)} "
        self.my_aes     = new AES(aes_Bsize*8)
        self.aes_Bsize  = aes_Bsize
        self.hash_Bsize = hash_Bsize
        self.n_blocks   = (hash_Bsize + aes_Bsize - 1) / aes_Bsize
        self.last_n_Btob= last_n_Btob
        rng = np.random.default_rng(seed=H0_seed)
        self.H_0        = rng.integers(0, 2**8, size=(aes_Bsize), dtype=np.uint8)
        self.H_0_p      = <const unsigned char *>&self.H_0[0]
 
    def __init__(self, unsigned int H0_seed    = 42, unsigned int aes_Bsize  = 16,
                       unsigned int hash_Bsize = 44, unsigned int last_n_Btob= 4):
        pass
    
    cpdef void G(self,
        np.ndarray[np.uint8_t, ndim=1] datain,
        np.ndarray[np.uint8_t, ndim=1] dataout,
    ):
        '''AES Pseudo-random generation step'''
        cdef Py_ssize_t j, k
        cdef const unsigned char * datain_p = <const unsigned char *>datain.data
        assert datain.shape[0]  >= self.aes_Bsize, \
            f"datain must have length>={self.aes_Bsize}"
        assert dataout.shape[0] >= self.aes_Bsize*self.n_blocks, \
            f"dataout must have length>={self.aes_Bsize * self.n_blocks}"
        self.MP_block(datain_p, self.H_0_p, <unsigned char *>&dataout[0])
        for j in range(1, self.n_blocks):
            self.MP_block(datain_p,
                          <const unsigned char *>&dataout[(j-1)*self.aes_Bsize],
                          <unsigned char *>&dataout[j*self.aes_Bsize])
        # Cast last n Bytes to bits (boolean) by keeping only the MSB
        for k in range(self.hash_Bsize-self.last_n_Btob, self.hash_Bsize):
            dataout[k] = (dataout[k] >= (1 <<7))
    
    cdef inline void MP_block (self,
        const unsigned char * datain_p,
        const unsigned char * key_p, 
        unsigned char * dataout_p,
    ):
        self.my_aes.EncryptECB(datain_p, key_p, dataout_p, self.aes_Bsize)
        self.my_aes.Xor3Blocks(datain_p, key_p, dataout_p, self.aes_Bsize)
        
    @property
    def H_0(self):
        return np.asarray(self.H_0)
    
    def __repr__(self):
        return "CPRG at {} [G {}b --> {}b (w/ {} bools)]".format(
            hex(id(self)),
            self.aes_Bsize*8,
            self.hash_Bsize*8,
            self.last_n_Btob
        )


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
@cython.initializedcheck(False)
@cython.nonecheck(False)
@cython.embedsignature(True)
cpdef np.ndarray[np.uint8_t, ndim=1] i2b(np.uint64_t x, Py_ssize_t n=32):
    """Decompose a 64bit integer `x` into an uint8 array of `n` bits, MSB first"""
    cdef np.ndarray[np.uint8_t, ndim=1] decomp = np.empty(shape=n, dtype=np.uint8)
    cdef Py_ssize_t i
    for i in range(n):
        decomp[i] = ((x & 1ULL<<(n-1-i)) != 0)
    return decomp
    
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
@cython.initializedcheck(False)
@cython.nonecheck(False)
@cython.embedsignature(True)    
cdef inline void xor_cy(np.uint8_t[::1] a_in, np.uint8_t[::1] b_inout):
    cdef Py_ssize_t i
    for i in a_in.shape[0]:
        b_inout[i] =  a_in[i] ^ b_inout[i]
        
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
@cython.initializedcheck(False)
@cython.nonecheck(False)
@cython.embedsignature(True)    
cpdef LT_Gen(np.uint64_t n, np.uint64_t lambd, np.uint64_t alpha, np.uint64_t alpha, group_type g):
    cdef np.ndarray[np.uint8_t, ndim=1] x_bits = i2b(x)
    
    # setup prg
    cdef CPRG prg = CPRG(aes_Bsize=16, hash_Bsize=44, last_n_Btob=4)
    cdef np.ndarray[np.uint8_t, ndim=1] Gsi    = np.empty(shape=48, dtype=np.uint8)
    cdef np.uint32_t sigma
    cdef np.uint8_t tau
    
    # Parse kj
    cdef np.ndarray[np.uint8_t, ndim=1] s  = np.copy(kj[1])
    cdef np.ndarray[np.uint8_t, ndim=2] CW = kj[2]
    cdef np.ndarray[np.uint32_t, ndim=1] CWleaf = kj[3]
    cdef np.uint8_t t = j
    
    # tree evaluation
    cdef np.uint32_t out = 0
    cdef Py_ssize_t i, k
    for i in range(32):
        prg.G(s, Gsi)
        if t:    # xor Gsi with CW[i]
            for k in range(44):
                Gsi[k] =  Gsi[k] ^ CW[i,k]
        if x_bits[i]:
            memcpy(&s[0], &Gsi[16], 16)          # s = Gsi[16:32]
            sigma = deref(<np.uint32_t*>&Gsi[36]) # sigma = Gsi[36:40]
            t     = Gsi[41]
            tau   = Gsi[43]
        else:
            memcpy(&s[0], &Gsi[0], 16)   # s = Gsi[0:16]
            sigma = deref(<np.uint32_t*>&Gsi[32]) # sigma = Gsi[36:40]
            t     = Gsi[40]
            tau   = Gsi[42]
        out += <np.uint32_t>((-1)**j * (tau * CWleaf[i] + sigma))
    out += <np.uint32_t>((-1)**j * (t * CWleaf[32] + deref(<np.uint32_t*>&s[0])))
    return out


In [55]:
(1 - FSS_Comparison_Eval_cy(0, key0, x) - FSS_Comparison_Eval_cy(1, key1, x)) % (2**n)

1

# Funshade: SP & FSS_comp 

In [242]:
l         = 128
threshold = 0.4

rng = np.random.default_rng(seed=4242)
def sample_biometric_template():
    template = rng.exponential(size=l)   # choosing an arbitraty element distribution for the example.
    return template / np.linalg.norm(template)

# Biometric data generation
live_template = sample_biometric_template()
ref_template  = sample_biometric_template()

# Biometric data rescaling --> Fixed-point approximation
s = 2**12
live_template_s = list((live_template * s).astype(int))
ref_template_s  = list((ref_template  * s).astype(int))
threshold_s     = int(threshold * s**2)

In [35]:
seed = np.random.randint(2**8, size=16, dtype=np.uint8)

In [None]:
%%cython -+ -c=/O2 -a
cimport numpy as np
import numpy as np
cimport cython

from libc.string cimport memset

np.import_array()
