# Kyber / ML-KEM Overview

**Module 08** | 08-lattices-post-quantum

*Module-LWE, NIST FIPS 203, KEM flow, toy implementation*

## The Post-Quantum Moment

> **Motivating Question:** Quantum computers will break RSA and elliptic curves. What replaces them?
>
> The answer is not hypothetical. It is already standardized: **NIST chose Kyber (ML-KEM) in August 2024 as FIPS 203**. Chrome, Signal, and TLS 1.3 are already deploying it.

Shor's algorithm factors integers and computes discrete logarithms in polynomial time on a quantum computer. This breaks:

| Scheme | Hardness assumption | Broken by Shor? |
|--------|--------------------|-----------------|
| RSA | Integer factorization | Yes |
| Diffie-Hellman | Discrete log in $\mathbb{Z}_p^*$ | Yes |
| ECDH / ECDSA | Discrete log on elliptic curves | Yes |
| **ML-KEM (Kyber)** | **Module-LWE on lattices** | **No** |

This notebook is the **climax of Module 08**. Everything you built in 08a through 08e—lattices, SVP, LLL, LWE, Ring-LWE—comes together here in the real-world scheme that is replacing ECDH as we speak.

## Objectives

By the end of this notebook you will be able to:

1. Explain why Module-LWE offers a better efficiency/security tradeoff than plain LWE or Ring-LWE.
2. Describe the three phases of a KEM: **KeyGen**, **Encapsulate**, **Decapsulate**.
3. State the NIST parameter sets for ML-KEM-512, ML-KEM-768, and ML-KEM-1024.
4. Implement a **toy Kyber-like KEM** from scratch in SageMath.
5. Compare ML-KEM key/ciphertext sizes to classical schemes.

## Prerequisites

- Completion of [Ring-LWE](08e-ring-lwe.ipynb) (the polynomial ring $R_q = \mathbb{Z}_q[x]/(x^n+1)$).
- Familiarity with LWE (08d): given $(\mathbf{A}, \mathbf{b} = \mathbf{A}\mathbf{s} + \mathbf{e})$, finding $\mathbf{s}$ is hard.
- Understanding of lattice hardness (08a-08c): SVP, LLL reduction.

**Bridge from 08e:** In Ring-LWE we moved from random matrices $\mathbf{A} \in \mathbb{Z}_q^{m \times n}$ to a single polynomial $a \in R_q$. This gave us compact keys but concentrated trust in a single algebraic structure. Module-LWE is the sweet spot: it works with **small matrices of polynomials**, combining Ring-LWE's efficiency with LWE's flexibility.

## 1. From Ring-LWE to Module-LWE

Recall the progression of LWE variants:

| Variant | Public matrix $\mathbf{A}$ | Key size | Security argument |
|---------|---------------------------|----------|-------------------|
| **LWE** | Random $\mathbf{A} \in \mathbb{Z}_q^{m \times n}$ | $O(n^2)$ | Strong (unstructured) |
| **Ring-LWE** | Single $a \in R_q$ | $O(n)$ | Relies on ring structure |
| **Module-LWE** | $\mathbf{A} \in R_q^{k \times k}$ (matrix of ring elements) | $O(k \cdot n)$ | Tunable via $k$ |

Module-LWE works in the **free module** $R_q^k$, where $R_q = \mathbb{Z}_q[x]/(x^n+1)$. The matrix $\mathbf{A}$ is now a $k \times k$ matrix whose entries are polynomials in $R_q$. Increasing $k$ adds security without changing the underlying ring.

**Why this matters:** If a structural weakness is found in $R_q$ for a specific $n$, we can increase $k$ to compensate—without redesigning the whole scheme. This modularity is exactly why NIST chose Module-LWE over Ring-LWE.

In [None]:
# === Setup: The polynomial ring R_q used in Kyber ===
#
# Kyber fixes n = 256, q = 3329.
# R_q = Z_q[x] / (x^256 + 1)
#
# For our toy examples we use MUCH smaller parameters.
# Toy: n = 4, q = 17 (so we can see what's happening)

# --- Real Kyber parameters (for reference) ---
n_real = 256
q_real = 3329
print(f"Real Kyber: n = {n_real}, q = {q_real}")
print(f"  q is prime: {is_prime(q_real)}")
print(f"  q mod 2n = {q_real % (2*n_real)}  (= 1, so NTT-friendly!)")

# --- Toy parameters for this notebook ---
n_toy = 4
q_toy = 17  # prime, and 17 mod 8 = 1 (NTT-friendly for n=4)

# Build the ring R_q = Z_q[x] / (x^n + 1)
Zq = Integers(q_toy)
Rx.<x> = PolynomialRing(Zq)
f_mod = x^n_toy + 1
Rq = Rx.quotient(f_mod, 'xbar')
xbar = Rq.gen()

print(f"\nToy parameters: n = {n_toy}, q = {q_toy}")
print(f"R_q = Z_{q_toy}[x] / ({f_mod})")
print(f"Example element: {Rq([3, 1, 4, 1])} = 3 + x + 4x^2 + x^3")

In [None]:
# === Module-LWE: matrices of polynomials ===
#
# In Module-LWE with parameter k, the public matrix A is k x k
# where each entry is a polynomial in R_q.
#
# Let's see what this looks like for k=2 (Kyber-512 security level).

import random

def random_poly(ring, n, q):
    """Sample a uniformly random polynomial in R_q."""
    coeffs = [ZZ.random_element(0, q) for _ in range(n)]
    return ring(coeffs)

