
Em Agosto de 2023 a NIST publicou um draft da norma FIPS203  para um Key Encapsulation Mechanism (KEM) derivado dos algoritmos KYBER. 
O preâmbulo do “draft”:

*A key-encapsulation mechanism (or KEM) is a set of algorithms that, under certain conditions, can be used by two parties to establish a shared secret key over a public channel. A shared secret key that is securely established using a KEM can then be used with symmetric-key cryptographic algorithms to perform basic tasks in secure communications, such as encryption and authentication. This standard specifes a key-encapsulation mechanism called ML-KEM. The security of ML-KEM is related to the computational diffculty of the so-called Module Learning with Errorsproblem. At present, ML-KEM is believed to be secure even against adversaries who possess a quantum computer*


Neste trabalho pretende-se implementar em Sagemath um protótipo deste standard parametrizado de acordo com as variantes sugeridas na norma (512, 768 e 1024 bits de segurança).

ML-KEM is a recently standardized lattice-based key encapsulation mechanism [FIPS203]. [ EDNOTE: Reference normatively the ratified version [I-D.draft-cfrg-schwabe-kyber-03] if it is ever ratified. Otherwise keep a normative reference of [FIPS203]. ]


ML-KEM is using Module Learning with Errors as its underlying primitive which is a structured lattices variant that offers good performance and relatively small and balanced key and ciphertext sizes. ML-KEM was standardized with three parameters, ML-KEM-512, ML-KEM-768, and ML-KEM-1024. These were mapped by NIST to the three security levels defined in the NIST PQC Project, Level 1, 3, and 5. These levels correspond to the hardness of breaking AES-128, AES-192 and AES-256 respectively.

This specification introduces ML-KEM-768 and ML-KEM-1024 to IKEv2 key exchanges as conservative security level parameters which will not have material performance impact on IKEv2/IPsec tunnels which usually stay up for long periods of time. Since the ML-KEM-768 and ML-KEM-1024 public key and ciphertext sizes can exceed the typical network MTU, these key exchanges will usually require two or three network IP packets from both the initiator and the responder.

