# ML-KEM (FIPS 203)

Implementation of ML-KEM in Python (we will considre ML-KEM-512 parameters)

In [None]:
n = 256
q = 3329
k = 2
eta1 = 3
eta2 = 2
du = 10
dv = 4

## Auxiliary Algorithms

### Cryptographic Functions

Implementation of pseudorandom function $\texttt{PRF}$ and hash functions $\texttt{H}$, $\texttt{J}$ and $\texttt{G}$.

In [1]:
import secrets

d = 32  # Example byte length
secrets.token_bytes(d)

b'&\x1d.\xbe\xb8\xec]\xe9\x8a\x0f\x08\xb7\xb6f0t8\xec\x89W \x98\x17\x02\x9azv}1A\xc0\xf4'

In [955]:
import hashlib

def PRF(s, b, eta):
    shake = hashlib.shake_256()
    shake.update(s+b)
    return shake.digest(64*eta)

In [956]:
import random

s = random.randbytes(32)
b = random.randbytes(1)

r = PRF(s,b,eta1)
print(len(r))

192


In [957]:
def H(s):
    return hashlib.sha3_256(s).digest()

def J(s):
    shake = hashlib.shake_256()
    shake.update(s)
    return shake.digest(32)

def G(c):
    d = hashlib.sha3_512(c).digest()
    return d[:32], d[32:]

In [958]:
print(H(s))
print(J(s))
print(G(s))

b'I\x01\xe6\xf2k)\xe5\x04,\xb1g\x9c\xa7\xc76;j\xe1\xb4]\x97\x82\r\xba66H\x88\xf5\x1d\x87\x8e'
b'<\xc5t\xfd\x18\xce\x19\xec\xf6x\xcfx3[75\xf7\xf7\xf4G\x87 -\xa1)\x9d?8\x1a1\x14`'
(b'\x80\xe9\xebv\x0f\x95,\x1f\x83\xd9Y\x19\xb1vY)\xc07f*\xa0\x94\xf5\xd0\x8a\xd5\xd0\x90\xa7C\x91"', b'*Z\x17GS\x7f\xdb\x9f]9\xf7fg\xe5B\x0e;M\x8d\xef\xba\xa8\x86\x93\x95\x88>\xadR\x06\xb8\n')


### General Algorithms

#### Conversion and Compression Algorithms

Implementation of conversion and compression algorithms of Kyber such as $\texttt{encode}$, $\texttt{decode}$, $\texttt{compress}$ and $\texttt{decompress}$.

In [959]:
def compress(x, d):
    return round((2**d/q) * x) % 2**d

def decompress(y, d):
    return round(q/2**d * y)

In [960]:
x = 2675
print(f"x = {x} - ", bin(x)[2:], f"[{len(bin(x)[2:])}]")
for d in range(11,5,-1):
    cx = compress(x,d)
    print(f"compress(x,{d}) = {cx} -", bin(cx)[2:], f"[{len(bin(cx)[2:])}]")

x = 2675 -  101001110011 [12]
compress(x,11) = 1646 - 11001101110 [11]
compress(x,10) = 823 - 1100110111 [10]
compress(x,9) = 411 - 110011011 [9]
compress(x,8) = 206 - 11001110 [8]
compress(x,7) = 103 - 1100111 [7]
compress(x,6) = 51 - 110011 [6]


In [961]:
for d in range(11,5,-1):
    cx = compress(x,d)
    dx = decompress(cx,d)
    print(f"decompress(compress(x,{d}),{d}) = {dx} -", bin(dx)[2:], f"[{len(bin(dx)[2:])}]")

decompress(compress(x,11),11) = 2676 - 101001110100 [12]
decompress(compress(x,10),10) = 2676 - 101001110100 [12]
decompress(compress(x,9),9) = 2672 - 101001110000 [12]
decompress(compress(x,8),8) = 2679 - 101001110111 [12]
decompress(compress(x,7),7) = 2679 - 101001110111 [12]
decompress(compress(x,6),6) = 2653 - 101001011101 [12]


In [962]:
import math

def BitsToBytes(b):
    l = int(len(b)/8)
    B = [0] * l
    for i in range(len(b)):
        B[math.floor(i/8)] = B[math.floor(i/8)] + b[i]*2**(i%8)
    return bytes(B)