def small_poly(ring, n, bound=1):
    """Sample a 'small' polynomial with coefficients in [-bound, bound]."""
    coeffs = [ZZ.random_element(-bound, bound + 1) for _ in range(n)]
    return ring(coeffs)

k = 2  # Module rank (Kyber-512 uses k=2)

# Generate a k x k matrix of random polynomials
A = matrix(Rq, k, k, lambda i, j: random_poly(Rq, n_toy, q_toy))

print("Public matrix A (2x2 matrix of polynomials in R_q):")
print("="*55)
for i in range(k):
    for j in range(k):
        print(f"  A[{i},{j}] = {A[i,j]}")
print()
print(f"Each entry is a polynomial of degree < {n_toy} with coefficients in Z_{q_toy}.")
print(f"Total 'randomness' in A: {k}*{k}*{n_toy} = {k*k*n_toy} field elements.")

> **Checkpoint:** Before continuing, make sure you understand the difference:
> - In **LWE**, $\mathbf{A}$ is a matrix of integers mod $q$.
> - In **Ring-LWE**, $a$ is a single polynomial in $R_q$.
> - In **Module-LWE**, $\mathbf{A}$ is a matrix of polynomials in $R_q$.
>
> *Quick check:* For Kyber-768 ($k=3$, $n=256$), how many total coefficients are in the matrix $\mathbf{A}$? Answer: $3 \times 3 \times 256 = 2304$. Compare this to plain LWE with the same security, which would need a matrix with $\sim 768^2 = 589{,}824$ entries.

## 2. What Is a KEM?

> **Misconception Alert:** "Kyber is a public-key encryption scheme." Not quite! Kyber is a **Key Encapsulation Mechanism (KEM)**. It does not encrypt arbitrary messages. Instead, it establishes a **shared symmetric key** between two parties.

A KEM has three algorithms:

1. **KeyGen()** $\to$ (public key $pk$, secret key $sk$)
2. **Encapsulate($pk$)** $\to$ (ciphertext $ct$, shared secret $K$)
3. **Decapsulate($sk$, $ct$)** $\to$ shared secret $K$

The crucial property: the shared secret $K$ produced by Encapsulate and Decapsulate is the **same** value. Alice and Bob can then use $K$ as the key for AES-256 or ChaCha20.

```
    Alice                              Bob
    -----                              ---
    (pk, sk) = KeyGen()
    send pk ─────────────────────────►
                                       (ct, K) = Encapsulate(pk)
                          ◄──────────── send ct
    K = Decapsulate(sk, ct)
    
    Both now share K for symmetric encryption.
```

**Why KEM instead of PKE?** KEMs are simpler to build securely. A KEM only needs to produce a random-looking key, not encrypt an arbitrary message. This makes the IND-CCA2 security proof cleaner via the Fujisaki-Okamoto transform.

## 3. Kyber KeyGen

KeyGen creates a Module-LWE instance:

1. Sample $\mathbf{A} \xleftarrow{\$} R_q^{k \times k}$ (public, random matrix of polynomials)
2. Sample $\mathbf{s} \xleftarrow{\small\text{CBD}} R_q^k$ (secret vector, **small** coefficients)
3. Sample $\mathbf{e} \xleftarrow{\small\text{CBD}} R_q^k$ (error vector, **small** coefficients)
4. Compute $\mathbf{t} = \mathbf{A} \mathbf{s} + \mathbf{e}$

**Public key:** $(\mathbf{A}, \mathbf{t})$ 
**Secret key:** $\mathbf{s}$

This is exactly the LWE problem from 08d, but over the module $R_q^k$ instead of $\mathbb{Z}_q^n$! The hardness assumption is: given $(\mathbf{A}, \mathbf{t})$, it is computationally infeasible to recover $\mathbf{s}$.

*Note: CBD stands for Centered Binomial Distribution. In Kyber, small coefficients are sampled from $\text{CBD}_\eta$, which outputs integers in $[-\eta, \eta]$. For our toy version, we simply sample uniformly from $\{-1, 0, 1\}$.*

In [None]:
# === Toy Kyber KeyGen ===

def toy_kyber_keygen(Rq, n, q, k):
    """
    Generate a Kyber-like key pair.
    
    Returns: (pk, sk) where pk = (A, t) and sk = s
    """
    # Step 1: Random public matrix A (k x k of polynomials)
    A = matrix(Rq, k, k, lambda i, j: random_poly(Rq, n, q))
    
    # Step 2: Secret vector s (k polynomials with SMALL coefficients)
    s = vector(Rq, [small_poly(Rq, n) for _ in range(k)])
    
    # Step 3: Error vector e (k polynomials with SMALL coefficients)
    e = vector(Rq, [small_poly(Rq, n) for _ in range(k)])
    
    # Step 4: t = A*s + e
    t = A * s + e
    
    pk = (A, t)
    sk = s
    return pk, sk

# Generate keys
set_random_seed(42)  # For reproducibility
pk, sk = toy_kyber_keygen(Rq, n_toy, q_toy, k)
A, t = pk
s = sk

print("=== KeyGen ===")
print(f"Module rank k = {k}")
print()
print("Secret key s (small coefficients!):")
for i in range(k):
    print(f"  s[{i}] = {s[i]}")
print()
print("Error vector e was added to mask s.")
print()
print("Public key t = A*s + e:")
for i in range(k):
    print(f"  t[{i}] = {t[i]}")
