# CBC Padding Oracle Attack

Padding might seem harmless, but if we can just figure out whether PKCS#5/PKCS#7 padding is correct or not, we can decrypt the entire message. A padding oracle is some function that tells us whether the padding of a given ciphertext is correct or not. In reality, this is a very common vulnerability in web applications! It might be present as an error returned to the caller, or through a side-channel attack, such as noticing that the server takes longer to respond when the padding is correct. AWS's s2n implementation of TLS ended up disabling support for CBC mode altogether after it was found to be vulnerable to the timing attack version in 2015... a full 13 years after the attack was first published!

## PKCS#7 Padding

Let's start by defining our pad/unpad functions!

In [119]:
def pad(plaintext: bytes, block_size: int) -> bytes:
    padding = block_size - (len(plaintext) % block_size)
    return plaintext + bytes([padding] * padding)


def unpad(padded: bytes) -> bytes:
    padding = padded[-1]
    if padding == 0 or padding > len(padded):
        raise ValueError("Invalid padding")
    if any(p != padding for p in padded[-padding:]):
        raise ValueError("Invalid padding")
    return padded[:-padding]


def demo_padding():
    plaintext = b"Hello, world!"
    block_size = 16
    padded = pad(plaintext, block_size)
    indent = 10
    print("Plaintext:".ljust(indent), plaintext.hex())
    print("Padded:".ljust(indent), padded.hex())
    print("Unpadded:".ljust(indent), unpad(padded).hex())


demo_padding()

Plaintext: 48656c6c6f2c20776f726c6421
Padded:    48656c6c6f2c20776f726c6421030303
Unpadded:  48656c6c6f2c20776f726c6421


## The Padding Oracle

Now, let's define our padding oracle. We pass it some ciphertext, and it decrypts it and checks the plaintext's padding. It will return `True` if the padding is correct, and `False` otherwise. This is simulating a real-life server; in a real attack, we would have to send the ciphertext to the server and check the response, or measure the time it takes to respond.

In [120]:
import secrets
from Crypto.Cipher import AES

# The secret key used for encryption/decryption by the server
# We won't have access to this key during the attack
secret_key = secrets.token_bytes(AES.key_size[0])

def xor_bytes(a: bytes, b: bytes) -> bytes:
    return bytes(x ^ y for x, y in zip(a, b))

# We can't use this function in the attack
def aes_cbc_encrypt(plaintext: bytes, key: bytes) -> bytes:
    iv = secrets.token_bytes(AES.block_size)
    cipher = AES.new(key, AES.MODE_CBC, iv=iv)
    return iv + cipher.encrypt(plaintext)


# We can't use this function in the attack
def aes_cbc_decrypt(ciphertext: bytes, key: bytes) -> bytes:
    iv, ciphertext = ciphertext[:AES.block_size], ciphertext[AES.block_size:]
    cipher = AES.new(key, AES.MODE_CBC, iv=iv)
    return cipher.decrypt(ciphertext)


# This is the only encryption-related function we can use in the attack!
# Note that the IV is prepended to the ciphertext
def padding_oracle(ciphertext: bytes) -> bool:
    plaintext = aes_cbc_decrypt(ciphertext, secret_key)
    try:
        unpad(plaintext)
        return True
    except ValueError:
        return False

## The Intercepted Message

We've intercepted an IV and ciphertext that we want to decrypt. Let's generate an example first, then proceed to the attack!

In [121]:
def encrypt_secret_message() -> bytes:
    plaintext = b"Comic Sans is the best font! This is top-secret info that I'd rather not reveal..."
    plaintext = pad(plaintext, AES.block_size)
    ciphertext = aes_cbc_encrypt(plaintext, secret_key)
    return ciphertext

intercepted_ciphertext = encrypt_secret_message()
print(intercepted_ciphertext.hex(' ', 2))

ad38 e9c3 3307 2dca e93f cd1c 7027 866f d877 ce47 7578 ed1f e275 220c 10ee 46f6 1666 487d af0e b1b2 8714 9ca9 c25c b4b7 647f 464d f193 64c9 4e01 b6b0 9a58 fd46 8661 80d7 d815 bb97 a289 c644 f321 37c6 f762 a59c 0dc3 8bec d656 3cfe 7d8e 0dff ca9d 0f7d 0f33 ce6b 683d 9128 7b56 54c6


## The Attack

Okay, this is the part you've been waiting for! Let's exploit the padding oracle and see if we can crack the message we've intercepted. Knowing if a message's padding is valid or not leaks information about the plaintext. It's not as drastically obvious as our attack on ECB mode, and there's more maths involved, but it's just as effective at cracking CBC mode! It's proven to be fiendishly difficult to patch in real-world systems too.

In [122]:
def blocks(ciphertext: bytes):
    return [
        ciphertext[i : i + AES.block_size]
        for i in range(0, len(ciphertext), AES.block_size)
    ]


def find_padding_size(iv: bytes, r: bytes, c: bytes) -> int:
    # We could treat all of these as preconditions, but it's a useful sanity check!
    if len(r) != AES.block_size:
        raise ValueError("R must be a single block")
    if len(c) != AES.block_size:
        raise ValueError("C must be a single block")
    if not padding_oracle(iv + r + c):
        raise ValueError("R || C must have valid padding")

    # Copy R so we don't modify the original
    r = bytearray(r)
    for i in range(len(r)):
        # Flip a bit in the i-th byte of R
        r[i] ^= 1
        # If that invalidates the padding, then the i-th byte of R is the start of the padding
        if not padding_oracle(r + c):
            return len(r) - i


def padding_oracle_attack():
    original_iv, *original_ciphertext = blocks(intercepted_ciphertext)
    prev_ciphertext = [original_iv] + original_ciphertext[:-1]
    plaintext = b""

    # We're going to use this null block as the IV when asking the oracle about padding.
    # This is to illustrate that it really doesn't matter! You can make this any value and it'll still work.
    # In fact, even if we don't know the IV for the intercepted message, we can still decrypt all but the first block!
    zero_iv = bytearray(AES.block_size)

    for c, prev_c in zip(original_ciphertext, prev_ciphertext):
        dc = bytearray(AES.block_size)
        r = bytearray(AES.block_size)
        # Try different values for the last byte of R until r || c has valid padding
        while not padding_oracle(zero_iv + r + c):
            r[-1] += 1
        # Now we know that the padding is valid, but we don't know if it's 01, 02 02, 03 03 03 or larger!
        padding_size = find_padding_size(zero_iv, r, c)
        # With n bytes of padding, we can recover the last n bytes of the block cipher output!
        dc[-padding_size:] = xor_bytes(
            r[-padding_size:], bytes([padding_size] * padding_size)
        )
        # Now we can recover the rest of the block cipher output
        for i in range(AES.block_size - padding_size - 1, -1, -1):
            # Change R so that the previous padding bytes are incremented by 1
            padding = AES.block_size - i
            r[i + 1 :] = xor_bytes(dc[i + 1 :], bytes([padding] * len(r[i + 1 :])))
            # Find the value for the i-th byte of R that results in valid padding
            while not padding_oracle(zero_iv + r + c):
                r[i] += 1
            # Recover the i-th byte of the block cipher output
            dc[i] = r[i] ^ padding

        # XOR the decrypted block with the previous ciphertext block to recover the plaintext
        plaintext += xor_bytes(dc, prev_c)

    return unpad(plaintext).decode()


padding_oracle_attack()

"Comic Sans is the best font! This is top-secret info that I'd rather not reveal..."