def BytesToBits(B):
    B = list(B)
    l = len(B)
    C = list(B)
    b = [0] * (l*8)
    for i in range(l):
        for j in range(8):
            b[i*8 + j] = C[i] % 2
            C[i] = math.floor(C[i]/2)  
    return b
    
def ByteEncode(F, d):
    b = [0] * (256 * d)
    for i in range(256):
        a = F[i]
        for j in range(d):
            b[i*d + j] = int(a % 2)
            a = (a - b[i*d + j]) / 2  
    B = BitsToBytes(b) 
    return B

def ByteDecode(B, d):
    m = q if d == 12 else 2**d
    
    b = BytesToBits(B)
    F = [0] * 256
    for i in range(256):
        sum = 0
        for j in range(d):
            sum = sum + b[i*d + j] * 2**j % m
        F[i] = int(sum)
    return F

In [963]:
F = [random.randint(0,1) for _ in range(256)]
print(F)

B = ByteEncode(F, 1)
print(B)

FF = ByteDecode(B, 1)
print(FF)

print(F == FF)


[1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0]
b'\x19]\x1a\xe6\xd8B\xfa\xf4\x12+\x9e\x83\x04Cp9\x83y\xa4\xc1\x07E\x03\xa5V\xda\xa4a64\xfb\x13'
[1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0,

#### Sampling algorithms

Now, let's try to implement sampling algorithms $\texttt{SampleNTT}$ and $\texttt{SamplePolyCBD}$.

In [964]:
def SampleNTT(B):
    shake = hashlib.shake_128()
    shake.update(B)
    cnt = 1
    j = 0
    a = [0] * 256 
    while j<256:
        digest = shake.digest(cnt*3)
        C = digest[-3:]
        d1 = C[0] + (256 * (C[1] % 16))
        d2 = math.floor(C[1]/16) + 16*C[2]
        if d1 < q:
            a[j] = d1
            j += 1
        if d2 < q and j < 256:
            a[j] = d2
            j += 1
        cnt += 1
    return a

In [965]:
B = random.randbytes(34)
a_NTT = SampleNTT(B)
print(a_NTT)

[2986, 3090, 1910, 3011, 711, 751, 235, 388, 650, 202, 2480, 381, 1594, 2251, 197, 1179, 264, 882, 2373, 1915, 257, 972, 2006, 1529, 2335, 2915, 1304, 399, 2970, 2920, 2660, 1204, 3290, 2172, 674, 244, 2928, 361, 405, 2143, 957, 2031, 2226, 3245, 668, 2098, 527, 2478, 2794, 1850, 1298, 2032, 1015, 2879, 1839, 2939, 377, 768, 156, 1831, 603, 2482, 1152, 1888, 3149, 877, 536, 460, 2509, 2861, 914, 3219, 1803, 3130, 243, 1712, 2342, 1728, 2356, 2787, 2706, 2923, 232, 1113, 1020, 1931, 3135, 735, 681, 562, 2091, 3130, 1828, 2125, 3052, 2371, 2184, 2770, 2223, 25, 2305, 362, 192, 336, 2095, 2048, 243, 2375, 1234, 427, 275, 1087, 196, 511, 534, 2225, 2379, 3026, 2154, 890, 2039, 733, 528, 2873, 2816, 435, 2395, 16, 2368, 2465, 2629, 3034, 1720, 918, 2448, 2932, 268, 3199, 2867, 68, 961, 378, 2998, 1021, 876, 447, 2779, 1362, 1736, 291, 325, 359, 320, 1729, 2504, 1263, 2208, 2703, 1810, 928, 2607, 1391, 1616, 2993, 1339, 2172, 2559, 286, 2335, 1597, 1746, 3315, 1786, 2816, 2561, 231, 2460, 21

In [966]:
def SamplePolyCBD(B, eta):
    b = BytesToBits(B)
    f = [0] * 256
    for i in range(256):
        x = 0
        y = 0
        for j in range(eta):
            x += b[2*i*eta + j]
            y += b[2*i*eta + eta + j]
        f[i] = x - y % q
    return f

In [967]:
eta = 2
B = random.randbytes(64*eta)
s = SamplePolyCBD(B, eta)
print(s)

[1, -1, -1, -1, -2, 0, 0, -2, 1, -2, 1, -1, 0, -1, 0, 0, -1, 0, 0, 0, 2, -1, 1, 1, -2, 1, 1, 1, -1, 0, 1, 0, 0, 0, 0, 1, -2, 0, -1, -1, -1, -1, -2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, -2, 0, 0, 1, 1, 2, 0, -2, 0, -1, 1, 1, 0, 0, -1, 1, -1, 0, 0, -1, 0, 0, -1, 0, 2, 1, -1, 0, 1, -2, -2, 0, 0, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, -2, -2, -1, 0, 0, 0, -1, -2, 0, -1, 0, 0, 1, 1, 1, 2, 0, -1, -1, -1, 0, 0, -1, 0, 1, 2, 0, 0, 0, 2, -1, -1, -1, 0, 0, -1, 0, 0, 0, -2, 0, 1, 0, 1, -1, 0, 0, 1, 1, 1, -1, 0, 0, 1, 1, 1, 0, -1, 0, 1, 2, 0, 0, 0, 2, 0, 0, -1, 0, 2, -2, 0, -2, 1, 0, -1, 0, 1, -1, 1, 2, -1, 1, -1, -1, 0, -2, 2, -1, 0, 1, 0, -1, 0, -1, 0, -2, -2, 1, 0, -2, 0, 1, 0, -2, -1, 0, 1, -2, -1, -1, 0, 0, 0, 0, -1, -1, -1, -2, 1, 1, 0, -1, -1, 0, 1, 0, 0, -1, -2, 0, -1, 0, -1, -1, 1, -1, -1, 1, 0, -2, 0, 0, 0, 2, 0, -1, 0, 2, 1, 1, -2, 1, 0, -2, -1, 0, 0, 0]


### The Number-Theoric Transform

In [968]:
# BitRev7(i)
BitRev7 = [
    1, 1729, 2580, 3289, 2642, 630, 1897, 848,
    1062, 1919, 193, 797, 2786, 3260, 569, 1746,
    296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
    1426, 2094, 535, 2882, 2393, 2879, 1974, 821,
    289, 331, 3253, 1756, 1197, 2304, 2277, 2055,
    650, 1977, 2513, 632, 2865, 33, 1320, 1915,
    2319, 1435, 807, 452, 1438, 2868, 1534, 2402,
    2647, 2617, 1481, 648, 2474, 3110, 1227, 910,
    17, 2761, 583, 2649, 1637, 723, 2288, 1100,
    1409, 2662, 3281, 233, 756, 2156, 3015, 3050,
    1703, 1651, 2789, 1789, 1847, 952, 1461, 2687,
    939, 2308, 2437, 2388, 733, 2337, 268, 641,
    1584, 2298, 2037, 3220, 375, 2549, 2090, 1645,
    1063, 319, 2773, 757, 2099, 561, 2466, 2594,
    2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
    1722, 1212, 1874, 1029, 2110, 2935, 885, 2154
]

# 2*BitRev7(i) + 1
TwoBitRev7Plus1 = [
    17, -17, 2761, -2761, 583, -583, 2649, -2649,
    1637, -1637, 723, -723, 2288, -2288, 1100, -1100,
    1409, -1409, 2662, -2662, 3281, -3281, 233, -233,
    756, -756, 2156, -2156, 3015, -3015, 3050, -3050,
    1703, -1703, 1651, -1651, 2789, -2789, 1789, -1789,
    1847, -1847, 952, -952, 1461, -1461, 2687, -2687,
    939, -939, 2308, -2308, 2437, -2437, 2388, -2388,
    733, -733, 2337, -2337, 268, -268, 641, -641,
    1584, -1584, 2298, -2298, 2037, -2037, 3220, -3220,
    375, -375, 2549, -2549, 2090, -2090, 1645, -1645,
    1063, -1063, 319, -319, 2773, -2773, 757, -757,
    2099, -2099, 561, -561, 2466, -2466, 2594, -2594,
    2804, -2804, 1092, -1092, 403, -403, 1026, -1026,
    1143, -1143, 2150, -2150, 2775, -2775, 886, -886,
    1722, -1722, 1212, -1212, 1874, -1874, 1029, -1029,
    2110, -2110, 2935, -2935, 885, -885, 2154, -2154
]

In [969]:
def NTT(f):
    f_ntt = f
    i = 1
    length = 128
    while length >= 2:
        start = 0
        while start < 256:
            zeta = BitRev7[i] % q
            i += 1
            for j in range(start, start+length, 1):
                t = (zeta * f_ntt[j+length]) % q
                f_ntt[j+length] = (f_ntt[j] - t) % q
                f_ntt[j] = (f_ntt[j] + t) % q
            start += 2*length
        length //= 2 
    return f_ntt

def inv_NTT(f_ntt):
    f = f_ntt
    i = 127
    length = 2
    while length <= 128:
        start = 0
        while start < 256:
            zeta = BitRev7[i] % q
            i -= 1
            for j in range(start, start+length, 1):
                t = f[j] % q
                f[j] = (t + f[j+length]) % q
                f[j+length] = (zeta * (f[j+length]-t)) % q
            start += 2*length 
        length *= 2
    for i in range(256):
        f[i] = (f[i]*3303) % q
    return f

In [970]:
f = SamplePolyCBD(B, eta)
print("f =", f[:30])

f_ntt = NTT(f)
print("f_ntt =", f_ntt[:30])

ff = inv_NTT(f_ntt)
print("ff =", ff[:30])

f = [1, -1, -1, -1, -2, 0, 0, -2, 1, -2, 1, -1, 0, -1, 0, 0, -1, 0, 0, 0, 2, -1, 1, 1, -2, 1, 1, 1, -1, 0]
f_ntt = [438, 439, 908, 1968, 2531, 1697, 1522, 670, 635, 1916, 593, 1835, 1781, 3046, 3130, 1010, 504, 484, 1654, 2035, 321, 2920, 2251, 774, 2049, 693, 583, 116, 2036, 1494]
ff = [1, 3328, 3328, 3328, 3327, 0, 0, 3327, 1, 3327, 1, 3328, 0, 3328, 0, 0, 3328, 0, 0, 0, 2, 3328, 1, 1, 3327, 1, 1, 1, 3328, 0]


In [971]:
def BaseCaseMultiply(a0, a1, b0, b1, gamma):
    c = [0] * 2
    c[0] = (a0*b0 + a1*b1*gamma) % q
    c[1] = (a0*b1 + a1*b0) % q
    return c

def MultiplyNTTs(f, g):
    h = [0] * 256
    for i in range(128):
        h[2*i],h[2*i+1] = BaseCaseMultiply(f[2*i], f[2*i+1], g[2*i], g[2*i+1], TwoBitRev7Plus1[i])
    return h

In [972]:
B = random.randbytes(34)
a = SampleNTT(B)

B = random.randbytes(34)
b = SampleNTT(B)

c = MultiplyNTTs(a, b)
print("c =", c)

c = [1325, 1663, 135, 1430, 1437, 3244, 1509, 2649, 2278, 2482, 1834, 2122, 2044, 627, 393, 2756, 1193, 420, 2419, 510, 542, 2231, 501, 2822, 2329, 2085, 656, 1513, 1450, 1500, 3164, 1475, 2202, 3218, 2628, 1994, 1625, 335, 3193, 3310, 2074, 607, 828, 1899, 553, 1233, 2592, 1962, 486, 1721, 1818, 2588, 1783, 628, 2110, 2211, 2970, 542, 64, 1056, 1464, 763, 185, 1455, 3289, 1036, 2320, 525, 2740, 3326, 1246, 1003, 800, 3109, 342, 785, 107, 526, 604, 1096, 3081, 288, 1766, 2179, 917, 1905, 1057, 3094, 1705, 3094, 3031, 1780, 734, 3091, 2401, 1383, 2939, 1359, 1651, 2404, 988, 1130, 2697, 1136, 2996, 1531, 1012, 1547, 1828, 317, 577, 1201, 2478, 780, 1347, 398, 2284, 2002, 405, 2114, 1554, 1147, 1631, 2199, 3047, 37, 1427, 1362, 905, 597, 2165, 1946, 594, 2032, 665, 1363, 289, 2064, 3282, 3250, 396, 1030, 3095, 916, 3073, 1443, 1044, 1142, 1222, 1847, 2719, 2955, 2694, 1596, 1716, 871, 1918, 1467, 2226, 2932, 2722, 2207, 2379, 2432, 610, 2857, 1794, 2096, 2374, 692, 2782, 2102, 71, 1817, 

## The K-PKE Compopnent Scheme

Implementation of the internal PKE algorithms $\texttt{K-PKE.KeyGen}$, $\texttt{K-PKE.Encrypt}$ and $\texttt{K-PKE.Decrypt}$.

In [973]:
def AddPolynomials(a, b):
     return [(a[i] + b[i]) % q for i in range(256)]

def SubPolynomials(a, b):
     return [(a[i] - b[i]) % q for i in range(256)]

In [974]:
def PKE_KeyGen(d):
    rho, sigma = G(d + bytes([k]))

    N = 0
    A_ntt = [[0] * k for _ in range(k)]
    for i in range(k):
        for j in range(k):
            A_ntt[i][j] = SampleNTT(rho + bytes([j]) + bytes([i]))
    
    s_ntt = [0] * k
    for i in range(k):
        s = SamplePolyCBD(PRF(sigma, bytes([N]), eta1), eta1)
        s_ntt[i] = NTT(s)
        N += 1

    e_ntt = [0] * k
    for i in range(k):
        e = SamplePolyCBD(PRF(sigma, bytes([N]), eta1), eta1)
        e_ntt[i] = NTT(e)
        N += 1

    t_ntt = [0] * k
    for i in range(k):
        sum = [0] * 256
        for j in range(k):
            prod = MultiplyNTTs(A_ntt[i][j], s_ntt[j])
            sum = AddPolynomials(sum, prod)
        t_ntt[i] = AddPolynomials(sum, e_ntt[i])

    ek_pke = b""
    dk_pke = b""
    for i in range(k):
        ek_pke += ByteEncode(t_ntt[i], 12)
        dk_pke += ByteEncode(s_ntt[i], 12)
    ek_pke += rho
    return ek_pke, dk_pke

In [975]:
d = random.randbytes(32)
ek_pke, dk_pke = PKE_KeyGen(d)
len(dk_pke)

768

In [976]:
def PKE_Encrypt(ek_pke, m, r):
    N = 0
    t_ntt = [0] * k
    for i in range(k):
        t_ntt[i] = ByteDecode(ek_pke[384*i : 384*i + 384], 12)
    
    rho = ek_pke[384*k :]

    A_ntt = [[0] * k for _ in range(k)]
    for i in range(k):
        for j in range(k):
            A_ntt[i][j] = SampleNTT(rho + bytes([j]) + bytes([i]))

    y_ntt = [0] * k
    for i in range(k):
        y = SamplePolyCBD(PRF(r, bytes([N]), eta1), eta1)
        y_ntt[i] = NTT(y)
        N += 1

    e1 = [0] * k
    for i in range(k):
        e1[i] = SamplePolyCBD(PRF(r, bytes([N]), eta2), eta2)
        N += 1

    e2 = SamplePolyCBD(PRF(r, bytes([N]), eta2), eta2)
    
    u = [0] * k
    for i in range(k):
        sum_ntt = [0] * 256
        for j in range(k):
            prod = MultiplyNTTs(A_ntt[j][i], y_ntt[j]) # transpose A good?
            sum_ntt = AddPolynomials(sum_ntt, prod)
        u[i] = AddPolynomials(inv_NTT(sum_ntt), e1[i])

    mu = [0] * 256
    m_decode = ByteDecode(m, 1)
    for i in range(256):
        mu[i] = decompress(m_decode[i], 1)
    
    sum_ntt = [0] * 256
    for i in range(k):
        prod = MultiplyNTTs(t_ntt[i], y_ntt[i])
        sum_ntt = AddPolynomials(sum_ntt, prod)
    v = AddPolynomials(inv_NTT(sum_ntt), e2)
    v = AddPolynomials(v, mu)

    c1 = b""
    for i in range(k):
        u_compress = [0] * 256
        for j in range(256):
            u_compress[j] = compress(u[i][j], du)
        c1 += ByteEncode(u_compress, du)
    
    
    v_compress = [0] * 256
    for i in range(256):
        v_compress[i] = compress(v[i], dv)
    c2 = ByteEncode(v_compress, dv)

    return c1 + c2

In [977]:
m = random.randbytes(32)
r = random.randbytes(32)

c = PKE_Encrypt(ek_pke, m, r)

In [978]:
len(c)

768

In [979]:
def PKE_Decrypt(dk_pke, c):
    c1 = c[:32*du*k]
    c2 = c[32*du*k:]

    u_prime = [[0] * 256 for _ in range(k)]
    for i in range(k):
        c1_decode = ByteDecode(c1[32*du*i : 32*du*i + 32*du], du)
        for j in range(256):
            u_prime[i][j] = decompress(c1_decode[j], du)
    
    v_prime = [0] * 256
    c2_decode = ByteDecode(c2, dv)
    for i in range(256):
        v_prime[i] = decompress(c2_decode[i], dv)

    s_ntt = [0] * k
    for i in range(k):
        s_ntt[i] = ByteDecode(dk_pke[384*i : 384*i + 384], 12)

    u_prime_ntt = [0] * k 
    for i in range(k):
        u_prime_ntt[i] = NTT(u_prime[i])

    sum_ntt = [0] * 256
    for i in range(k):
        m = MultiplyNTTs(s_ntt[i], u_prime_ntt[i])
        sum_ntt = AddPolynomials(sum_ntt, m)
    omega = SubPolynomials(v_prime, inv_NTT(sum_ntt))

    omega_compress = [0] * 256
    for i in range(256):
        omega_compress[i] = compress(omega[i], 1)
    m = ByteEncode(omega_compress, 1)

    return m

In [980]:
m_decrypt = PKE_Decrypt(dk_pke, c)

In [981]:
m == m_decrypt

True

## Main Internal Algorithms

Implementation of the internal algorithms $\texttt{ML-KEM.KeyGen\_internal}$, $\texttt{ML-KEM.Encaps\_internal}$ and $\texttt{ML-KEM.Decaps\_internal}$.

In [982]:
def KEM_KeyGen_internal(d, z):
    ek_pke, dk_pke = PKE_KeyGen(d)
    ek = ek_pke
    dk = dk_pke + ek + H(ek) + z
    return ek, dk

In [983]:
d = random.randbytes(32)
z = random.randbytes(32)

ek, dk = KEM_KeyGen_internal(d, z)

In [984]:
len(dk) # 384k + (384k + 32) + 32 + 32

1632

In [985]:
def KEM_Encaps_internal(ek, m):
    K, r = G(m + H(ek))
    c = PKE_Encrypt(ek, m, r)
    return K, c

In [986]:
m = random.randbytes(32)
K, c = KEM_Encaps_internal(ek, m)

In [987]:
def KEM_Decaps_internal(dk, c):
    dk_pke = dk[0 : 384*k]
    ek_pke = dk[384*k : 768*k+32]
    h = dk[768*k+32 : 768*k+64]
    z = dk[768*k+64 : 768*k+96]

    m_prime = PKE_Decrypt(dk_pke, c)
    K_prime, r_prime = G(m_prime + h)

    K_bar = J(z + c)
    c_prime = PKE_Encrypt(ek_pke, m_prime, r_prime)
    
    if c_prime != c:
        K_prime = K_bar
    return K_prime

In [988]:
K_prime = KEM_Decaps_internal(dk, c)

In [989]:
K == K_prime

True

## ML-KEM Key-Encapsulation Mechanism

Implementation of the three main algorithms of the ML-KEM scheme $\texttt{ML-KEM.KeyGen}$, $\texttt{ML-KEM.Encaps}$ and $\texttt{ML-KEM.Decaps}$

In [990]:
def KEM_KeyGen():
    d = random.randbytes(32)
    z = random.randbytes(32)
    if d is None or z is None:
        return None
    ek, dk = KEM_KeyGen_internal(d, z) 
    return ek, dk

In [991]:
def KEM_Encaps(ek):
    m = random.randbytes(32)
    if m is None:
        return None
    K,c = KEM_Encaps_internal(ek, m)
    return K, c

In [992]:
def KEM_Decaps(dk, c):
    K_prime = KEM_Decaps_internal(dk, c)
    return K_prime

# Testing

In [999]:
n = 256
q = 3329

# ML-KEM-1024 parameters
k = 4
eta1 = 2
eta2 = 2
du = 11
dv = 5

ek, dk = KEM_KeyGen()
K, c = KEM_Encaps(ek)
K_prime = KEM_Decaps(dk, c)

K == K_prime

True