print()
print("Notice: t looks random (big coefficients), even though s and e are small.")
print("This is the Module-LWE guarantee: (A, A*s+e) is indistinguishable from (A, random).")

> **Checkpoint:** After KeyGen, what does a ciphertext need to contain?
>
> *Think about it:* Bob wants to send Alice a shared secret. He has Alice's public key $(\mathbf{A}, \mathbf{t})$. He needs to create something that:
> 1. Encodes a message (or random seed) that Alice can recover.
> 2. Looks random to anyone without $\mathbf{s}$.
>
> The answer: Bob creates his **own** Module-LWE instance using the same $\mathbf{A}$, and uses $\mathbf{t}$ to "mix in" a message. The ciphertext will contain two parts: $(\mathbf{u}, v)$.

## 4. Kyber Encapsulation

To encapsulate (create a ciphertext and shared secret):

1. Choose a random message $m \in \{0,1\}^{256}$ (in real Kyber, 32 random bytes)
2. Derive randomness from $m$ (we skip this in our toy version)
3. Sample $\mathbf{r} \xleftarrow{\small\text{CBD}} R_q^k$ (small random vector)
4. Sample $\mathbf{e}_1 \xleftarrow{\small\text{CBD}} R_q^k$ and $e_2 \xleftarrow{\small\text{CBD}} R_q$ (small errors)
5. Compute:
   - $\mathbf{u} = \mathbf{A}^T \mathbf{r} + \mathbf{e}_1$ (vector of $k$ polynomials)
   - $v = \mathbf{t}^T \mathbf{r} + e_2 + \lceil q/2 \rfloor \cdot m$ (single polynomial)

**Ciphertext:** $(\mathbf{u}, v)$

**Shared secret:** $K = \text{Hash}(m)$

The term $\lceil q/2 \rfloor \cdot m$ encodes the message bit into the "upper half" of $\mathbb{Z}_q$. For our toy $q = 17$, this means $\lceil 17/2 \rfloor = 9$. A message bit of 1 adds 9 to the coefficient; a bit of 0 adds nothing.

In [None]:
# === Toy Kyber Encapsulation ===

def encode_message(Rq, n, q, msg_bits):
    """
    Encode binary message into a polynomial.
    Each bit b is mapped to b * round(q/2).
    """
    half_q = (q + 1) // 2  # round(q/2)
    coeffs = [int(b) * half_q for b in msg_bits]
    # Pad with zeros if message is shorter than n
    coeffs += [0] * (n - len(coeffs))
    return Rq(coeffs[:n])

def toy_kyber_encapsulate(Rq, n, q, k, pk, msg_bits):
    """
    Encapsulate: produce ciphertext (u, v) encoding msg_bits.
    In real Kyber, msg_bits is random and the shared secret is Hash(msg).
    """
    A, t = pk
    
    # Sample small randomness
    r  = vector(Rq, [small_poly(Rq, n) for _ in range(k)])
    e1 = vector(Rq, [small_poly(Rq, n) for _ in range(k)])
    e2 = small_poly(Rq, n)
    
    # Encode the message
    m_encoded = encode_message(Rq, n, q, msg_bits)
    
    # Compute ciphertext
    u = A.transpose() * r + e1    # k polynomials
    v = t * r + e2 + m_encoded    # 1 polynomial (dot product t^T * r)
    
    return (u, v)

# Encapsulate with a toy message
msg = [1, 0, 1, 1]  # 4 bits (one per coefficient, since n=4)
ct = toy_kyber_encapsulate(Rq, n_toy, q_toy, k, pk, msg)
u, v = ct

print("=== Encapsulation ===")
print(f"Message bits: {msg}")
print(f"Encoded as: each '1' bit -> coefficient {(q_toy+1)//2}, each '0' bit -> 0")
print()
print("Ciphertext u (vector of k polynomials):")
for i in range(k):
    print(f"  u[{i}] = {u[i]}")
print(f"\nCiphertext v (single polynomial):")
print(f"  v = {v}")
print()
print("The ciphertext looks random. But Alice, who knows s, can decrypt it.")

## 5. Kyber Decapsulation

Alice receives $(\mathbf{u}, v)$ and uses her secret key $\mathbf{s}$ to recover the message:

1. Compute $v - \mathbf{s}^T \mathbf{u}$
2. For each coefficient, round to the nearest value in $\{0, \lceil q/2 \rfloor\}$ to recover the message bit.

**Why does this work?** Let's expand:

$$v - \mathbf{s}^T \mathbf{u} = (\mathbf{t}^T \mathbf{r} + e_2 + \lceil q/2 \rfloor \cdot m) - \mathbf{s}^T(\mathbf{A}^T \mathbf{r} + \mathbf{e}_1)$$

Since $\mathbf{t} = \mathbf{A}\mathbf{s} + \mathbf{e}$, we have $\mathbf{t}^T \mathbf{r} = \mathbf{s}^T \mathbf{A}^T \mathbf{r} + \mathbf{e}^T \mathbf{r}$. Substituting:

$$= \mathbf{s}^T \mathbf{A}^T \mathbf{r} + \mathbf{e}^T \mathbf{r} + e_2 + \lceil q/2 \rfloor \cdot m - \mathbf{s}^T \mathbf{A}^T \mathbf{r} - \mathbf{s}^T \mathbf{e}_1$$

