# **Estruturas criptograficas: TP3 problema 2**

In [73]:
def ntt_transform(f, s):
    
    # Calcula a transformada NTT de f
    trans = [f(s_i) for s_i in s]
    
    return trans

def ntt_inverse(trans, mu):

    # Calcula a transformada inversa
    f_inv = [mu[i] * trans[i] for i in range(len(trans))]

    return sum(f_inv)

def raizes(omega, N):
    
    s = [omega.nth_root(2) * omega^i for i in range(N)]
    
    return s

def base_mu(x, s, N):
    
    mods = [x - s[i] for i in range(N)]

    mu = CRT_basis(mods)
    
    return mu
    

# Escolha de valores
q = 3329
N = 256
omega = primitive_root(q)^2

# Polinômio de exemplo
f = PolynomialRing(GF(q), 'w')
w = f.gen()
# f = 1 + x -2*x^2 - x^3


# Cálculo das raízes
ra = raizes(omega, N)

mu = base_mu(w, ra, N)


In [74]:
# Operações sobre matrizes e vetores
# Soma de matrizes
def sumMatrix(e1, e2, n):
    for i in range(len(e1)):
        e1[i] = sumVector(e1[i], e2[i], n)
    return e1

# Subtração de matrizes
def subMatrix(e1, e2, n):
    for i in range(len(e1)):
        e1[i] = subVector(e1[i], e2[i], n)
    return e1

# Multiplicação de matrizes
def multMatrix(vec1, vec2, n):
    for i in range(len(vec1)):
        vec1[i] = multVector(vec1[i], vec2[i],n)
    tmp = [0] * n
    for i in range(len(vec1)):
        tmp = sumVector(tmp, vec1[i], n)
    return tmp

# Multiplicação de uma matriz por um vector
def multMatrixVector(M, v, k, n) :
    for i in range(len(M)):
        for j in range(len(M[i])):
            M[i][j] = multVector(M[i][j], v[j], n)
    tmp = [[0] * n] * k 
    for i in range(len(M)):
        for j in range(len(M[i])):
            tmp[i] = sumVector(tmp[i], M[i][j],n)
    return tmp

# Soma de vetores
def sumVector(ff1, ff2, n):
    res = []
    for i in range(n):
        res.append((ff1[i] + ff2[i]))
    return res

# Multiplicação de vetores
def multVector(ff1, ff2, n):
    res = []
    for i in range(n):
        res.append((ff1[i] * ff2[i]))
    return res

# Subtração de vetores
def subVector(ff1, ff2, n):
    res = []
    for i in range(n):
        res.append((ff1[i] - ff2[i]))
    return res

In [75]:
import os
from hashlib import sha3_256, sha3_512, shake_128, shake_256
from cryptography.hazmat.primitives import hashes