(https://www.ietf.org/archive/id/draft-kampanakis-ml-kem-ikev2-01.html)
(https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.203.ipd.pdf)

In [52]:
from typing import List, Tuple
import hashlib
import os
from functools import reduce
#import sage.all

FIPS-203 specifies ML-KEM which was based on the NIST Post-Quantum Competition’s only KEM winner Kyber. ML-KEM stands for Module Lattice-based Key Encapsulation Mechanism. It defines 3 parameter sets; each at a different level of security:

1. ML-KEM-512 (security equivalence to AES-128)
2. ML-KEM-768 (security equivalence to AES-192)
3. ML-KEM-1024 (security equivalence to AES-256)

ML-KEM is appropriate as a general replacement for quantum-vulnerable key exchange algorithms such as ECDH or FFDH. Note that ECDH and FFDH happen to be Non-Interactive Key Exchange (NIKE) algorithms, but ML-KEM is not so for applications where the non-interactivity is a requirement, ML-KEM is NOT an appropriate drop-in replacement. While the performance of ML-KEM is very good, the cryptographic artifact sizes are larger than those of ECDH and FFDH.

In [53]:
# # ML-KEM-768 params:
# N = 256
# Q = 3329
# K = 3
# ETA1 = 2
# ETA2 = 2
# DU = 10
# DV = 4

# # ML-KEM-512 params:
# N_512 = 256
# Q_512 = 3329
# K_512 = 2
# ETA1_512 = 3
# ETA2_512 = 2
# DU_512 = 10
# DV_512 = 4

# # ML-KEM-768 params:
# N_768 = 256
# Q_768 = 3329
# K_768 = 3
# ETA1_768 = 2
# ETA2_768 = 2
# DU_768 = 10
# DV_768 = 4

# # ML-KEM-1024 params:
# N_1024 = 256
# Q_1024 = 3329
# K_1024 = 4
# ETA1_1024 = 2
# ETA2_1024 = 2
# DU_1024 = 11
# DV_1024 = 5

DEFAULT_PARAMETERS = {
    "kyber_512" : {
        "N" : 256,
        "K" : 2,
        "Q" : 3329,
        "ETA1" : 3,
        "ETA2" : 2,
        "DU" : 10,
        "DV" : 4,
    },
    "kyber_768" : {
        "N" : 256,
        "K" : 3,
        "Q" : 3329,
        "ETA1" : 2,
        "ETA2" : 2,
        "DU" : 10,
        "DV" : 4,
    },
    "kyber_1024" : {
        "N" : 256,
        "K" : 4,
        "Q" : 3329,
        "ETA1" : 2,
        "ETA2" : 2,
        "DU" : 11,
        "DV" : 5,
    }
}


In [54]:
from shakestream import ShakeStream

In [55]:
def bitrev7(n: int) -> int:
    return int(f"{n:07b}"[::-1], 2)

# # Recalculando os parâmetros ZETA e GAMMA para ML-KEM-512
# ZETA_512 = [pow(17, bitrev7(k), Q_512) for k in range(128)]
# GAMMA_512 = [pow(17, 2*bitrev7(k)+1, Q_512) for k in range(128)]

# # Recalculando os parâmetros ZETA e GAMMA para ML-KEM-768
# ZETA_768 = [pow(17, bitrev7(k), Q_768) for k in range(128)]
# GAMMA_768 = [pow(17, 2*bitrev7(k)+1, Q_768) for k in range(128)]

# # Recalculando os parâmetros ZETA e GAMMA para ML-KEM-1024
# ZETA_1024 = [pow(17, bitrev7(k), Q_1024) for k in range(256)]
# GAMMA_1024 = [pow(17, 2*bitrev7(k)+1, Q_1024) for k in range(256)]

ZETA = [pow(17, bitrev7(k), Q) for k in range(128)] # used in ntt and ntt_inv
GAMMA = [pow(17, 2*bitrev7(k)+1, Q) for k in range(128)] # used in ntt_mul

# can be reused for NTT representatives
def poly256_add(a: List[int], b: List[int]) -> List[int]:
	return [(x + y) % Q for x, y in zip(a, b)]

def poly256_sub(a: List[int], b: List[int]) -> List[int]:
	return [(x - y) % Q for x, y in zip(a, b)]

# naive O(n^2) multiplication algorithm for testing/comparison purposes.
# this is not used for the main impl.
def poly256_slow_mul(a: List[int], b: List[int]) -> List[int]:
	c = [0] * 511

	# textbook multiplication, without carry
	for i in range(256):
		for j in range(256):
			c[i+j] = (c[i+j] + a[j] * b[i]) % Q

	# now for reduction mod X^256 + 1
	for i in range(255):
		c[i] = (c[i] - c[i+256]) % Q
		# we could explicitly zero c[i+256] here, but there's no need
	
	# because we're about to truncate c
	return c[:256]


# this is O(n logn)
def ntt(f_in: List[int]) -> List[int]:
	f_out = f_in.copy()
	k = 1
	for log2len in range(7, 0, -1):
		length = 2**log2len
		for start in range(0, 256, 2 * length):
			zeta = ZETA[k]
			k += 1
			for j in range(start, start + length):
				t = (zeta * f_out[j + length]) % Q
				f_out[j + length] = (f_out[j] - t) % Q
				f_out[j] = (f_out[j] + t) % Q
	return f_out


# as well as this
def ntt_inv(f_in: List[int]) -> List[int]:
	f_out = f_in.copy()
	k = 127
	for log2len in range(1, 8):
		length = 2**log2len
		for start in range(0, 256, 2 * length):
			zeta = ZETA[k]
			k -= 1
			for j in range(start, start + length):
				t = f_out[j]
				f_out[j] = (t + f_out[j + length]) % Q
				f_out[j + length] = (zeta * (f_out[j + length] - t)) % Q

	for i in range(256):
		f_out[i] = (f_out[i] * 3303) % Q  # 3303 == pow(128, -1, Q)

	return f_out

ntt_add = poly256_add  # it's just elementwise addition

#  O(n)
def ntt_mul(a: List[int], b: List[int]) -> List[int]:
	c = []
	for i in range(128):
		a0, a1 = a[2 * i: 2 * i + 2]
		b0, b1 = b[2 * i: 2 * i + 2]
		c.append((a0 * b0 + a1 * b1 * GAMMA[i]) % Q)
		c.append((a0 * b1 + a1 * b0) % Q)
	return c


# crypto functions

def mlkem_prf(eta: int, data: bytes, b: int) -> bytes:
	return hashlib.shake_256(data + bytes([b])).digest(64 * eta)

def mlkem_xof(data: bytes, i: int, j: int) -> ShakeStream:
	return ShakeStream(hashlib.shake_128(data + bytes([i, j])).digest)

def mlkem_hash_H(data: bytes) -> bytes:
	return hashlib.sha3_256(data).digest()

def mlkem_hash_J(data: bytes) -> bytes:
	return hashlib.shake_256(data).digest(32)

def mlkem_hash_G(data: bytes) -> bytes:
	return hashlib.sha3_512(data).digest()


In [56]:
# encode/decode logic

def bits_to_bytes(bits: List[int]) -> bytes:
	assert(len(bits) % 8 == 0)
	return bytes(
		sum(bits[i + j] << j for j in range(8))
		for i in range(0, len(bits), 8)
	)

def bytes_to_bits(data: bytes) -> List[int]:
	bits = []
	for word in data:
		for i in range(8):
			bits.append((word >> i) & 1)
	return bits

def byte_encode(d: int, f: List[int]) -> bytes:
	assert(len(f) == 256)
	bits = []
	for a in f:
		for i in range(d):
			bits.append((a >> i) & 1)
	return bits_to_bytes(bits)

def byte_decode(d: int, data: bytes) -> List[int]:
	bits = bytes_to_bits(data)
	return [sum(bits[i * d + j] << j for j in range(d)) for i in range(256)]

def compress(d: int, x: List[int]) -> List[int]:
	return [(((n * 2**d) + Q // 2 ) // Q) % (2**d) for n in x]

def decompress(d: int, x: List[int]) -> List[int]:
	return [(((n * Q) + 2**(d-1) ) // 2**d) % Q for n in x]



In [57]:
# sampling

def sample_ntt(xof: ShakeStream):
	res = []
	while len(res) < 256:
		a, b, c = xof.read(3)
		d1 = ((b & 0xf) << 8) | a
		d2 = c << 4 | b >> 4
		if d1 < Q:
			res.append(d1)
		if d2 < Q and len(res) < 256:
			res.append(d2)
	return res


def sample_poly_cbd(eta: int, data: bytes) -> List[int]:
	assert(len(data) == 64 * eta)
	bits = bytes_to_bits(data)
	f = []
	for i in range(256):
		x = sum(bits[2*i*eta+j] for j in range(eta))
		y = sum(bits[2*i*eta+eta+j] for j in range(eta))
		f.append((x - y) % Q)
	return f


# K-PKE

def kpke_keygen(seed: bytes=None) -> Tuple[bytes, bytes]:
	d = os.urandom(32) if seed is None else seed
	ghash = mlkem_hash_G(d)
	rho, sigma = ghash[:32], ghash[32:]

	ahat = []
	for i in range(K):
		row = []
		for j in range(K):
			row.append(sample_ntt(mlkem_xof(rho, i, j)))
		ahat.append(row)
	
	shat = [
		ntt(sample_poly_cbd(ETA1, mlkem_prf(ETA1, sigma, i)))
		for i in range(K)
	]
	ehat = [
		ntt(sample_poly_cbd(ETA1, mlkem_prf(ETA1, sigma, i+K)))
		for i in range(K)
	]
	that = [ # t = a * s + e
		reduce(ntt_add, [
			ntt_mul(ahat[j][i], shat[j])
			for j in range(K)
		] + [ehat[i]])
		for i in range(K)
	]
	ek_pke = b"".join(byte_encode(12, s) for s in that) + rho
	dk_pke = b"".join(byte_encode(12, s) for s in shat)
	return ek_pke, dk_pke


def kpke_encrypt(ek_pke: bytes, m: bytes, r: bytes) -> bytes:
	that = [byte_decode(12, ek_pke[i*128*K:(i+1)*128*K]) for i in range(K)]
	rho = ek_pke[-32:]

	# identical to as in kpke_keygen
	ahat = []
	for i in range(K):
		row = []
		for j in range(K):
			row.append(sample_ntt(mlkem_xof(rho, i, j)))
		ahat.append(row)
	
	rhat = [
		ntt(sample_poly_cbd(ETA1, mlkem_prf(ETA1, r, i)))
		for i in range(K)
	]
	e1 = [
		sample_poly_cbd(ETA2, mlkem_prf(ETA2, r, i+K))
		for i in range(K)
	]
	e2 = sample_poly_cbd(ETA2, mlkem_prf(ETA2, r, 2*K))

	u = [ # u = ntt-1(AT*r)+e1
		poly256_add(ntt_inv(reduce(ntt_add, [
			ntt_mul(ahat[i][j], rhat[j]) # i,j are reversed here
			for j in range(K)
		])), e1[i])
		for i in range(K)
	]
	mu = decompress(1, byte_decode(1, m))
	v = poly256_add(ntt_inv(reduce(ntt_add, [
		ntt_mul(that[i], rhat[i])
		for i in range(K)
	])), poly256_add(e2, mu))

	c1 = b"".join(byte_encode(DU, compress(DU, u[i])) for i in range(K))
	c2 = byte_encode(DV, compress(DV, v))
	return c1 + c2


def kpke_decrypt(dk_pke: bytes, c: bytes) -> bytes:
	c1 = c[:32*DU*K]
	c2 = c[32*DU*K:]
	u = [
		decompress(DU, byte_decode(DU, c1[i*32*DU:(i+1)*32*DU]))
		for i in range(K)
	]
	v = decompress(DV, byte_decode(DV, c2))
	shat = [byte_decode(12, dk_pke[i*384:(i+1)*384]) for i in range(K)]
	w = poly256_sub(v, ntt_inv(reduce(ntt_add, [
		ntt_mul(shat[i], ntt(u[i]))
		for i in range(K)
	])))
	m = byte_encode(1, compress(1, w))
	return m


# KEM

def mlkem_keygen(seed1=None, seed2=None):
	z = os.urandom(32) if seed1 is None else seed1
	ek_pke, dk_pke = kpke_keygen(seed2)
	ek = ek_pke
	dk = dk_pke + ek + mlkem_hash_H(ek) + z
	return ek, dk


def mlkem_encaps(ek: bytes, seed=None) -> Tuple[bytes, bytes]:
	m = os.urandom(32) if seed is None else seed
	ghash = mlkem_hash_G(m + mlkem_hash_H(ek))
	k = ghash[:32]
	r = ghash[32:]
	c = kpke_encrypt(ek, m, r)
	return k, c


def mlkem_decaps(c: bytes, dk: bytes) -> bytes:
	dk_pke = dk[:384*K]
	ek_pke = dk[384*K : 768*K + 32]
	h = dk[768*K + 32 : 768*K + 64]
	z = dk[768*K + 64 : 768*K + 96]
	mdash = kpke_decrypt(dk_pke, c)
	ghash = mlkem_hash_G(mdash + h)
	kdash = ghash[:32]
	rdash = ghash[32:]
	kbar = mlkem_hash_J(z + c)
	cdash = kpke_encrypt(ek_pke, mdash, rdash)
	if cdash != c:
		return kbar
	return kdash


In [58]:
# if __name__ == "__main__":
# 	a = list(range(256))
# 	b = list(range(1024, 1024+256))

# 	ntt_res = ntt_inv(ntt_add(ntt(a), ntt(b)))
# 	poly_res = poly256_add(a, b)

# 	assert(ntt_res == poly_res)

# 	ntt_prod = ntt_inv(ntt_mul(ntt(a), ntt(b)))
# 	poly_prod = poly256_slow_mul(a, b)

# 	assert(ntt_prod == poly_prod)


# 	ek_pke, dk_pke = kpke_keygen(b"SEED"*8)

# 	msg = b"This is a demonstration message."
# 	ct = kpke_encrypt(ek_pke, msg, b"RAND"*8)
# 	pt = kpke_decrypt(dk_pke, ct)
# 	print(pt)
# 	assert(pt == msg)


# 	ek, dk = mlkem_keygen()
# 	k1, c = mlkem_encaps(ek)
# 	print("encapsulated:", k1.hex())

# 	k2 = mlkem_decaps(c, dk)
# 	print("decapsulated:", k2.hex())

# 	assert(k1 == k2)
	
if __name__ == "__main__":
    # Teste para ML-KEM-512
    print("Teste para ML-KEM-512:")
    ek_512, dk_512 = mlkem_keygen()
    k1_512, c_512 = mlkem_encaps(ek_512)
    k2_512 = mlkem_decaps(c_512, dk_512)
    print("Chave original:", k1_512.hex())
    print("Chave desencapsulada:", k2_512.hex())
    assert k1_512 == k2_512

    # Teste para ML-KEM-768
    print("\nTeste para ML-KEM-768:")
    ek_768, dk_768 = mlkem_keygen()
    k1_768, c_768 = mlkem_encaps(ek_768)
    k2_768 = mlkem_decaps(c_768, dk_768)
    print("Chave original:", k1_768.hex())
    print("Chave desencapsulada:", k2_768.hex())
    assert k1_768 == k2_768

    # Teste para ML-KEM-1024
    print("\nTeste para ML-KEM-1024:")
    ek_1024, dk_1024 = mlkem_keygen()
    k1_1024, c_1024 = mlkem_encaps(ek_1024)
    k2_1024 = mlkem_decaps(c_1024, dk_1024)
    print("Chave original:", k1_1024.hex())
    print("Chave desencapsulada:", k2_1024.hex())
    assert k1_1024 == k2_1024


# # Testar as variantes do Kyber
# if __name__ == "__main__":
#     # Teste com Kyber512
#     print("Kyber512 parameters:")
#     print("n:", Kyber512.N)
#     print("q:", Kyber512.Q)
#     print("k:", Kyber512.K)
#     print("eta1:", Kyber512.ETA1)
#     print("eta2:", Kyber512.ETA2)
#     print("du:", Kyber512.DU)
#     print("dv:", Kyber512.DV)
#     print()

#     # Teste com Kyber768
#     print("Kyber768 parameters:")
#     print("n:", Kyber768.N)
#     print("q:", Kyber768.Q)
#     print("k:", Kyber768.K)
#     print("eta1:", Kyber768.ETA1)
#     print("eta2:", Kyber768.ETA2)
#     print("du:", Kyber768.DU)
#     print("dv:", Kyber768.DV)
#     print()

#     # Teste com Kyber1024
#     print("Kyber1024 parameters:")
#     print("n:", Kyber1024.N)
#     print("q:", Kyber1024.Q)
#     print("k:", Kyber1024.K)
#     print("eta1:", Kyber1024.ETA1)
#     print("eta2:", Kyber1024.ETA2)
#     print("du:", Kyber1024.DU)
#     print("dv:", Kyber1024.DV)
#     print()
    

Teste para ML-KEM-512:
Chave original: b9c2b61c4de611ad58d9413047fbaa01fc7433a46abe2ea5c4d7f61d5e5b4ea2
Chave desencapsulada: b9c2b61c4de611ad58d9413047fbaa01fc7433a46abe2ea5c4d7f61d5e5b4ea2

Teste para ML-KEM-768:
Chave original: 94a82d0f13113ea313bc10053e52d86a70bda6e6a5599ccd656b602093cfacac
Chave desencapsulada: 94a82d0f13113ea313bc10053e52d86a70bda6e6a5599ccd656b602093cfacac

Teste para ML-KEM-1024:
Chave original: 9f80f3285663c3b0ae1c20641070721a331ad74a8b2ddeb5fb921186c92e0baa
Chave desencapsulada: 9f80f3285663c3b0ae1c20641070721a331ad74a8b2ddeb5fb921186c92e0baa