$$= \lceil q/2 \rfloor \cdot m + \underbrace{\mathbf{e}^T \mathbf{r} + e_2 - \mathbf{s}^T \mathbf{e}_1}_{\text{small noise}}$$

The $\mathbf{s}^T \mathbf{A}^T \mathbf{r}$ terms cancel! What remains is the encoded message plus a small noise term. Since all of $\mathbf{e}, \mathbf{r}, e_2, \mathbf{s}, \mathbf{e}_1$ have small coefficients, their products and sums are still small relative to $\lceil q/2 \rfloor$. So we can round to recover $m$ exactly.

In [None]:
# === Toy Kyber Decapsulation ===

def decode_message(noisy_poly, n, q):
    """
    Decode polynomial back to message bits.
    Each coefficient is rounded to nearest {0, round(q/2)}.
    """
    half_q = (q + 1) // 2
    # Lift coefficients to integers in [0, q-1]
    coeffs = [ZZ(c) for c in list(noisy_poly.lift())]
    # Pad if needed
    coeffs += [0] * (n - len(coeffs))
    
    bits = []
    for c in coeffs[:n]:
        # Distance to 0 vs distance to half_q
        dist_to_0 = min(c, q - c)       # handles wraparound
        dist_to_half = min(abs(c - half_q), q - abs(c - half_q))
        bits.append(0 if dist_to_0 <= dist_to_half else 1)
    return bits

def toy_kyber_decapsulate(Rq, n, q, k, sk, ct):
    """
    Decapsulate: recover message from ciphertext using secret key.
    """
    s = sk
    u, v = ct
    
    # Core computation: v - s^T * u
    noisy_message = v - s * u  # dot product s^T * u, then subtract from v
    
    # Decode: round each coefficient to recover message bits
    msg_recovered = decode_message(noisy_message, n, q)
    
    return msg_recovered, noisy_message

# Decapsulate
msg_recovered, noisy = toy_kyber_decapsulate(Rq, n_toy, q_toy, k, sk, ct)

print("=== Decapsulation ===")
print(f"Computed v - s^T*u = {noisy}")
print()

# Show the rounding process
half_q = (q_toy + 1) // 2
coeffs = [ZZ(c) for c in list(noisy.lift())]
coeffs += [0] * (n_toy - len(coeffs))
print(f"Rounding each coefficient (q={q_toy}, half_q={half_q}):")
for i, c in enumerate(coeffs[:n_toy]):
    dist_0 = min(c, q_toy - c)
    dist_half = min(abs(c - half_q), q_toy - abs(c - half_q))
    closer = "0 (bit=0)" if dist_0 <= dist_half else f"{half_q} (bit=1)"
    print(f"  coeff[{i}] = {c:2d}  |  dist to 0: {dist_0}, dist to {half_q}: {dist_half}  ->  closer to {closer}")

print(f"\nOriginal message:  {msg}")
print(f"Recovered message: {msg_recovered}")
print(f"\nCorrect: {msg == msg_recovered}")

> **Misconception Callout:** "Post-quantum cryptography is theoretical and not deployed anywhere yet."
>
> **Wrong.** ML-KEM (Kyber) is NIST FIPS 203, standardized August 2024. It is already deployed in:
> - **Google Chrome** (hybrid X25519 + ML-KEM-768 since 2024)
> - **Signal Protocol** (PQXDH: X25519 + ML-KEM-1024 since September 2023)
> - **TLS 1.3** (RFC 9180 hybrid key exchange)
> - **Apple iMessage** (PQ3 protocol since March 2024)
>
> You are studying a scheme that is protecting billions of messages *right now*.

## 6. NIST Parameter Sets

Kyber (ML-KEM) defines three security levels by varying the module rank $k$. All share $n=256$ and $q=3329$.

In [None]:
# === NIST ML-KEM Parameter Comparison ===

params = {
    'ML-KEM-512':  {'k': 2, 'eta1': 3, 'eta2': 2, 'du': 10, 'dv': 4,
                    'pk_bytes': 800,  'sk_bytes': 1632, 'ct_bytes': 768,
                    'security': '~128-bit (NIST Level 1)'},
    'ML-KEM-768':  {'k': 3, 'eta1': 2, 'eta2': 2, 'du': 10, 'dv': 4,
                    'pk_bytes': 1184, 'sk_bytes': 2400, 'ct_bytes': 1088,
                    'security': '~192-bit (NIST Level 3)'},
    'ML-KEM-1024': {'k': 4, 'eta1': 2, 'eta2': 2, 'du': 11, 'dv': 5,
                    'pk_bytes': 1568, 'sk_bytes': 3168, 'ct_bytes': 1568,
                    'security': '~256-bit (NIST Level 5)'},
}

# Display as a clean table
print("Parameter Set k eta1 eta2 PK (bytes) CT (bytes) Security")for name, p in params.items():
    print(f"{name} {p['k']} {p['eta1']} {p['eta2']} {p['pk_bytes']} {p['ct_bytes']} {p['security']}")

print()
print("All parameter sets use n = 256, q = 3329.")
print("The ONLY difference is the module rank k (and minor compression parameters).")
print("Higher k = more rows in the matrix = harder Module-LWE instance = more security.")

In [None]:
# === Key Size Comparison: ML-KEM vs Classical ===

import matplotlib.pyplot as plt

