# NuCyper

notes
- KEM : key encapsulation mechanism
- DEM : data encapsulation mechanism
- AEAD: authentication encryption with associated data

## Setup

In [31]:
# utils 

import hashlib
import struct # for packing and unpacking
import random # for random numbers, not cryptographically secure
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 # for AEAD
import os # for os.urandom for nonce

q = 7919 # group of prime order, 1000 primese here https://primes.utm.edu/lists/small/1000.txt
g = 1 # generator in Gq
U = 1 # generator in Gq

def inv_mod(x,mod=q):
    if q == 0:
        return 1/x
    else:
        # return x^-1 mod q
        return pow(x,mod-2,q)

def add_mod(x,y,mod=q):
    return (x+y) % mod

def sub_mod(x,y,mod=q):
    return (x-y) % mod

def opp_mod(x,mod=q):
    return (-x) % mod

def mul_mod(x,y,mod=q):
    return (x*y) % mod

def div_mod(x,y,mod=q):
    return (x*inv_mod(y,mod)) % mod

def evaluate_polynomial(coeff,val,mod=q):
    s = 0
    for i in range(len(coeff)):
        s += (coeff[i]*pow(val,i,mod)) %mod
        s=s%mod
    
    return s

def list_to_bytes(l):
    return b''.join([struct.pack('i', i) for i in l])

def bytes_to_list(b):
    return [struct.unpack('i', b[i:i+4])[0] for i in range(0, len(b), 4)]

def lambda_compute(s_vec,idx,mod=q):
    # s_vec is a list of secret shares
    # returns lambda
    l = 1
    for j, s_j in enumerate(s_vec):
        if j != idx:
            inv_num = inv_mod(s_j-s_vec[idx],mod)
            s = s_j
            l = (l * s) * inv_num 
            #print(f"i = {i}, inv_num = {inv_num}, l = {l}")
    return l%q


# test inv_mod
x = random.randint(1,q-1)
inv_num = inv_mod(x,q)

print(f"x = {x}, inv_num = {inv_num}, x*inv_num = {(x*inv_num) % q}")

# test evaluate_polynomial
a = [5,6,45,45,3,2,3]
print(f"polynomial = a ={a}")
print(f"evaluate_polynomial(a,0) == a[0] = {evaluate_polynomial(a,0) == a[0]}")
print(f"evaluate_polynomial(a,1) == sum(a)%q = {evaluate_polynomial(a,1) == sum(a)%q}")

x = 6
mod = 11
# test opp_mod
print(f"opp_mod({x},{mod}) = {opp_mod(x,mod)}")
print(f"add_mod(opp_mod({x},{mod}),{x},{mod}) == 0? {add_mod(opp_mod(x,mod),x,mod) == 0}")

# test list_to_bytes and bytes_to_list
test_list = [1,2,3,4,5,6,7,8,9,10]
test_bytes = list_to_bytes(test_list)
print(f"test_list = {test_list}")
print(f"test_bytes = {test_bytes}")
print(f"bytes_to_list(test_bytes) = {bytes_to_list(test_bytes)}")


x = 5931, inv_num = 6959, x*inv_num = 1
polynomial = a =[5, 6, 45, 45, 3, 2, 3]
evaluate_polynomial(a,0) == a[0] = True
evaluate_polynomial(a,1) == sum(a)%q = True
opp_mod(6,11) = 5
add_mod(opp_mod(6,11),6,11) == 0? True
test_list = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
test_bytes = b'\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00\x05\x00\x00\x00\x06\x00\x00\x00\x07\x00\x00\x00\x08\x00\x00\x00\t\x00\x00\x00\n\x00\x00\x00'
bytes_to_list(test_bytes) = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]


In [32]:
def hash_to_q(bytes_data,q=q):
    hexdig = hashlib.sha256(bytes_data).hexdigest()
    
    # convert to int mod q
    return int(hexdig,16) % q

def H2(num1,num2,q=q):
    num1_bytes = num1.to_bytes(32,byteorder='big')
    num2_bytes = num2.to_bytes(32,byteorder='big')
    return hash_to_q(num1_bytes+num2_bytes,q)

def H3(num1,num2,num3,q=q):
    num1_bytes = num1.to_bytes(32,byteorder='big')
    num2_bytes = num2.to_bytes(32,byteorder='big')
    num3_bytes = num3.to_bytes(32,byteorder='big')
    return hash_to_q(num1_bytes+num2_bytes+num3_bytes,q)

