# **Estruturas criptograficas: TP3 problema 2**

In [456]:
import os
from hashlib import sha3_256, sha3_512, shake_128, shake_256
from cryptography.hazmat.primitives import hashes
from functools import reduce


class KEM():
    
    def __init__(self, security_bits=512):
        
        if security_bits == 512:
            self.n = 256
            self.q = 3329
            self.k = 2
            self.eta1 = 3
            self.eta2 = 2
            self.du = 10
            self.dv = 4
        elif security_bits == 768:
            self.n = 256
            self.q = 3329
            self.k = 3
            self.eta1 = 2
            self.eta2 = 2
            self.du = 10
            self.dv = 4
        elif security_bits == 1024:
            self.n = 256
            self.q = 3329
            self.k = 4
            self.eta1 = 2
            self.eta2 = 2
            self.du = 11
            self.dv = 5
            
    
    def G(self, c):
        c_hash = sha3_512(c).digest()
        a = c_hash[:32]
        b = c_hash[32:]
        return a, b
    
    def H(self, m):
        return sha3_256(m).digest()
    
    def J(self, s):
        return shake_256(s).digest(32)
    
    # variable output length B32*B*B
    def XOF(self, rho, i, j):
        temp = shake_128(rho + int.to_bytes(i) + int.to_bytes(j)).digest(self.q)
        return temp
    
    def NTT(self, f):
        f_ = list(f)
        
        k = 1
        len = 128
        while len >= 2:
            start = 0
            while start < 256:
                zeta = mod(17^(self.BitReverse(k)), self.q) 
                k += 1
                for j in range(start, start + len):
                    t = mod(zeta * f_[j + len], self.q)
                    f_[j + len] = mod(f_[j] - t, self.q)
                    f_[j] = mod(f_[j] + t, self.q)
                    
                
                
                start = start + 2 * len
            len = len // 2
        
        return f_
    
    def NTTInverse(self, f_):
        f = list(f_)
        
        k = 127
        len = 2
        while len <= 128:
            start = 0
            while start < 256:
                zeta = mod(17^(self.BitReverse(k)), self.q)
                k -= 1
                for j in range(start, start + len):
                    t = f[j]
                    f[j] = mod(t + f[j + len], self.q)
                    f[j + len] = mod(zeta * (f[j + len] - t), self.q)
                
                start = start + 2 * len
            len = len * 2
        
        for i in range(256):
            f[i] = mod(f[i] * 3303, self.q)

        return f
    
    def SampleNTT(self, B):
        i = 0
        j = 0
        a = [[0] for _ in range(256)]
        while j < 256:
            d1 = B[i] + 256 * (B[i + 1] % 16) # % ou mod?
            d2 = floor(B[i + 1] / 16) + 16 * B[i + 2]
            if d1 < self.q:
                a[j] = d1
                j += 1
            if d2 < self.q and j < 256:
                a[j] = d2
                j += 1
            i += 3
            
        return a
    
    def PRF(self, s, b, eta):
        return shake_256(s + bytes(b)).digest(64 * eta)
    
    def BytesToBits(self, B):
        b = [0] * len(B) * 8 
        B = self.BytesToByteArray(B)
        for i in range(len(B)):
            for j in range(0,8):
                b[8*i+j] = mod(B[i], 2)
                B[i] = B[i] // 2
                
        return b
            
    def BitsToBytes(self, b):
        l = len(b) // 8
        B = [0] * l
        for i in range(0,8*l):
            B[i // 8] += ZZ(b[i]) * 2^(mod(i,8))
        return B
    
    def ByteArrayToBytes(self, B):
        return bytes(B)
    
    def BytesToByteArray(self, Bytes):
        return list(Bytes)
    
    def MatrixMultiplication(self, A, u):
        aux = A.copy()
        res = [0] * self.n
        
        for i in range(self.k):
            aux[i] = self.MultiplyNTTs(A[i], u[i])
            
        for i in range(self.k):    
            res = self.ArrayAddition(res, aux[i])

        return res
    
    def MatrixAddition(self, A, B):
        res = []
        for i in range(self.k):
            res.append(self.ArrayAddition(A[i], B[i]))
            
        return res
        
    
    def ArrayAddition(self, A, B):
        res = [0] * self.n
        for i in range(self.n):
            res[i] = A[i] + B[i]
        
        return res
    
    def ArraySubtraction(self, A, B):
        res = [0] * self.n
        for i in range(self.n):
            res[i] = A[i] - B[i]
        
        return res
    
    def MultiplyNTTs(self, f, g):
        h = [0] * self.n
        for i in range(128):
            h[2*i], h[2*i + 1] = self.BaseCaseMultiply(f[2*i], f[2*i + 1], g[2*i], g[2*i + 1], 17^(2* self.BitReverse(i) + 1))
        return h
    
    def BaseCaseMultiply(self, a0, a1, b0, b1, y):
        c0 = mod((a0 * b0) + (a1 * b1 * y), self.q)
        c1 = mod((a0 * b1) + (a1 * b0), self.q)
        return c0, c1 
    
    def BitReverse(self, i):
        return int('{:07b}'.format(i)[::-1], 2)
        
                    
    def SamplePolyCBD(self, B, eta):
        b = self.BytesToBits(B)
        
        f = [0] * 256
        
        for i in range(256):
            x = 0
            y = 0
            for j in range(eta):
                x += b[2 * i * eta + j]
            for j in range(eta):
                y += b[2 * i * eta + eta + j]   
                
            f[i] = mod((x - y), self.q)
            
        return f

     
    def ByteEncode(self, F, d):
        
        if d < 12:
            m = 2^d
        elif d == 12:
            m = self.q
            
        b = [0] * (256 * d)
        for i in range(256):
            a = mod(F[i], 2^d)

            for j in range(d):
                b[i*d + j] = a % 2
                a = (ZZ(a) - ZZ(b[i*d + j])) / 2
        
        B = self.BitsToBytes(b)
        return B
        
    def ByteDecode(self, B, d):
        if d < 12:
            m = 2^d
        elif d == 12:
            m = self.q

        b = self.BytesToBits(B)
        F = [0] * 256
        for i in range(256):
            for j in range(0, d):
                F[i] += mod(ZZ(b[i*d + j]) * ZZ(2^j), m)
        
        return F
        
    def Compress(self, x, d):
        z = list(x)
        for i in range(len(x)):
            z[i] = mod(round((2^d) / self.q * ZZ(z[i])), 2^d)

        return z


    def Decompress(self, y, d):
        z = list(y)
        for i in range(len(y)):
            z[i] = mod(round((self.q / 2^d) * ZZ(z[i])), self.q)
        
        return z
    

    def PKEKeyGen(self):
        d = os.urandom(32)
        rho, sigma = self.G(d)
        N = 0
        
        A = [[0] * self.k for _ in range(self.k)]
        for i in range(self.k):
            for j in range(self.k):
                A[i][j] = self.SampleNTT(self.XOF(rho, i, j))
        
        s = [[0] for _ in range(self.k)]
        for i in range(self.k):
            s[i] = self.SamplePolyCBD(self.PRF(sigma, N, self.eta1), self.eta1)
            N += 1
            
        
            
        e = [[0] for _ in range(self.k)]
        for i in range(self.k):
            e[i] = self.SamplePolyCBD(self.PRF(sigma, N, self.eta1), self.eta1)
            N += 1

        
        _s, _e = [], []
        
        for i in range(self.k):
            _s.append(self.NTT(s[i]))
            _e.append(self.NTT(e[i]))
 
        t = [
		reduce(self.ArrayAddition, [
                self.MultiplyNTTs(A[j][i], _s[j])
                for j in range(self.k)
            ] + [_e[i]])
            for i in range(self.k)
        ]

        ek_pke = b''
        for i in range(self.k):
            ek_pke += self.ByteArrayToBytes(self.ByteEncode(t[i], 12))
        
        ek_pke += rho
        
        dk_pke = b''
        for i in range(self.k):
            dk_pke += self.ByteArrayToBytes(self.ByteEncode(_s[i], 12))
        
        return ek_pke, dk_pke
    
    def PKEEncrypt(self, ek, m, r):
        N = 0
        
        t = []
        for i in range(self.k):
            t.append(self.ByteDecode(self.BytesToByteArray(ek[i * 384:(i+1) * 384]), 12))
        
        rho = ek[self.k * 384:]
        
        A = [[0] * self.k for _ in range(self.k)]
        for i in range(self.k):
            for j in range(self.k):
                A[i][j] = self.SampleNTT(self.XOF(rho, i, j))
        
        r_ = [0] * self.k 
        for i in range(self.k):
            r_[i] = self.SamplePolyCBD(self.PRF(r, N, self.eta1), self.eta1)
            N += 1
            
        e1 = [0] * self.k
        for i in range(self.k):
            e1[i] = self.SamplePolyCBD(self.PRF(r, N, self.eta2), self.eta2)
            N += 1
         
        e2 = self.SamplePolyCBD(self.PRF(r, N, self.eta2), self.eta2)
        

        
        _r = []
        for i in range(self.k):
            _r.append(self.NTT(r_[i]))


        u = [
            self.ArrayAddition(self.NTTInverse(reduce(self.ArrayAddition, [
                self.MultiplyNTTs(A[i][j], _r[j])
                for j in range(self.k)
            ])), e1[i])
            for i in range(self.k)
	    ]

        mu2 = self.Decompress(self.ByteDecode(m, 1), 1)

        v = self.ArrayAddition(self.ArrayAddition(self.NTTInverse(self.MatrixMultiplication(t, _r)), e2), mu2)
        
        c1 = b''
        for i in range(self.k):
            c1 += self.ByteArrayToBytes(self.ByteEncode(self.Compress(u[i], self.du), self.du))
            
        c2 = self.ByteArrayToBytes(self.ByteEncode(self.Compress(v, self.dv), self.dv))

        return c1 + c2
    
    def PKEDecrypt(self, dk_pke, c):
        # ATENÇAO A DIVIDIR
        # c1_ = self.BytesToByteArray(c[: 32 * self.du * self.k])
        # print(len(c1_))
        c1 = []
        for i in range(self.k):
            c1.append(self.BytesToByteArray(c[32 * i * self.du: 32 * (i + 1) * self.du]))
            
        c2 = self.BytesToByteArray(c[32 * self.du * self.k: 32 * (self.du * self.k + self.dv)]) 
        
        u = []
        for i in range(self.k):
            u.append(self.Decompress(self.ByteDecode(c1[i], self.du), self.du))

        v = self.Decompress(self.ByteDecode(c2, self.dv), self.dv)        
        
        
        s = []
        for i in range(self.k):
            s.append(self.ByteDecode(dk_pke[i * 384: (i+1) * 384], 12))  
            
        for i in range(self.k):
            u[i] = self.NTT(u[i])
            
        w = self.MatrixMultiplication(s, u)
        w = self.ArraySubtraction(v, self.NTTInverse(w))
        
        m = self.ByteEncode(self.Compress(w, 1), 1)

        return self.ByteArrayToBytes(m)
        
    def keygen(self):
        z = os.urandom(32)
        
        ek_pke, dk_pke = self.PKEKeyGen()
        
        ek = ek_pke
        dk = dk_pke + ek + self.H(ek) + z
        
        return ek, dk       
    
    def encaps(self, ek):
        m = os.urandom(32)
 
        K, r = self.G(m + self.H(ek))
        
        c = self.PKEEncrypt(ek, m, r)
        
        return K, c
    
    def decaps(self, c, dk):
        
        dk_pke = dk[:384 * self.k]
        ek_pke = dk[384 * self.k: 768 * self.k + 32]
        h = dk[768 * self.k + 32: 768 * self.k + 64]
        z = dk[768 * self.k + 64: 768 * self.k + 96]
        
        m_ = self.PKEDecrypt(dk_pke, c)
        
        K_, r_ = self.G(m_ + h)
        
        K = self.J(z + c)
        
        c_ = self.PKEEncrypt(ek_pke, m_, r_)
        
        if c_ != c:
            print('Error')
            K_ = K
        
        return K_
        

In [457]:
kem = KEM(1024)
ek, dk = kem.keygen()

K, c = kem.encaps(ek)

K_ = kem.decaps(c, dk)

print(K.hex())
print(K_.hex())

See https://github.com/sagemath/sage/issues/35473 for details.
  z[i] = mod(round((self.q / Integer(2)**d) * ZZ(z[i])), self.q)


bd13234d500c09307109a4fc74bbcd8af0ba9ad7b0569be93ae0d72c82d19802
bd13234d500c09307109a4fc74bbcd8af0ba9ad7b0569be93ae0d72c82d19802