schemes = [
    ('X25519\n(ECDH)', 32, 32, '~128-bit'),
    ('RSA-2048', 256, 256, '~112-bit'),
    ('RSA-4096', 512, 512, '~140-bit'),
    ('ML-KEM-512', 800, 768, '~128-bit'),
    ('ML-KEM-768', 1184, 1088, '~192-bit'),
    ('ML-KEM-1024', 1568, 1568, '~256-bit'),
]

names = [s[0] for s in schemes]
pk_sizes = [s[1] for s in schemes]
ct_sizes = [s[2] for s in schemes]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

colors = ['#2ecc71', '#e74c3c', '#e74c3c', '#3498db', '#3498db', '#3498db']

# Public key sizes
bars1 = ax1.bar(names, pk_sizes, color=colors, edgecolor='black', linewidth=0.5)
ax1.set_ylabel('Bytes', fontsize=12)
ax1.set_title('Public Key Size Comparison', fontsize=14, fontweight='bold')
ax1.set_ylim(0, max(pk_sizes) * 1.2)
for bar, size in zip(bars1, pk_sizes):
    ax1.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 30,
             f'{size}', ha='center', va='bottom', fontsize=10, fontweight='bold')

# Ciphertext sizes
bars2 = ax2.bar(names, ct_sizes, color=colors, edgecolor='black', linewidth=0.5)
ax2.set_ylabel('Bytes', fontsize=12)
ax2.set_title('Ciphertext Size Comparison', fontsize=14, fontweight='bold')
ax2.set_ylim(0, max(ct_sizes) * 1.2)
for bar, size in zip(bars2, ct_sizes):
    ax2.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 30,
             f'{size}', ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.savefig('kyber_size_comparison.png', dpi=100, bbox_inches='tight')
plt.show()

print("\nKey takeaway:")
print(f"  ML-KEM-768 public key = {1184} bytes (vs X25519 = 32 bytes: {1184/32:.0f}x larger)")
print(f"  But ML-KEM-768 is MUCH smaller than RSA-4096 = {512} bytes for similar security.")
print(f"  The ~1 KB overhead is very acceptable for post-quantum security.")

## 7. Complete Toy Kyber: End-to-End

Let's run the full KEM flow with slightly larger toy parameters to see everything work together. We will use $n=8$, $q=97$, $k=2$ (still tiny compared to real Kyber, but large enough to be interesting).

In [None]:
# === Complete Toy Kyber KEM: Larger Example ===

# Larger toy parameters
n2 = 8
q2 = 97  # prime, 97 mod 16 = 1 (NTT-friendly)
k2 = 2

# Build the ring
Zq2 = Integers(q2)
Rx2.<y> = PolynomialRing(Zq2)
Rq2 = Rx2.quotient(y^n2 + 1, 'ybar')

# Helper functions for this ring
def rand_poly2(ring, n, q):
    return ring([ZZ.random_element(0, q) for _ in range(n)])

def small_poly2(ring, n, bound=1):
    return ring([ZZ.random_element(-bound, bound+1) for _ in range(n)])

def encode_msg2(ring, n, q, bits):
    half_q = (q + 1) // 2
    coeffs = [int(b) * half_q for b in bits] + [0]*(n - len(bits))
    return ring(coeffs[:n])

def decode_msg2(poly, n, q):
    half_q = (q + 1) // 2
    coeffs = [ZZ(c) for c in list(poly.lift())] + [0]*(n)
    bits = []
    for c in coeffs[:n]:
        d0 = min(c, q - c)
        dh = min(abs(c - half_q), q - abs(c - half_q))
        bits.append(0 if d0 <= dh else 1)
    return bits

print("="*60)
print(f"TOY KYBER KEM  (n={n2}, q={q2}, k={k2})")
print("="*60)

# --- KeyGen (Alice) ---
set_random_seed(2024)
A2 = matrix(Rq2, k2, k2, lambda i,j: rand_poly2(Rq2, n2, q2))
s2 = vector(Rq2, [small_poly2(Rq2, n2) for _ in range(k2)])
e2_vec = vector(Rq2, [small_poly2(Rq2, n2) for _ in range(k2)])
t2 = A2 * s2 + e2_vec

print("\n--- Alice: KeyGen ---")
print(f"Secret key s has {k2} polynomials of degree < {n2}")
print(f"Public key (A, t) published.")

# --- Encapsulate (Bob) ---
message = [1, 0, 1, 1, 0, 0, 1, 0]  # 8 bits = 1 byte

r2 = vector(Rq2, [small_poly2(Rq2, n2) for _ in range(k2)])
e1_2 = vector(Rq2, [small_poly2(Rq2, n2) for _ in range(k2)])
e2_2 = small_poly2(Rq2, n2)
m_enc = encode_msg2(Rq2, n2, q2, message)

u2 = A2.transpose() * r2 + e1_2
v2 = t2 * r2 + e2_2 + m_enc

print(f"\n--- Bob: Encapsulate ---")
print(f"Message bits:    {message}")
print(f"Ciphertext (u, v) computed and sent to Alice.")

# --- Decapsulate (Alice) ---
noisy2 = v2 - s2 * u2
recovered = decode_msg2(noisy2, n2, q2)

print(f"\n--- Alice: Decapsulate ---")
print(f"Recovered bits:  {recovered}")
print(f"\nOriginal:  {message}")
print(f"Recovered: {recovered}")
print(f"Match: {message == recovered}")

# In real Kyber, the shared secret would be K = Hash(message)
import hashlib
shared_secret = hashlib.sha256(bytes(message)).hexdigest()[:32]
print(f"\nShared secret K = SHA-256(message)[:16] = {shared_secret}")
print("Both Alice and Bob now use K as the symmetric key for AES/ChaCha20.")