def H4(num1,num2,num3,num4,q=q):
    num4 = num4 % q
    num1_bytes = num1.to_bytes(32,byteorder='big')
    num2_bytes = num2.to_bytes(32,byteorder='big')
    num3_bytes = num3.to_bytes(32,byteorder='big')
    num4_bytes = num4.to_bytes(32,byteorder='big')
    return hash_to_q(num1_bytes+num2_bytes+num3_bytes+num4_bytes,q)

def H6(num1,num2,num3,num4,num5,num6,q=q):
    num4 = num4 % q
    num5 = num5 % q
    num6 = num6 % q
    num1_bytes = num1.to_bytes(32,byteorder='big')
    num2_bytes = num2.to_bytes(32,byteorder='big')
    num3_bytes = num3.to_bytes(32,byteorder='big')
    num4_bytes = num4.to_bytes(32,byteorder='big')
    num5_bytes = num5.to_bytes(32,byteorder='big')
    num6_bytes = num6.to_bytes(32,byteorder='big')
    return hash_to_q(num1_bytes+num2_bytes+num3_bytes+num4_bytes+num5_bytes+num6_bytes,q)   

def KDF(num,q=q):
    num_bytes = num.to_bytes(32,byteorder='big')
    salt = b'NuCypher'
    # KDF returns bytes as per the spec
    return hashlib.sha256(num_bytes+salt).digest()

# test hash_to_q
bytes_data = b'hello world'
h = hash_to_q(bytes_data)
print(f"bytes_data = {bytes_data}\nhash_to_q(bytes_data) = {h}\n")

# make 4 random numbers and hash them
a = 1
b = 2
c = 3
d = 4

print(f"a = {a}, b = {b}, c = {c}, d = {d}")

h2 = H2(a,b)
h3 = H3(a,b,c)
h4 = H4(a,b,c,d)

print(f"h2(a,b) = {h2}\nh3(a,b,c) = {h3}\nh4(a,b,c,d) = {h4}\n")

# KDF
kdf = KDF(h)
print(kdf)

bytes_data = b'hello world'
hash_to_q(bytes_data) = 5576

a = 1, b = 2, c = 3, d = 4
h2(a,b) = 6762
h3(a,b,c) = 3127
h4(a,b,c,d) = 3110

b'\x8e\xe7\xe0C\xad&\xa4+\x99\xd8\xf9S\x172\x9d\xc1\xb3\xacv\xa4>K\xf46\x81\xec\x0fZL\x1av\x15'


## Key Generation

In [33]:
def key_gen(mod=q,g=g):
    sk = random.randint(1,q-1)
    pk = mul_mod(sk, g, mod)
    return sk,pk

def re_key_gen(sk, pk, N, t,mod=q,g=g):
    # N is the number of re-encryption keys
    # t is the threshold
    assert t <= N
    assert sk < mod
    assert pk < mod
    
    x = random.randint(1, q-1)
    X = mul_mod(x, g,mod)
    d = H3(X,pk, mul_mod(x, pk,mod))
    print("d: ",d)

    # generate t-1 random numbers
    f = [random.randint(1,mod-1) for i in range(t-1)]

    # f0=add_mod(sk, opp_mod(d,mod),mod)
    # print("f0: ",f0)
    f0=mul_mod(sk,inv_mod(d,mod), mod)
    poly_coeffs = [f0] + f
    pkA = mul_mod(sk, g,mod)
    D = H3(pkA,pk, mul_mod(pk,sk,mod))
    print("D: ",D)

    KF = []

    for i in range(N):
        y = random.randint(1,mod-1)
        id = random.randint(1,mod-1)
        sx = H2(id, D)
        print("i", i, "sx: ",sx)
        Y = mul_mod(g,y,mod)
        rk = evaluate_polynomial(poly_coeffs,sx,mod)
        U1 = mul_mod(U,rk,q)
        z1 = H6(Y,id,pkA,pk,U1,X)    
        z2 = (y - z1 * sk) % q

        kFrag = [id,rk, X, U1, z1, z2]
        KF.append(kFrag)
    
    return KF


alice_sk, alice_pk = key_gen()
bob_sk, bob_pk = key_gen()

print(f"Alice's sk: {alice_sk}, pk: {alice_pk}")
print(f"Bob's sk: {bob_sk}, pk: {bob_pk}")

# re-encryption key generation
N = 10
t = 6
KF = re_key_gen(alice_sk, bob_pk, N, t)

print(KF)


