# **Estruturas criptograficas: TP3 problema 2**

In [579]:
def ntt_transform(f, s):
    
    # Calcula a transformada NTT de f
    trans = [f(s_i) for s_i in s]
    
    return trans

def ntt_inverse(trans, mu):

    # Calcula a transformada inversa
    f_inv = [mu[i] * trans[i] for i in range(len(trans))]

    return sum(f_inv)

def raizes(omega, N):
    
    s = [omega.nth_root(2) * omega^i for i in range(N)]
    
    return s

def base_mu(x, s, N):
    
    mods = [x - s[i] for i in range(N)]

    mu = CRT_basis(mods)
    
    return mu
    

# Escolha de valores
q = 3329
N = 256
omega = primitive_root(q)^2

# Polinômio de exemplo
f = PolynomialRing(GF(q), 'w')
w = f.gen()
# f = 1 + x -2*x^2 - x^3


# Cálculo das raízes
ra = raizes(omega, N)

mu = base_mu(w, ra, N)


In [580]:
import os
from hashlib import sha3_256, sha3_512, shake_128, shake_256
from cryptography.hazmat.primitives import hashes

class KEM():
    
    def __init__(self, security_bits=512):
        
        if security_bits == 512:
            self.n = 256
            self.q = 3329
            self.k = 2
            self.eta1 = 3
            self.eta2 = 2
            self.du = 10
            self.dv = 4
        elif security_bits == 768:
            self.n = 256
            self.q = 3329
            self.k = 3
            self.eta1 = 2
            self.eta2 = 2
            self.du = 10
            self.dv = 4
        elif security_bits == 1024:
            self.n = 256
            self.q = 3329
            self.k = 4
            self.eta1 = 2
            self.eta2 = 2
            self.du = 11
            self.dv = 5
        
        # Zq.<w> = GF(self.q)[]
        # fi = w^self.n + 1
        # Rq.<w> = QuotientRing(Zq, Zq.ideal(fi))
        f = PolynomialRing(GF(q), 'w')
        self.Rq = f
    
    def G(self, c):
        c_hash = sha3_512(c).digest()
        a = c_hash[:32]
        b = c_hash[32:]
        return a, b
    
    def H(self, m):
        return sha3_256(m).digest()
    
    def J(self, s):
        return shake_256(s).digest(32)
    
    # variable output length B32*B*B
    def XOF(self, rho, i, j):
        temp = shake_128(rho + int.to_bytes(i) + int.to_bytes(j)).digest(self.q)
        # print(len(temp))
        return temp
    
    def NTT(self, f):
        f_ = list(f)
        
        k = 1
        len = 128
        while len >= 2:
            start = 0
            while start < 256:
                zeta = mod(17^(self.BitReverse(k)), self.q) 
                k += 1
                for j in range(start, start + len):
                    t = mod(zeta * f_[j + len], self.q)
                    f_[j + len] = mod(f_[j] - t, self.q)
                    f_[j] = mod(f_[j] + t, self.q)
                    
                
                
                start = start + 2 * len
            len = len // 2
        
        return f_
    
    def NTTInverse(self, f_):
        f = list(f_)
        
        k = 127
        len = 2
        while len <= 128:
            start = 0
            while start < 256:
                zeta = mod(17^(self.BitReverse(k)), self.q)
                k -= 1
                for j in range(start, start + len):
                    t = zeta * f[j + len]
                    f[j] = mod(t + f[j + len], self.q)
                    f[j + len] = mod(f[j + len] - t, self.q)
                
                start = start + 2 * len
            len = len * 2
        
        f = [mod(f[i] * 3303, self.q) for i in range(256)]
        return f
    
    def SampleNTT(self, B):
        i = 0
        j = 0
        a = [[0] for _ in range(256)]
        while j < 256:
            d1 = B[i] + 256 * (B[i + 1] % 16)
            d2 = floor(B[i + 1] / 16) + 16 * B[i + 2]
            if d1 < self.q:
                a[j] = d1
                j += 1
            if d2 < self.q and j < 256:
                a[j] = d2
                j += 1
            i += 3
            
        return a
        # return self.Rq(a)
    
    def PRF(self, s, b, eta):
        return shake_256(s + bytes(b)).digest(64 * eta)
    
    def BytesToBits(self, B):
        b = [0] * len(B) * 8 
        B = self.BytesToByteArray(B)
        for i in range(len(B)):
            for j in range(0,8):
                b[8*i+j] = mod(B[i], 2)
                B[i] = B[i] // 2
                
        return b
            
    def BitsToBytes(self, b):
        l = len(b) // 8
        B = [0] * l
        for i in range(0,8*l):
            B[i // 8] += ZZ(b[i]) * 2^(mod(i,8))
        return B
    
    def ByteArrayToBytes(self, B):
        return bytes(B)
    
    def BytesToByteArray(self, Bytes):
        return list(Bytes)
    
    def MatrixMultiplication(self, A, u, Transpose=False):
        aux = A.copy()
        res = []
        
        if len(A[0]) == self.k:
            for i in range(self.k):
                for j in range(self.k):
                    if Transpose:
                        aux[i][j] = self.MultiplyNTTs(A[j][i], u[j])
                    else:
                        aux[i][j] = self.MultiplyNTTs(A[i][j], u[j])
                        # print(aux[i][j])
            
            for i in range(self.k):
                res.append(self.ArrayAddition(aux[i][0], aux[i][1]))
        
        elif len(A[0]) == self.n:
            for i in range(self.k):
                aux[i] = self.MultiplyNTTs(A[i], u[i])
                
            res = self.ArrayAddition(aux[0], aux[1])

                 
        return res
    
    def MatrixAddition(self, A, B):
        res = []
        for i in range(self.k):
            res.append(self.ArrayAddition(A[i], B[i]))
            
        return res
        
    
    def ArrayAddition(self, A, B):
        res = [0] * self.n
        for i in range(self.n):
            res[i] = A[i] + B[i]
        
        return res
    
    def ArraySubtraction(self, A, B):
        res = [0] * self.n
        for i in range(self.n):
            res[i] = A[i] - B[i]
        
        return res
    
    def MultiplyNTTs(self, f, g):#17
        h = [0] * self.n
        for i in range(128):
            h[2*i], h[2*i + 1] = self.BaseCaseMultiply(f[2*i], f[2*i + 1], g[2*i], g[2*i + 1], 17^(2* self.BitReverse(i) + 1))
        return h
    
    def BaseCaseMultiply(self, a0, a1, b0, b1, y):
        c0 = mod((a0 * b0) + (a1 * b1 * y), self.q)
        c1 = mod((a0 * b1) + (a1 * b0), self.q)
        return c0, c1 
    
    def BitReverse(self, i):
        return int('{:07b}'.format(i)[::-1], 2)
        
                    
    def SamplePolyCBD(self, B, eta):
        b = self.BytesToBits(B)
        
        f = [0] * 256
        
        for i in range(256):
            x = 0
            y = 0
            for j in range(eta):
                x += b[2 * i * eta + j]
            for j in range(eta):
                y += b[2 * i * eta + eta + j]   
                
            f[i] = mod((x - y), self.q)
            
        return f
        # return self.Rq(f)
     
    def ByteEncode(self, F, d):
        
        if d < 12:
            m = 2^d
        elif d == 12:
            m = self.q
            
        b = [0] * (256 * d)
        for i in range(256):
            a = mod(F[i], 2^d)
            # print(a)
            for j in range(d):
                b[i*d + j] = a % 2
                # print(b[i*d + j])
                # print(ZZ(a) - ZZ(b[i*d + j]))
                a = (ZZ(a) - ZZ(b[i*d + j])) / 2
        
        B = self.BitsToBytes(b)
        return B
        
    def ByteDecode(self, B, d):
        if d < 12:
            m = 2^d
        elif d == 12:
            m = self.q

        b = self.BytesToBits(B)
        F = [0] * 256
        for i in range(256):
            for j in range(0, d):
                F[i] += mod(ZZ(b[i*d + j]) * ZZ(2^j), m)
        
        return F
        
    def Compress(self, x, d):
        for i in range(len(x)):
            x[i] = mod(round((2^d) / self.q * ZZ(x[i])), 2^d)

        return x


    def Decompress(self, y, d):
        for i in range(len(y)):
            y[i] = mod(round((self.q / 2^d) * ZZ(y[i])), self.q)
        
        return y
                 
     
    def PKEKeyGen(self):
        d = os.urandom(32)
        rho, sigma = self.G(d)
        N = 0
        
        A = [[0] * self.k for _ in range(self.k)]
        for i in range(self.k):
            for j in range(self.k):
                A[i][j] = self.SampleNTT(self.XOF(rho, i, j))
                
        
        s = [[0] for _ in range(self.k)]
        for i in range(self.k):
            s[i] = self.SamplePolyCBD(self.PRF(sigma, N, self.eta1), self.eta1)
            N += 1
            
        e = [[0] for _ in range(self.k)]
        for i in range(self.k):
            e[i] = self.SamplePolyCBD(self.PRF(sigma, N, self.eta1), self.eta1)
            N += 1

        
        _s, _e = [], []
        
        for i in range(self.k):
            _s.append(self.NTT(s[i]))
            _e.append(self.NTT(e[i]))
            

        t = self.MatrixAddition(self.MatrixMultiplication(A, _s), _e)

        ek_pke = b''
        for i in range(self.k):
            ek_pke += self.ByteArrayToBytes(self.ByteEncode(t[i], 12))
        
        ek_pke += rho
        
        dk_pke = b''
        for i in range(self.k):
            dk_pke += self.ByteArrayToBytes(self.ByteEncode(_s[i], 12))
        
        return ek_pke, dk_pke
    
    def PKEEncrypt(self, ek, m, r):
        N = 0
        
        t = []
        for i in range(self.k):
            t.append(self.ByteDecode(self.BytesToByteArray(ek[i * 384:(i+1) * 384]), 12))
        
        rho = ek[self.k * 384:]
        
        A = [[0] * self.k for _ in range(self.k)]
        for i in range(self.k):
            for j in range(self.k):
                A[i][j] = self.SampleNTT(self.XOF(rho, i, j))
        
        r_ = [0] * self.k 
        for i in range(self.k):
            r_[i] = self.SamplePolyCBD(self.PRF(r, N, self.eta1), self.eta1)
            N += 1
            
        e1 = [0] * self.k
        for i in range(self.k):
            e1[i] = self.SamplePolyCBD(self.PRF(r, N, self.eta2), self.eta2)
            N += 1
         
        e2 = self.SamplePolyCBD(self.PRF(r, N, self.eta2), self.eta2)
        
        # print('AQUI')
        # print(r_[0])
        # x = ntt_transform(r_[0], ra)
        # print(x)
        # print(ntt_inverse(x, mu))
        
        _r = []
        for i in range(self.k):
            _r.append(self.NTT(r_[i]))
            
            
        u = self.MatrixMultiplication(A, _r, True)

        for i in range(self.k):
            u[i] = self.NTTInverse(u[i])
        
        
        u = self.MatrixAddition(u, e1)
        print('u', u)

        mu2 = self.Decompress(self.ByteDecode(m, 1), 1)
        
        v = self.ArrayAddition(self.ArrayAddition(self.NTTInverse(self.MatrixMultiplication(t, _r, True)), e2), mu2)
        print('v', v)
        
        c1 = b''
        for i in range(self.k):
            c1 += self.ByteArrayToBytes(self.ByteEncode(self.Compress(u[i], self.du), self.du))
            
        c2 = self.ByteArrayToBytes(self.ByteEncode(self.Compress(v, self.dv), self.dv))

        return c1 + c2
    
    def PKEDecrypt(self, dk_pke, c):
        # ATENÇAO A DIVIDIR
        c1 = self.BytesToByteArray(c[: 32 * self.du * self.k])
        c1 = [c1[:len(c1)/2], c1[len(c1)/2:]]
        c2 = self.BytesToByteArray(c[32 * self.du * self.k: 32 * (self.du * self.k + self.dv)]) 
        
        u = []
        for i in range(self.k):
            u.append(self.Decompress(self.ByteDecode(c1[i], self.du), self.du))

        v = self.Decompress(self.ByteDecode(c2, self.dv), self.dv)        
        
        # print(dk_pke)
        
        s = []
        for i in range(self.k):
            s.append(self.ByteDecode(dk_pke[i * 384: (i+1) * 384], 12))  
            
        for i in range(self.k):
            u[i] = self.NTT(u[i])
            
        w = self.MatrixMultiplication(s, u, True)
        w = self.ArraySubtraction(v, self.NTTInverse(w))
        
        m = self.ByteEncode(self.Compress(w, 1), 1)

        return self.ByteArrayToBytes(m)
        
    def keygen(self):
        z = os.urandom(32)
        
        ek_pke, dk_pke = self.PKEKeyGen()
        
        ek = ek_pke
        dk = dk_pke + ek + self.H(ek) + z
        
        return ek, dk       
    
    def encaps(self, ek):
        m = os.urandom(32)
        # print(m)
        # print(len(m))   
        K, r = self.G(m + self.H(ek))
        
        c = self.PKEEncrypt(ek, m, r)
        
        return K, c
    
    def decaps(self, c, dk):
        
        dk_pke = dk[:384 * self.k]
        ek_pke = dk[384 * self.k: 768 * self.k + 32]
        h = dk[768 * self.k + 32: 768 * self.k + 64]
        z = dk[768 * self.k + 64: 768 * self.k + 96]
        
        m_ = self.PKEDecrypt(dk_pke, c)
        
        # print(m_)
        # print(len(m_))  
        K_, r_ = self.G(m_ + h)
        
        K = self.J(z + c)
        
        c_ = self.PKEEncrypt(ek_pke, m_, r_)
        
        if c_ != c:
            print('Error')
            K_ = K
        
        return K_
        

In [581]:
kem = KEM()
ek, dk = kem.keygen()

K, c = kem.encaps(ek)

K_ = kem.decaps(c, dk)

print(K.hex())
print(K_.hex())

u [[2056, 2475, 2613, 1129, 829, 2019, 2963, 2645, 2264, 3090, 2797, 1693, 2528, 2098, 713, 1466, 3327, 2888, 2223, 393, 1657, 416, 1857, 1, 541, 2960, 3219, 329, 978, 619, 1351, 2040, 3181, 1401, 1612, 605, 2353, 2857, 2349, 2269, 2458, 1647, 1145, 1865, 1969, 3002, 471, 2511, 1916, 1869, 887, 558, 1542, 302, 3068, 1582, 3098, 2516, 947, 2573, 1563, 3243, 461, 1188, 1225, 163, 2241, 3185, 140, 459, 1357, 2922, 2316, 748, 2011, 2665, 457, 452, 1815, 338, 2433, 1449, 2277, 1672, 1901, 2994, 2011, 2518, 678, 2978, 138, 312, 2473, 670, 392, 884, 996, 3119, 965, 2038, 46, 3287, 3290, 779, 1444, 2541, 197, 2180, 1873, 2207, 1665, 1738, 1579, 2266, 3037, 207, 1142, 1425, 1945, 2804, 1728, 1012, 1793, 581, 1568, 649, 2306, 2754, 2278, 1819, 1306, 2083, 1458, 1270, 2034, 2570, 2178, 435, 1024, 723, 65, 2898, 682, 384, 1859, 149, 3158, 1349, 3045, 1529, 1732, 1601, 61, 2162, 2167, 2148, 171, 1688, 2809, 3310, 1289, 2912, 2555, 2590, 1430, 484, 1688, 1790, 2980, 1962, 1050, 1217, 2896, 2783, 297

See https://github.com/sagemath/sage/issues/35473 for details.
  y[i] = mod(round((self.q / Integer(2)**d) * ZZ(y[i])), self.q)


u [[2893, 1870, 1868, 928, 1537, 3079, 855, 963, 1129, 1225, 477, 2610, 2644, 2916, 239, 738, 1930, 2331, 3093, 888, 3249, 2164, 441, 2517, 393, 387, 1501, 2985, 1, 2762, 370, 136, 1536, 319, 486, 828, 1575, 902, 1930, 2898, 606, 2577, 203, 1779, 2269, 2861, 1683, 47, 1967, 1856, 3061, 2419, 23, 820, 1460, 2970, 558, 2641, 2834, 2093, 1581, 823, 815, 379, 1534, 1455, 3076, 558, 2679, 3011, 3167, 2133, 3185, 1896, 128, 534, 2922, 2044, 2582, 1514, 2730, 1243, 1644, 2226, 1630, 191, 1881, 202, 1672, 2398, 1103, 2678, 2519, 1246, 351, 373, 161, 1764, 2816, 1762, 1011, 1669, 211, 3326, 2038, 1476, 2628, 2017, 778, 854, 789, 1832, 165, 2822, 964, 2302, 2133, 781, 1065, 1527, 207, 2837, 2777, 1179, 2804, 2487, 2317, 2228, 3220, 2558, 2688, 1795, 798, 2809, 1511, 1202, 2083, 2548, 2589, 1435, 2571, 1671, 2894, 735, 2018, 2849, 1906, 2647, 1832, 1969, 3181, 839, 1348, 7, 1392, 2214, 1601, 18, 1167, 2945, 2527, 2792, 343, 1588, 3277, 143, 417, 2833, 2591, 299, 288, 105, 1790, 1955, 1369, 1962, 