## 8. Why Decryption Works: Noise Analysis

The correctness of Kyber depends on the accumulated noise being small enough to not flip any message bits during decoding. Let's visualize this.

In [None]:
# === Noise Analysis: Why decryption succeeds ===
#
# After decapsulation, each coefficient has the form:
#   round(q/2) * m_i + noise_i
# Decryption succeeds when |noise_i| < q/4.

# Let's run many trials and observe the noise distribution.

noise_samples = []
failures = 0
num_trials = 500

for trial in range(num_trials):
    # KeyGen
    A_t = matrix(Rq2, k2, k2, lambda i,j: rand_poly2(Rq2, n2, q2))
    s_t = vector(Rq2, [small_poly2(Rq2, n2) for _ in range(k2)])
    e_t = vector(Rq2, [small_poly2(Rq2, n2) for _ in range(k2)])
    t_t = A_t * s_t + e_t
    
    # Encapsulate (message = all zeros for clean noise measurement)
    r_t = vector(Rq2, [small_poly2(Rq2, n2) for _ in range(k2)])
    e1_t = vector(Rq2, [small_poly2(Rq2, n2) for _ in range(k2)])
    e2_t = small_poly2(Rq2, n2)
    
    u_t = A_t.transpose() * r_t + e1_t
    v_t = t_t * r_t + e2_t  # No message encoding -> pure noise after decaps
    
    # Decapsulate: should get ~0 for each coefficient
    noisy_t = v_t - s_t * u_t
    coeffs_t = [ZZ(c) for c in list(noisy_t.lift())] + [0]*n2
    
    for c in coeffs_t[:n2]:
        # Center around 0: map [0, q-1] to [-(q-1)/2, (q-1)/2]
        centered = c if c <= q2//2 else c - q2
        noise_samples.append(centered)
        if abs(centered) >= q2 // 4:
            failures += 1

print(f"Noise statistics over {num_trials} trials ({num_trials * n2} total coefficients):")
print(f"  Min noise:     {min(noise_samples)}")
print(f"  Max noise:     {max(noise_samples)}")
print(f"  Mean noise:    {sum(noise_samples)/len(noise_samples):.2f}")
print(f"  Decoding threshold: |noise| < {q2//4} = q/4")
print(f"  Coefficients exceeding threshold: {failures}/{len(noise_samples)} ({100*failures/len(noise_samples):.2f}%)")

# Histogram
fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(noise_samples, bins=range(min(noise_samples)-1, max(noise_samples)+2),
        color='#3498db', edgecolor='black', linewidth=0.5, alpha=0.8)