Alice's sk: 2341, pk: 2341
Bob's sk: 7454, pk: 7454
d:  7050
D:  4697
i 0 sx:  2859
i 1 sx:  339
i 2 sx:  6498
i 3 sx:  1914
i 4 sx:  3323
i 5 sx:  6732
i 6 sx:  7898
i 7 sx:  1729
i 8 sx:  6588
i 9 sx:  6499
[[6509, 7221, 4829, 7221, 4740, 3185], [7026, 5969, 4829, 5969, 3593, 2457], [241, 2916, 4829, 2916, 5090, 1547], [1089, 3926, 4829, 3926, 7335, 2236], [5452, 2262, 4829, 2262, 4221, 7843], [1399, 2205, 4829, 2205, 398, 3997], [6850, 5104, 4829, 5104, 7163, 4899], [726, 6947, 4829, 6947, 5158, 4743], [7353, 2265, 4829, 2265, 423, 4891], [7562, 3971, 4829, 3971, 5035, 6921]]


## Encapsulation and Decapsulation

In [34]:
def encapsulate(alice_pk,mod=q,g=g):
    r = random.randint(1,mod-1)
    u = random.randint(1,mod-1)

    E = mul_mod(g,r,mod)
    V = mul_mod(g,u,mod)

    h2 = H2(E,V)
    s = add_mod(u, mul_mod(r, h2,mod), mod) 
    iinKDF = mul_mod(alice_pk,add_mod(r,u,mod),mod)
    print(f"iinKDF: {iinKDF}")
    K = KDF(iinKDF)
    capsule = [E,V,s]

    return K, capsule

def check_capsule(capsule,mod=q,g=g):
    E,V,s = capsule
  
    h2 = H2(E,V)
 
    if mul_mod(g,s,mod) == add_mod(V, mul_mod(E,h2,mod),mod):
        return True
    else:
        print(f"mul_mod(g,s,mod) {mul_mod(g,s,mod)}")
        print(f"add_mod(V, mul_mod(E,h2,mod),mod: {add_mod(V, mul_mod(E,h2,mod),mod)}")
        return False

def decapsulate(alice_sk, capsule,mod=q):
    E,V,s = capsule
    if not check_capsule(capsule):
        raise ValueError("Invalid capsule")
    inKDF = mul_mod(E+V,alice_sk,mod)
    print(f"inKDF: {inKDF}")
    K = KDF(inKDF)
    return K


K, capsule = encapsulate(alice_pk)
K2 = decapsulate(alice_sk, capsule)
print(f"is K == K2? {K == K2}")

iinKDF: 1539
inKDF: 1539
is K == K2? True


## Re-Encapsulation and Fragments Decapsulation

In [35]:
def re_encapsulate(kFrag, capsule,mod=q):
    if not check_capsule(capsule):
        raise ValueError("Invalid capsule")
    
    E,V,s = capsule
    id,rk,X,U1,z1,z2 = kFrag

    E1 = mul_mod(E,rk,mod)
    V1 = mul_mod(V,rk,mod)
    cFrag = [E1,V1,id,X]

    return cFrag

def decapsulateFrags(skB,pkA,cFrag_vec,mod=q,g=g):
    
    t = len(cFrag_vec)
    pkB = mul_mod(g,skB,mod)
    assert pkB == bob_pk
    D = H3(pkA,pkB, mul_mod(pkA,skB,mod))
    print(f"D: {D}" )
    S = [H2(cFrag_vec[i][2],D) for i in range(t)] # H2(id_i,D), id is the third element of cFrag
    print(f"S: {S}")

    E_prime = 0
    V_prime = 0

    for i in range(t):
        lam = lambda_compute(S,i)
        E = cFrag_vec[i][0] # E is the first element of cFrag
        V = cFrag_vec[i][1] # V is the second element of cFrag
        E_prime = add_mod(E_prime , mul_mod(E,lam, mod),mod)
        V_prime = add_mod(V_prime,mul_mod(V,lam, mod), mod)
    X = cFrag_vec[0][3]
    d = H3(X,pkB, mul_mod(X,skB,mod))
    inKDF = mul_mod(add_mod(E_prime,V_prime,mod),d,mod)
    print(f"d: {d}")
    print(f"E_prime: {E_prime}")
    print(f"V_prime: {V_prime}")
    print(f"inKDF: {inKDF}")
    K = KDF(inKDF)

    return K

tprime =t
L = random.sample(range(N), tprime)
# L = [0,1,2]
# L=sorted([7, 8, 6, 4, 3, 5, 2])
print(f"selected indices: {L}")
cFrag_vec = [re_encapsulate(KF[i], capsule) for i in L]

for i in range(tprime):
    print(f"cFrag {L[i]}: {cFrag_vec[i]}")

print()

K3 = decapsulateFrags(bob_sk, alice_pk, cFrag_vec)
print(f"is K == K3? **{K == K3}**")