class KEM():
    
    def __init__(self, security_bits=512):
        
        if security_bits == 512:
            self.n = 256
            self.q = 3329
            self.k = 2
            self.eta1 = 3
            self.eta2 = 2
            self.du = 10
            self.dv = 4
        elif security_bits == 768:
            self.n = 256
            self.q = 3329
            self.k = 3
            self.eta1 = 2
            self.eta2 = 2
            self.du = 10
            self.dv = 4
        elif security_bits == 1024:
            self.n = 256
            self.q = 3329
            self.k = 4
            self.eta1 = 2
            self.eta2 = 2
            self.du = 11
            self.dv = 5
        
        # Zq.<w> = GF(self.q)[]
        # fi = w^self.n + 1
        # Rq.<w> = QuotientRing(Zq, Zq.ideal(fi))
        f = PolynomialRing(GF(q), 'w')
        self.Rq = f
    
    def G(self, c):
        c_hash = sha3_512(c).digest()
        a = c_hash[:32]
        b = c_hash[32:]
        return a, b
    
    def H(self, m):
        return sha3_256(m).digest()
    
    # variable output length B32*B*B
    def XOF(self, rho, i, j):
        temp = shake_128(rho + int.to_bytes(i) + int.to_bytes(j)).digest(self.q)
        # print(len(temp))
        return temp
    
    def SampleNTT(self, B):
        i = 0
        j = 0
        a = [[0] for _ in range(256)]
        while j < 256:
            d1 = B[i] + 256 * (B[i + 1] % 16)
            d2 = floor(B[i + 1] / 16) + 16 * B[i + 2]
            if d1 < self.q:
                a[j] = d1
                j += 1
            if d2 < self.q and j < 256:
                a[j] = d2
                j += 1
            i += 3
        return self.Rq(a)
    
    def PRF(self, s, b, eta):
        return shake_256(s + bytes(b)).digest(64 * eta)
    
    def bbb(self, byteArray):
        bitArray = []
        for elem in byteArray:
            bitElemArr = []
            for i in range(0,8): 
                bitElemArr.append(mod(elem//2**(mod(i,8)),2))
                for i in range(0,len(bitElemArr)):
                    bitArray.append(bitElemArr[i])
        return bitArray
    
    def BytesToBits(self, B):
        b = [0] * len(B) * 8 
        B = self.BytesToByteArray(B)
        for i in range(len(B)):
            for j in range(0,8):
                b[8*i+j] = mod(B[i], 2)
                B[i] = B[i] // 2
                
        return b
            
    def BitsToBytes(self, b):
        l = len(b) // 8
        B = [0] * l
        for i in range(0,8*l):
            B[i // 8] += ZZ(b[i]) * 2^(mod(i,8))
        return B
    
    def ByteArrayToBytes(self, B):
        Bytes = b''
        for i in range(len(B)):
            Bytes += int(B[i]).to_bytes(1, 'little')
        
        return Bytes
    
    def BytesToByteArray(self, Bytes):
        B = []
        # print(Bytes)
        for i in range(len(Bytes)):
            # print(int.to_bytes(Bytes[i], 1, 'little'))
            B.append(Bytes[i])
        
        return B    
    
    
    def SamplePolyCBD(self, B, eta):
        b = self.BytesToBits(B)
        
        f = [0] * 256
        
        for i in range(256):
            x = 0
            y = 0
            for j in range(eta):
                x += b[2 * i * eta + j]
            for j in range(eta):
                y += b[2 * i * eta + eta + j]   
                
            f[i] = mod((x - y), self.q)
            
        return self.Rq(f)
     
    def ByteEncode(self, F, d):
        
        if d < 12:
            m = 2^d
        elif d == 12:
            m = self.q
            
        b = [0] * (256 * d)
        for i in range(256):
            a = mod(F[i], 2^d)
            # print(a)
            for j in range(d):
                b[i*d + j] = a % 2
                # print(b[i*d + j])
                # print(ZZ(a) - ZZ(b[i*d + j]))
                a = (ZZ(a) - ZZ(b[i*d + j])) / 2
        
        B = self.BitsToBytes(b)
        return B
        
              
                
     
    def PKEKeyGen(self):
        d = os.urandom(32)
        rho, sigma = self.G(d)
        N = 0
        
        A = [[0] * self.k for _ in range(self.k)]
        # print(A)
        for i in range(self.k):
            for j in range(self.k):
                A[i][j] = self.SampleNTT(self.XOF(rho, i, j))
                
        
        s = [[0] for _ in range(self.k)]
        for i in range(self.k):
            s[i] = self.SamplePolyCBD(self.PRF(sigma, N, self.eta1), self.eta1)
            N += 1
            
        e = [[0] for _ in range(self.k)]
        for i in range(self.k):
            e[i] = self.SamplePolyCBD(self.PRF(sigma, N, self.eta1), self.eta1)
            N += 1

        
        _s, _e = [], []
        
        for i in range(self.k):
            _s.append(ntt_transform(s[i], ra))
            _e.append(ntt_transform(e[i], ra))
        
        
        t = sumMatrix(multMatrixVector(A,_s,self.k,self.n), _e, self.n)
        
        
        ek_pke = b''
        for i in range(self.k):
            ek_pke += self.ByteArrayToBytes(self.ByteEncode(t[i], 12))
        
        ek_pke += rho
        
        dk_pke = b''
        for i in range(self.k):
            dk_pke += self.ByteArrayToBytes(self.ByteEncode(_s[i], 12))

        
        return ek_pke, dk_pke
        
    
    def keygen(self):
        z = os.urandom(32)
        
        ek_pke, dk_pke = self.PKEKeyGen()
        
        ek = ek_pke
        dk = dk_pke + ek + self.H(ek) + z
        
        return ek, dk       
    
    def encaps(self, ek):
        m = os.urandom(32)
        
         
        

In [76]:
kem = KEM()
key = kem.keygen()