ax.axvline(x=q2//4, color='red', linestyle='--', linewidth=2, label=f'Threshold = +{q2//4}')
ax.axvline(x=-q2//4, color='red', linestyle='--', linewidth=2, label=f'Threshold = -{q2//4}')
ax.set_xlabel('Noise value', fontsize=12)
ax.set_ylabel('Count', fontsize=12)
ax.set_title(f'Decapsulation Noise Distribution (n={n2}, q={q2}, k={k2})', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
plt.tight_layout()
plt.show()

if failures == 0:
    print("\nAll noise values are within threshold. Decryption always succeeds!")
else:
    print(f"\nNote: {failures} values exceeded threshold. With toy parameters, some failures are expected.")
    print("Real Kyber parameters are chosen so the failure probability is < 2^{-139}.")

> **Checkpoint:** The noise grows as a product of small polynomials. In real Kyber:
> - Coefficients of $\mathbf{s}, \mathbf{e}, \mathbf{r}, \mathbf{e}_1, e_2$ are in $\{-\eta, \ldots, \eta\}$ with $\eta \leq 3$.
> - After multiplication and accumulation, the noise per coefficient is bounded by roughly $k \cdot n \cdot \eta^2$.
> - For Kyber-768: $k=3$, $n=256$, $\eta=2$, so worst-case noise $\approx 3 \cdot 256 \cdot 4 = 3072$.
> - The threshold is $q/4 = 3329/4 \approx 832$.
>
> Wait, $3072 > 832$?! The **worst case** exceeds the threshold, but the **probabilistic** bound (using the fact that CBD noise is concentrated near 0) gives a failure probability of $< 2^{-139}$. This is why Kyber uses the Centered Binomial Distribution, not uniform sampling.

## 9. Performance Comparison

Post-quantum schemes must be practical. How does ML-KEM compare to what it replaces?

In [None]:
# === Performance Comparison Table ===
# (Representative benchmarks from liboqs / OpenSSL, x86-64 with AVX2)

perf_data = [
    ('X25519 (ECDH)',     'Classical',  32,    32,    '~125 us',    'Yes (Shor)'),
    ('RSA-2048',          'Classical',  256,   256,   '~1600 us',   'Yes (Shor)'),
    ('RSA-4096',          'Classical',  512,   512,   '~8000 us',   'Yes (Shor)'),
    ('ML-KEM-512',        'PQ',         800,   768,   '~150 us',    'No'),
    ('ML-KEM-768',        'PQ',         1184,  1088,  '~200 us',    'No'),
    ('ML-KEM-1024',       'PQ',         1568,  1568,  '~270 us',    'No'),
]

print("Scheme Type PK (B) CT (B) KeyGen+Encaps Broken by QC?")for row in perf_data:
    print(f"{row[0]} {row[1]} {row[2]} {row[3]} {row[4]} {row[5]}")

print()
print("Key observations:")
print("  1. ML-KEM-768 is only ~1.5x slower than X25519, but quantum-resistant.")
print("  2. ML-KEM is MUCH faster than RSA for comparable security.")
print("  3. The main cost is bandwidth: ~1 KB keys vs 32 bytes for X25519.")
print("  4. This is why TLS 1.3 uses HYBRID mode: X25519 + ML-KEM-768 together.")

## Exercises

### Exercise 1: Verify the Noise Cancellation (Fully Worked)

Manually expand $v - \mathbf{s}^T \mathbf{u}$ for our toy parameters and verify that the $\mathbf{s}^T \mathbf{A}^T \mathbf{r}$ terms cancel.

In [None]:
# === Exercise 1 (Fully Worked): Verify Noise Cancellation ===
#
# We'll create a fresh instance and show every intermediate value.

set_random_seed(100)

# KeyGen
A_ex = matrix(Rq, k, k, lambda i,j: random_poly(Rq, n_toy, q_toy))
s_ex = vector(Rq, [small_poly(Rq, n_toy) for _ in range(k)])
e_ex = vector(Rq, [small_poly(Rq, n_toy) for _ in range(k)])
t_ex = A_ex * s_ex + e_ex

# Encapsulate with message [1, 0, 1, 0]
msg_ex = [1, 0, 1, 0]
r_ex  = vector(Rq, [small_poly(Rq, n_toy) for _ in range(k)])
e1_ex = vector(Rq, [small_poly(Rq, n_toy) for _ in range(k)])
e2_ex = small_poly(Rq, n_toy)
m_enc_ex = encode_message(Rq, n_toy, q_toy, msg_ex)

u_ex = A_ex.transpose() * r_ex + e1_ex
v_ex = t_ex * r_ex + e2_ex + m_enc_ex

# Now expand v - s^T * u step by step
print("Step-by-step expansion of v - s^T * u:")

# v = t^T r + e2 + m_enc
#   = (As + e)^T r + e2 + m_enc
#   = s^T A^T r + e^T r + e2 + m_enc
term_sATr_from_v = s_ex * (A_ex.transpose() * r_ex)  # s^T A^T r (from v)
term_eTr = e_ex * r_ex  # e^T r

# s^T u = s^T (A^T r + e1) = s^T A^T r + s^T e1
term_sATr_from_u = s_ex * (A_ex.transpose() * r_ex)  # s^T A^T r (from s^T u)
term_sTe1 = s_ex * e1_ex  # s^T e1

print(f"\n1. v = t^T*r + e2 + m_encoded")
print(f"   = (A*s + e)^T * r + e2 + m_encoded")
print(f"   = s^T*A^T*r + e^T*r + e2 + m_encoded")
print(f"\n2. s^T*u = s^T*(A^T*r + e1)")
print(f"   = s^T*A^T*r + s^T*e1")
print(f"\n3. v - s^T*u = (s^T*A^T*r + e^T*r + e2 + m_encoded) - (s^T*A^T*r + s^T*e1)")
print(f"             = m_encoded + (e^T*r + e2 - s^T*e1)")
print(f"               ^^^^^^^^^^^   ^^^^^^^^^^^^^^^^^^^^^^^^^")
print(f"                message           small noise")

# Verify numerically
noise_term = term_eTr + e2_ex - term_sTe1
expected = m_enc_ex + noise_term
actual = v_ex - s_ex * u_ex

print(f"\nNumerical verification:")
print(f"  m_encoded   = {m_enc_ex}")
print(f"  noise term  = {noise_term}")
print(f"  Expected    = {expected}")
print(f"  Actual v-su = {actual}")
print(f"  Match: {expected == actual}")

# Decode
recovered_ex = decode_message(actual, n_toy, q_toy)
print(f"\n  Original msg:  {msg_ex}")
print(f"  Recovered msg: {recovered_ex}")
print(f"  Correct: {msg_ex == recovered_ex}")

### Exercise 2: What Happens When Noise Is Too Large? (Guided)

Increase the noise bound in the small polynomial sampling from $\pm 1$ to $\pm 4$. Run KeyGen, Encapsulate, and Decapsulate. Does decryption still succeed? Why or why not?

**Hints:**
- Modify the `bound` parameter in `small_poly()` from 1 to 4.
- With $q=17$ and large noise, the accumulated error may exceed $q/4 \approx 4$.
- Run multiple trials (at least 20) and count how many fail.

In [None]:
# === Exercise 2 (Guided): Large Noise Experiment ===
#
# TODO: Fill in the missing parts marked with ???

def small_poly_large(ring, n, bound):
    """Sample a polynomial with coefficients in [-bound, bound]."""
    coeffs = [ZZ.random_element(-bound, bound + 1) for _ in range(n)]
    return ring(coeffs)

noise_bound = 4  # ??? Try different values: 1, 2, 3, 4
num_trials_ex2 = 50
successes = 0

for trial in range(num_trials_ex2):
    # KeyGen with large noise
    A_ex2 = matrix(Rq, k, k, lambda i,j: random_poly(Rq, n_toy, q_toy))
    s_ex2 = vector(Rq, [small_poly_large(Rq, n_toy, noise_bound) for _ in range(k)])
    e_ex2 = vector(Rq, [small_poly_large(Rq, n_toy, noise_bound) for _ in range(k)])
    t_ex2 = A_ex2 * s_ex2 + e_ex2
    
    # Encapsulate
    msg_ex2 = [1, 0, 1, 1]
    r_ex2  = vector(Rq, [small_poly_large(Rq, n_toy, noise_bound) for _ in range(k)])
    e1_ex2 = vector(Rq, [small_poly_large(Rq, n_toy, noise_bound) for _ in range(k)])
    e2_ex2 = small_poly_large(Rq, n_toy, noise_bound)
    m_enc2 = encode_message(Rq, n_toy, q_toy, msg_ex2)
    
    u_ex2 = A_ex2.transpose() * r_ex2 + e1_ex2
    v_ex2 = t_ex2 * r_ex2 + e2_ex2 + m_enc2
    
    # Decapsulate
    noisy_ex2 = v_ex2 - s_ex2 * u_ex2
    recovered_ex2 = decode_message(noisy_ex2, n_toy, q_toy)
    
    if recovered_ex2 == msg_ex2:
        successes += 1

print(f"Noise bound: +/- {noise_bound}")
print(f"Successes: {successes}/{num_trials_ex2} ({100*successes/num_trials_ex2:.0f}%)")
print(f"Failures:  {num_trials_ex2 - successes}/{num_trials_ex2}")
print()
print("# ??? Explain your observation:")
print("# Why does increasing the noise bound cause more failures?")
print("# What is the relationship between noise_bound, q, and decryption correctness?")
print("# Hint: the decoding threshold is q/4 =", q_toy // 4)

### Exercise 3: Implement Kyber with $k=3$ (Independent)

Modify the toy Kyber implementation to use module rank $k=3$ (corresponding to ML-KEM-768 security level). Use the toy parameters $n=4$, $q=17$.

1. Generate a $3 \times 3$ matrix of polynomials for $\mathbf{A}$.
2. Run KeyGen, Encapsulate, and Decapsulate.
3. Verify that decryption succeeds.
4. Compare the number of polynomial elements in the public key for $k=2$ vs $k=3$.

*No hints. You have all the building blocks from this notebook.*

In [None]:
# === Exercise 3 (Independent): Kyber with k=3 ===
#
# Your code here. Use the helper functions defined earlier in this notebook.
# 
# Steps:
#   1. Set k3 = 3
#   2. KeyGen: A is 3x3, s and e are length-3 vectors
#   3. Encapsulate: r and e1 are length-3 vectors
#   4. Decapsulate: verify message recovery
#   5. Print the public key sizes for k=2 vs k=3



## What Comes Next

This notebook gave you the **overview** of Kyber / ML-KEM. The journey continues:

- **Module 08 Break notebooks:** We will attack toy Kyber with deliberately weakened parameters. What happens if $q$ is too small? If the noise distribution is wrong? If the polynomial ring has a bad structure? These "break" exercises teach you *why* each design choice matters.

- **Module 08 Connect notebooks:** ML-KEM in the wild. How does TLS 1.3 hybrid key exchange work? What is ML-DSA (Dilithium), the lattice-based *signature* scheme that complements ML-KEM? How do Signal and iMessage combine classical and post-quantum primitives?

- **Rust implementation (Module 08 project):** Implement ML-KEM from scratch in Rust, including NTT-based polynomial multiplication, CBD sampling, and compression/decompression of ciphertext.

You now understand the core idea: **Module-LWE gives us a trapdoor**. Publishing $(\mathbf{A}, \mathbf{A}\mathbf{s}+\mathbf{e})$ hides $\mathbf{s}$, but knowing $\mathbf{s}$ lets you strip away the randomness and recover a message. This is the same encrypt/decrypt paradigm as RSA and ElGamal, but built on a problem that quantum computers cannot solve.

## Summary

In this notebook we explored **Kyber / ML-KEM**, the NIST-standardized post-quantum KEM. Key takeaways:

1. **Module-LWE** is the sweet spot between LWE (general but slow) and Ring-LWE (fast but inflexible). Kyber works over $R_q^{k \times k}$, a matrix of polynomials.

2. **KEM flow:** KeyGen produces $(\mathbf{A}, \mathbf{t}=\mathbf{A}\mathbf{s}+\mathbf{e})$. Encapsulate creates $(\mathbf{u}, v)$ encoding a random message. Decapsulate uses $\mathbf{s}$ to recover the message, and both parties derive the same shared secret $K$.

3. **Correctness** depends on noise cancellation: $v - \mathbf{s}^T\mathbf{u} = \lceil q/2 \rfloor \cdot m + \text{small noise}$. The $\mathbf{s}^T\mathbf{A}^T\mathbf{r}$ terms cancel algebraically.

4. **NIST parameters:** ML-KEM-512/768/1024 use $k=2/3/4$ with $n=256$, $q=3329$. Key sizes are ~1 KB (vs 32 bytes for X25519), but performance is competitive.

5. **This is deployed NOW.** ML-KEM (FIPS 203) protects Chrome, Signal, iMessage, and TLS 1.3 traffic today.

**Connection to the full module:** Notebooks 08a-08c built your lattice intuition (bases, SVP, LLL). Notebook 08d introduced LWE as a hard problem on lattices. Notebook 08e moved to Ring-LWE for efficiency. This notebook (08f) assembled all the pieces into the real-world scheme. The progression was: *geometry* (08a-08c) $\to$ *hardness* (08d) $\to$ *efficiency* (08e) $\to$ *deployment* (08f).