selected indices: [4, 8, 5, 0, 7, 6]
cFrag 4: [4210, 6424, 5452, 4829]
cFrag 8: [1905, 4374, 7353, 4829]
cFrag 5: [491, 5779, 1399, 4829]
cFrag 0: [3084, 7106, 6509, 4829]
cFrag 7: [2434, 6923, 726, 4829]
cFrag 6: [3611, 1906, 6850, 4829]

D: 4697
S: [3323, 6588, 6732, 2859, 1729, 7898]
d: 7050
E_prime: 4105
V_prime: 7430
inKDF: 1539
is K == K3? **True**


In [36]:
## testing the ChaCha20Poly1305 AEAD
test_aead_data = b"a secret message"
test_aead_aad = b"authenticated but unencrypted data"
test_aead_key = ChaCha20Poly1305.generate_key()

print(f"test_aead_key: {test_aead_key}")
test_aead_chacha = ChaCha20Poly1305(test_aead_key)
test_aead_nonce = os.urandom(12)
test_aead_ct = test_aead_chacha.encrypt(test_aead_nonce, test_aead_data, test_aead_aad)
test_aead_chacha.decrypt(test_aead_nonce, test_aead_ct, test_aead_aad)
# b'a secret message'

test_aead_key: b'\x0f\xe6\x16\xab\xf6\tR\xda\xa7{JC"E\x0f\x00w\x07\xf6\x124G\x1fh\xe2\xa8\xb2\x0f\x16\x87\xfdo'


b'a secret message'

## KEM/DEM construction

In [38]:
def encrypt(pkA,M):
    K, capsule = encapsulate(pkA)
    bytes_capsule = list_to_bytes(capsule)
    nonce = os.urandom(12)
    enc_data = ChaCha20Poly1305(K).encrypt(nonce, M, bytes_capsule)
    return (capsule, enc_data, nonce)

def decrypt(skA,cyphertext):
    capsule, enc_data, nonce= cyphertext
    bytes_capsule = list_to_bytes(capsule)
    K = decapsulate(skA, capsule)
    return ChaCha20Poly1305(K).decrypt(nonce, enc_data, bytes_capsule)




test_message = b"this is a test message"
capsule, enc_data, nonce = encrypt(alice_pk, test_message)
print(f"enc_data: {enc_data}")
print(f"capsule: {capsule}")
print(f"nonce: {nonce}")

decrypted_message = decrypt(alice_sk, (capsule, enc_data,nonce ))
print(f"decrypted_message: {decrypted_message}")
print(f"test_message == decrypted_message? {test_message == decrypted_message}")

iinKDF: 5675
enc_data: b'f(\x1d\x8a\x8c&\x9f\xa5\x9a\xca?\xea\xf6 \x16h#\x7f\xfc\x81Lw\xbc\xea\xe4\xb8\xd1\x08\x8a?v\xbc\xfc\x1b\x19I\xfa\xb6'
capsule: [5299, 471, 3599]
nonce: b'\x1e\xae,\xc1\xb8m{\xd9f\x9eR\xc7'
inKDF: 5675
decrypted_message: b'this is a test message'
test_message == decrypted_message? True


In [55]:
def re_encrypt(kFrag, cyphertext):
    capsule, enc_data, nonce= cyphertext
    cFrag = re_encapsulate(kFrag, capsule)
    Cprime = (cFrag, enc_data, nonce)
    return Cprime


def decrypt_frags(skB,pkA,cyphertexts, capsule,mod=q,g=g):
    cFrag_vec = [cyphertexts[i][0] for i in range(len(cyphertexts))]
    print(f"cFrag_vec: {cFrag_vec}")
    K = decapsulateFrags(skB,pkA,cFrag_vec,mod=q,g=g)
    _, enc_data, nonce = cyphertexts[0]
    bytes_capsule = list_to_bytes(capsule)
    dec_data = ChaCha20Poly1305(K).decrypt(nonce, enc_data, bytes_capsule)
    return dec_data

cyphertexts = [re_encrypt(KF[i], (capsule, enc_data, nonce)) for i in range(tprime)]

decrypted_message2 = decrypt_frags(bob_sk, alice_pk, cyphertexts, capsule)
print(f"decrypted_message2: {decrypted_message2}")
print(f"test_message == decrypted_message2? {test_message == decrypted_message2}")


cFrag_vec: [[7390, 3840, 6509, 4829], [1245, 154, 7026, 4829], [1915, 3449, 241, 4829], [661, 4019, 1089, 4829], [4891, 4256, 5452, 4829], [3770, 1166, 1399, 4829]]
D: 4697
S: [2859, 339, 6498, 1914, 3323, 6732]
d: 7050
E_prime: 5117
V_prime: 6313
inKDF: 5675
decrypted_message2: b'this is a test message'
test_message == decrypted_message2? True
