In [1]:
from functools import reduce
import copy, hashlib, math, random

In [2]:
Q = 8380417

def H(v, d):
    H_object = hashlib.shake_256(bytes(v)).digest(d // 8)
    return [int(bit) for byte in H_object for bit in f'{byte:08b}']


def H128(v, d):
    H_128_object = hashlib.shake_128(bytes(v)).digest(d // 8)
    return [int(bit) for byte in H_128_object for bit in f'{byte:08b}']


def jth_byte(rho, j, hash):
    hash_object = hash(rho, 8 * (j + 1))
    hash_object_bits = [int(bit) for byte in hash_object for bit in f'{byte:08b}']

    return hash_object_bits[8 * j : 8 * j + 8]


def mod_plus_minus(m, alpha):
    if alpha % 2 == 0:
        lim = alpha // 2

    else:
         lim = (alpha - 1) // 2

    mod = m % alpha
    if mod > lim:
        mod -= alpha

    return mod


def brv(r):
    brv_r = 0
    for i in range(8):
        brv_r |= ((r >> i) & 1) << (7 - i)

    return brv_r


def vector_add(ac, bc):
	return [(x + y) % Q for x, y in zip(ac, bc)]


def vector_sub(ac, bc):
	return [(x - y) % Q for x, y in zip(ac, bc)]


def vector_mult(ac, bc):
    return [(x * y) % Q for x, y in zip(ac, bc)]


def matrix_vector_mult(Ac, bc):
    result = []
    for i in range(len(bc)):
        mid_result = []
        for j in range(len(Ac[i])):
            mid_result.append(vector_mult(Ac[i][j], bc[i]))
        result.append(reduce(vector_add, mid_result))

    return result


def infinity_norm(matrix):
    max = None
    for vector in matrix:
        for elem in vector:
            if max == None:
                max = elem

            else:
                aux = mod_plus_minus(elem, Q)
                if aux > max:
                    max = aux

    return max


def infinity_norm2(matrix):
    max = None
    for vector in matrix:
        for elem in vector:
            if max == None:
                max = elem

            else:
                aux = mod_plus_minus(elem, Q)
                if aux > max:
                    max = aux
    
    print(max)
    return max
        

In [3]:
class MLDSA():

    def __init__(self, tau, lmbda, gamma1, k, l, eta, omega):
        self.d = 13
        self.tau = tau
        self.lmbda = lmbda
        self.gamma1 = gamma1
        self.gamma2 = (Q - 1) // 88
        self.k, self.l = k, l
        self.eta = eta
        self.beta = self.tau * self.eta
        self.omega = omega
        self.zeta = 1753

    
    # 1 - DONE
    def ml_dsa_keygen(self):
        csi = [random.randint(0, 1) for _ in range(256)]
        H_csi = H(csi, 1024)
        rho, rhol, K = H_csi[:256], H_csi[256 : 768], H_csi[768:]

        Ac = self.expand_a(rho)
        s1, s2 = self.expand_s(rhol)
        
        ntt_s1 = [self.ntt(s1_elem) for s1_elem in s1]
        Ac_ntt_s1 = matrix_vector_mult(Ac, ntt_s1)
        ntt_inv_Ac_ntt_s1 = [self.ntt_inv(Ac_ntt_s1_elem) for Ac_ntt_s1_elem in Ac_ntt_s1]
        t = [vector_add(ntt_inv_Ac_ntt_s1[i], s2[i]) for i in range(len(s2))]

        t1, t0 = [], []
        for vector in t:
            r0_vector = []
            r1_vector = []
            for r in vector:
                r1, r0 = self.power_2_round(r)
                r0_vector.append(r0)
                r1_vector.append(r1)

            t1.append(r1_vector)
            t0.append(r0_vector)

        pk = self.pk_encode(rho, t1)
        tr = H(self.bytes_to_bits(pk), 512)
        sk = self.sk_encode(rho, K, tr, s1, s2, t0)

        return pk, sk
    

    # 2
    def ml_dsa_sign(self, sk, M):
        rho, K, tr, s1, s2, t0 = self.sk_decode(sk)
        s1c = [self.ntt(s1_elem) for s1_elem in s1]
        s2c = [self.ntt(s2_elem) for s2_elem in s2]
        t0c = [self.ntt(t0_elem) for t0_elem in t0]
        Ac = self.expand_a(rho)
        mi = H(tr + M, 512)
        rnd = [random.randint(0, 1) for _ in range(256)]
        rhol = H(K + rnd + mi, 512)
        k = 0

        z, h = None, None

        while z == None and h == None:
            y = self.expand_mask(rhol, k)
            ntt_y = [self.ntt(y_elem) for y_elem in y]
            Ac_ntt_y = matrix_vector_mult(Ac, ntt_y)
            w = [self.ntt_inv(Ac_ntt_y_elem) for Ac_ntt_y_elem in Ac_ntt_y]
            
            w1 = []
            for vector in w:
                w1_vector = []
                for elem in vector:
                    w1_vector.append(self.high_bits(elem))
                w1.append(w1_vector)

            ct = H(mi + self.w1_encode(w1), 2 * self.lmbda)
            c1t, c2t = ct[:256], ct[256:]
            c = self.sample_in_ball(c1t)
            cc = self.ntt(c)

            cc_s1c_mult = [vector_mult(s1c_elem, cc) for s1c_elem in s1c]
            cs1 = [self.ntt_inv(cc_s1c_mult_elem) for cc_s1c_mult_elem in cc_s1c_mult]

            cc_s2c_mult = [vector_mult(s2c_elem, cc) for s2c_elem in s2c]
            cs2 = [self.ntt_inv(cc_s2c_mult_elem) for cc_s2c_mult_elem in cc_s2c_mult]

            z = [vector_add(y[i], cs1[i]) for i in range(len(y))]
            w_sub_cs2 = [vector_sub(w[i], cs2[i]) for i in range(len(w))]

            r0 = []
            for vector in w_sub_cs2:
                r0_vector = []
                for elem in vector:
                    r0_vector.append(self.low_bits(elem))
                r0.append(r0_vector)

            print(f'{infinity_norm(r0)} : {self.gamma2 - self.beta}')
            if infinity_norm(z) >= (self.gamma1 - self.beta) and infinity_norm(r0) >= (self.gamma2 - self.beta):
                print('primeiro if')
                z, h = None, None

            else:
                cc_t0c_mult = [vector_mult(t0c_elem, cc) for t0c_elem in t0c]
                ct0 = [self.ntt_inv(cc_t0c_mult_elem) for cc_t0c_mult_elem in cc_t0c_mult]

                m_ct0 = []
                for vector in ct0:
                    m_ct0_aux = []
                    for elem in vector:
                        m_ct0_aux.append(elem * -1)
                    m_ct0.append(m_ct0_aux)

                cs2_ct0 = [vector_add(cs2[i], ct0[i]) for i in range(len(cs2))]
                w_m = [vector_sub(w[i], cs2_ct0[i]) for i in range(len(w))]
                
                h = []
                h_count = 0
                for i in range(len(w_m)):
                    h_vector = []
                    for j in range(len(w_m[i])):
                        aux = self.make_hint(w_m[i][j], m_ct0[i][j])
                        h_vector.append(aux)
                        h_count += aux

                    h.append(h_vector)
                
                print(f'{infinity_norm2(ct0)} : {self.gamma2}')
                print(f'{h_count} : {self.omega}')
                if infinity_norm(ct0) >= self.gamma2 or h_count > self.omega:
                    print('segundo if')
                    z, h = None, None
            
            k += self.l
            print(k)

        Z_modpmQ = []
        for vector in z:
            aux_vector = []
            for elem in vector:
                aux_vector.append(mod_plus_minus(elem, Q))
            Z_modpmQ.append(aux_vector)

        sigma = self.sig_encode(ct, Z_modpmQ, h)

        return sigma
    

    # 3
    def ml_dsa_verify(self, pk, M, sigma):
        rho, t1 = self.pk_decode(pk)
        ct, z, h = self.sig_decode(sigma)

        if h == None:
            return False
        
        Ac = self.expand_a(rho)
        tr = H(self.bytes_to_bits(pk), 512)
        mi = H(tr + M, 512)
        c1t, c2t = ct[:256], ct[256:]
        c = self.sample_in_ball(c1t)

        Ac_z = matrix_vector_mult(Ac, [self.ntt(z_elem) for z_elem in z])
        
        t1_2d = []
        for vector in t1:
            aux = []
            for elem in vector:
                aux.append(elem * (2 ** self.d))
            t1_2d.append(self.ntt(aux))

        c_t1_2d = [vector_mult(t1_2d_elem, self.ntt(c)) for t1_2d_elem in t1_2d]

        ntt_m = [vector_sub(Ac_z[i], c_t1_2d[i]) for i in range(len(Ac_z))]
        wlapprox = [self.ntt_inv(ntt_m_elem) for ntt_m_elem in ntt_m]

        wl1 = []
        count1 = 0
        for i in range(len(wlapprox)):
            wl1_aux = []
            for j in range(len(wlapprox[i])):
                val = h[i][j]
                count1 += val
                wl1_aux.append(self.use_hint(val, wlapprox[i][j]))
            wl1.append(wl1_aux)

        ctl = H(mi + self.w1_encode(wl1), 2 * self.lmbda)

        return (infinity_norm(z) < (self.gamma1 - self.beta)) and (ct == ctl) and (count1 <= self.omega)
    

    # 4 - DONE
    def integer_to_bits(self, x, alpha):
        y = [None for _ in range(alpha)]

        for i in range(alpha):
            y[i] = x % 2
            x //= 2

        return y
    

    # 5 - DONE
    def bits_to_integer(self, y, alpha):
        x = 0
        for i in range(1, alpha + 1):
            x = 2 * x + y[alpha - i]

        return x
    

    # 6 - DONE
    def bits_to_bytes(self, y):
        c = len(y)
        z_len = math.ceil(c // 8)
        z = [0 for _ in range(z_len)]
    
        for i in range(c):
            z[i // 8] += y[i] * 2 ** (i % 8)
    
        return z
    

    # 7 - DONE
    def bytes_to_bits(self, z):
        zz = copy.deepcopy(z)
        d = len(zz)
        y = [0 for _ in range(d * 8)]

        for i in range(d):
            for j in range(8):
                y[8 * i + j] = zz[i] % 2
                zz[i] //= 2

        return y
        

    # 8 - DONE
    def coeff_from_three_bytes(self, b0, b1, b2):
        if b2 > 127:
            b2 -= 128

        z = 2 ** 16 * b2 + 2 ** 8 * b1 + b0
        if z < Q:
            return z

        else:
            return None

    
    # 9 - DONE
    def coeff_from_half_byte(self, b):
        if self.eta == 2 and b < 15:
            return 2 - (b % 5)
        
        else:
            if self.eta == 4 and b < 9:
                return 4 - b  
            else:
                return None
            

    # 10 - DONE
    def simple_bit_pack(self, w, b):
        z = []

        for i in range(256):
            z += self.integer_to_bits(w[i], b.bit_length())

        return self.bits_to_bytes(z)
            

    # 11 - DONE
    def bit_pack(self, w, a, b):
        z = []

        for i in range(256):
            z += self.integer_to_bits(b - w[i], (a + b).bit_length())

        return self.bits_to_bytes(z)


    # 12 - DONE
    def simple_bit_unpack(self, v, b):
        c = b.bit_length()
        z = self.bytes_to_bits(v)
        w = [None for _ in range(256)]

        for i in range(256):
            w[i] = self.bits_to_integer(z[i * c : i * c + c], c)

        return w


    # 13 - DONE
    def bit_unpack(self, v, a, b):
        c = (a + b).bit_length()
        z = self.bytes_to_bits(v)
        w = [None for _ in range(256)]

        for i in range(256):
            w[i] = b - self.bits_to_integer(z[i * c : i * c + c], c)

        return w
    

    # 14 - DONE
    def hint_bit_pack(self, h):
        y = [0 for _ in range(self.omega + self.k)]
        index = 0

        for i in range(self.k):
            for j in range(256):
                if h[i][j] != 0:
                    y[index] = j
                    index += 1
            y[self.omega + i] = index

        return y


    # 15 - DONE
    def hint_bit_unpack(self, y):
        h = [[0 for _ in range(256)] for _ in range(self.k)]
        index = 0

        for i in range(self.k):
            if y[self.omega + i] < index or y[self.omega + i] > self.omega:
                return None

            while index < y[self.omega + i]:
                h[i][y[index]] = 1
                index += 1

        while index < self.omega:
            if y[index] != 0:
                return None
            index += 1

        return h
    

    # 16 - DONE
    def pk_encode(self, rho, t1):
        pk = self.bits_to_bytes(rho)
        
        for i in range(self.k):
            pk += self.simple_bit_pack(t1[i], 2 ** ((Q - 1).bit_length() - self.d) - 1)

        return pk
    

    # 17 - DONE
    def pk_decode(self, pk):
        y = pk[:32]

        pk_z = pk[32:]
        chunk_size = len(pk_z) // self.k
        z = [pk_z[i : i + chunk_size] for i in range(0, len(pk_z), chunk_size)]

        t1 = [None for _ in range(self.k)]

        rho = self.bytes_to_bits(y)

        for i in range(self.k):
            t1[i] = self.simple_bit_unpack(z[i], 2 ** ((Q - 1).bit_length() - self.d) - 1)

        return rho, t1


    # 18 - DONE
    def sk_encode(self, rho, K, tr, s1, s2, t0):
        sk = self.bits_to_bytes(rho) + self.bits_to_bytes(K) + self.bits_to_bytes(tr)

        for i in range(self.l):
            sk += self.bit_pack(s1[i], self.eta, self.eta)

        for i in range(self.k):
            sk += self.bit_pack(s2[i], self.eta, self.eta)

        for i in range(self.k):
            sk += self.bit_pack(t0[i], 2 ** (self.d - 1) - 1, 2 ** (self.d - 1))

        return sk
    

    # 19 - DONE
    def sk_decode(self, sk):
        f, g, h = sk[:32], sk[32 : 64], sk[64 : 128]
        sk_y_len = 32 * (2 * self.eta).bit_length() * self.l
        sk_y = sk[128 : 128 + sk_y_len]
        sk_z_len = 32 * (2 * self.eta).bit_length() * self.k
        sk_z = sk[128 + sk_y_len : 128 + sk_y_len + sk_z_len]
        sk_w_len = 32 * self.d * self.k
        sk_w = sk[128 + sk_y_len + sk_z_len : 128 + sk_y_len + sk_z_len + sk_w_len]

        y = [sk_y[i : i + len(sk_y) // self.l] for i in range(0, len(sk_y), len(sk_y) // self.l)]
        z = [sk_z[i : i + len(sk_z) // self.k] for i in range(0, len(sk_z), len(sk_z) // self.k)]
        w = [sk_w[i : i + len(sk_w) // self.k] for i in range(0, len(sk_w), len(sk_w) // self.k)]

        rho = self.bytes_to_bits(f)
        K = self.bytes_to_bits(g)
        tr = self.bytes_to_bits(h)

        s1 = [None for _ in range(self.l)]
        for i in range(self.l):
            s1[i] = self.bit_unpack(y[i], self.eta, self.eta)
        
        s2 = [None for _ in range(self.k)]
        for i in range(self.k):
            s2[i] = self.bit_unpack(z[i], self.eta, self.eta)
        
        t0 = [None for _ in range(self.k)]
        for i in range(self.k):
            t0[i] = self.bit_unpack(w[i], 2 ** (self.d - 1) - 1, 2 ** (self.d - 1))

        return rho, K, tr, s1, s2, t0
        
    
    # 20 - DONE
    def sig_encode(self, ct, z, h):
        sigma = self.bits_to_bytes(ct)

        for i in range(self.l):
            sigma += self.bit_pack(z[i], self.gamma1 - 1, self.gamma1)

        sigma += self.hint_bit_pack(h)

        return sigma
    

    # 21 - DONE
    def sig_decode(self, sigma):
        w = sigma[: self.lmbda // 4]
        sigma_x_len = self.l * 32 * (1 + (self.gamma1 - 1).bit_length())
        sigma_x = sigma[self.lmbda // 4 : self.lmbda // 4 + sigma_x_len]
        sigma_y_len = self.omega + self.k
        sigma_y = sigma[self.lmbda // 4 + sigma_x_len : self.lmbda // 4 + sigma_x_len + sigma_y_len]

        x = [sigma_x[i : i + len(sigma_x) // self.l] for i in range(0, len(sigma_x), len(sigma_x) // self.l)]

        ct = self.bytes_to_bits(w)

        z = [None for _ in range(self.l)]
        for i in range(self.l):
            z[i] = self.bit_unpack(x[i], self.gamma1 - 1, self.gamma1)

        h = self.hint_bit_unpack(sigma_y)

        return ct, z, h


    # 22 - DONE
    def w1_encode(self, w1):
        w1t = []

        for i in range(self.k):
            w1t += self.bytes_to_bits(self.simple_bit_pack(w1[i], int((Q - 1) / (2 * self.gamma2) - 1)))

        return w1t
    

    # 23 - DONE
    def sample_in_ball(self, rho):
        c = [0 for _ in range(256)]
        k = 8

        for i in range(256 - self.tau, 256):
            while self.bits_to_bytes(jth_byte(rho, k, H))[0] > i:
                k += 1

            j = self.bits_to_bytes(jth_byte(rho, k, H))[0]
            c[i] = c[j]
            c[j] = -1 ** H(rho, 8 * (i + self.tau - 256 + 1))[i + self.tau - 256]
            k += 1

        return c
    

    # 24 - DONE
    def rej_ntt_poly(self, rho):
        j = 0
        c = 0
        ac = [None for _ in range(256)]

        while j < 256:
            H_128_c = self.bits_to_bytes(jth_byte(rho, c, H128))[0]
            H_128_c1 = self.bits_to_bytes(jth_byte(rho, c + 1, H128))[0]
            H_128_c2 = self.bits_to_bytes(jth_byte(rho, c + 2, H128))[0]
            ac[j] = self.coeff_from_three_bytes(H_128_c, H_128_c1, H_128_c2)

            c += 3
            if ac[j] != None:
                j += 1

        return ac
    

    # 25 - DONE
    def rej_bounded_poly(self, rho):
        j = 0
        c = 0
        a = [None for _ in range(256)]

        while j < 256:
            z = self.bits_to_bytes(jth_byte(rho, c, H))[0]
            z0 = self.coeff_from_half_byte(z % 16)
            z1 = self.coeff_from_half_byte(z // 16)

            if z0 != None:
                a[j] = z0
                j += 1

            if z1 != None and j < 256:
                a[j] = z1
                j += 1

            c += 1

        return a
    

    # 26 - DONE
    def expand_a(self, rho):
        Ac = [[None for _ in range(self.l)] for _ in range(self.k)]

        for r in range(self.k):
            for s in range(self.l):
                Ac[r][s] = self.rej_ntt_poly(rho + self.integer_to_bits(s, 8) + self.integer_to_bits(r, 8))

        return Ac
    

    # 27 - DONE
    def expand_s(self, rho):
        s1 = [None for _ in range(self.l)]
        s2 = [None for _ in range(self.k)]

        for r in range(self.l):
            s1[r] = self.rej_bounded_poly(rho + self.integer_to_bits(r, 16))

        for r in range(self.k):
            s2[r] = self.rej_bounded_poly(rho + self.integer_to_bits(r + self.l, 16))

        return s1, s2


    # 28 - DONE
    def expand_mask(self, rho, mu):
        c = 1 + (self.gamma1 - 1).bit_length()
        s = [None for _ in range(self.l)]

        for r in range(self.l):
            n = self.integer_to_bits(mu + r, 16)
            v = [self.bits_to_bytes(jth_byte(rho + n, 32 * r * c + i, H))[0] for i in range(32 * c)]
            s[r] = self.bit_unpack(v, self.gamma1 - 1, self.gamma1)

        return s


    # 29 - DONE
    def power_2_round(self, r):
        rp = r % Q
        r0 = mod_plus_minus(rp, 2 ** self.d)

        return int((rp - r0) / 2 ** self.d), r0
    

    # 30 - DONE
    def decompose(self, r):
        rp = r % Q
        r0 = mod_plus_minus(rp, 2 * self.gamma2)

        if rp - r0 == Q - 1:
            r1 = 0
            r0 -= 1

        else:
            r1 = (rp - r0) / (2 * self.gamma2)

        return int(r1), int(r0)
    

    # 31 - DONE
    def high_bits(self, r):
        r1, r0 = self.decompose(r)
        return r1
    

    # 32 - DONE
    def low_bits(self, r):
        r1, r0 = self.decompose(r)
        return r0
    

    # 33 - DONE
    def make_hint(self, r, z):
        r1 = self.high_bits(r)
        v1 = self.high_bits(r + z)

        if r1 != v1:
            return 1
        
        return 0
    

    # 34 - DONE
    def use_hint(self, h, r):
        m = (Q - 1) // (2 * self.gamma2)
        r1, r0 = self.decompose(r)

        if h == 1 and r0 > 0:
            return (r1 + 1) % m
        
        if h == 1 and r0 <= 0:
            return (r1 - 1) % m
        
        return r1


    # 35 - DONE
    def ntt(self, w):
        wc = [None for _ in range(256)]
        for j in range(256):
            wc[j] = w[j]

        k = 0
        len = 128

        while len >= 1:
            start = 0
            while start < 256:
                k += 1
                zeta = pow(self.zeta, brv(k), Q)
                for j in range(start, start + len):
                    t = (zeta * wc[j + len]) % Q
                    wc[j + len] = (wc[j] - t) % Q
                    wc[j] = (wc[j] + t) % Q

                start += 2 * len
            
            len //= 2

        return wc
    

    # 36 - DONE
    def ntt_inv(self, wc):
        w = [None for _ in range(256)]
        for j in range(256):
            w[j] = wc[j]

        k = 256
        len = 1

        while len < 256:
            start = 0
            while start < 256:
                k -= 1
                zeta = -pow(self.zeta, brv(k), Q)

                for j in range(start, start + len):
                    t = w[j]
                    w[j] = (t + w[j + len]) % Q
                    w[j + len] = (t - w[j + len]) % Q
                    w[j + len] = (w[j + len] * zeta) % Q

                start += 2 * len

            len *= 2

        f = 8347681
        for j in range(256):
            w[j] = (w[j] * f) % Q

        return w

In [None]:
mldsa = MLDSA(39, 128, 2 ** 17, 4, 4, 2, 80)
pk, sk = mldsa.ml_dsa_keygen()

print(len(pk))
print(len(sk))

M = [0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1]
sigma = mldsa.ml_dsa_sign(sk, M)

a = mldsa.ml_dsa_verify(pk, M, sigma)
print(a)

1312
2560
94892 : 95154
8376019
8376019 : 95232
64 : 80
segundo if
4
95166 : 95154
primeiro if
8
