from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes

def padding_oracle(ciphertext: bytes) -> bool:
    """
    Checks if the given ciphertext has valid PKCS#7 padding after decryption.
    Returns True if valid, False otherwise.
    """
    BLOCK_SIZE = 16  # AES block size in bytes

    # Check if ciphertext length is valid (must be a multiple of 16)
    if len(ciphertext) % BLOCK_SIZE != 0:
        return False

    try:
        # Extract the IV and the actual ciphertext
        iv = ciphertext[:BLOCK_SIZE]
        ct = ciphertext[BLOCK_SIZE:]

        # Set up AES in CBC mode using the key and IV
        cipher = Cipher(algorithms.AES(KEY), modes.CBC(iv))
        decryptor = cipher.decryptor()

        # Decrypt the ciphertext
        decrypted_data = decryptor.update(ct) + decryptor.finalize()

        # Try to remove padding
        unpadder = padding.PKCS7(BLOCK_SIZE * 8).unpadder()
        unpadder.update(decrypted_data)
        unpadder.finalize()

        return True  # If no error, padding is valid

    except (ValueError, TypeError):
        return False  # Padding was invalid



## Task 1 – Understand the Components

### Q1: How does the `padding_oracle()` function check if padding is valid?

The `padding_oracle()` function decrypts the ciphertext using AES-CBC mode.  
After decryption, it tries to remove the padding using PKCS#7.  
If the padding is correct, it succeeds and returns True.  
If the padding is invalid, it throws a ValueError or TypeError, which is caught — then the function returns False.  
This lets us know whether the padding in the decrypted message is valid or not.

### Q2: What is the purpose of the IV in CBC mode?

The IV (Initialization Vector) is used to randomize encryption.  
In CBC mode, the IV is XORed with the first plaintext block before it is encrypted.  
This helps make sure that even if the same message is encrypted twice, the ciphertext will look different.  
It protects against attackers seeing patterns in repeated messages.


### Q3: Why must the ciphertext length be a multiple of the block size?

AES works in blocks of 16 bytes.  
If the ciphertext isn’t a multiple of 16, decryption will not work because AES expects full blocks.  
Also, when using padding like PKCS#7, the encrypted data includes padding to make sure the length is correct.


In [1]:
def split_blocks(data: bytes) -> list[bytes]:
    """
    Split data into blocks of 16 bytes each (AES block size).
    
    Parameters:
        data (bytes): The input data to split.
    
    Returns:
        list[bytes]: A list of byte blocks, each 16 bytes long.
    """
    block_size = 16  # AES uses 16-byte blocks
    return [data[i:i + block_size] for i in range(0, len(data), block_size)]


In [2]:
# Test the function with 32 bytes of data (2 full blocks)
test_data = b"ABCDEFGHIJKLMNOP" + b"QRSTUVWXYZ123456"
blocks = split_blocks(test_data)

for i, block in enumerate(blocks):
    print(f"Block {i+1}: {block}")


Block 1: b'ABCDEFGHIJKLMNOP'
Block 2: b'QRSTUVWXYZ123456'


In [3]:
def decrypt_block(prev_block: bytes, target_block: bytes) -> bytes:
    """
    Decrypt a single block using the padding oracle attack.
    
    Parameters:
        prev_block (bytes): The previous ciphertext block (or IV).
        target_block (bytes): The ciphertext block we want to decrypt.
    
    Returns:
        bytes: The recovered plaintext block.
    """
    block_size = 16
    intermediate = [0] * block_size  # Stores intermediate state (D_k(C_i))
    plaintext = [0] * block_size     # Final plaintext bytes

    # Start from last byte to the first
    for byte_index in range(block_size - 1, -1, -1):
        padding_value = block_size - byte_index
        
        # Try all possible byte values (0-255)
        for guess in range(256):
            # Build a fake block to force the padding
            fake_block = bytearray(block_size)
            
            # Fill in known intermediate values for padding
            for i in range(byte_index + 1, block_size):
                fake_block[i] = intermediate[i] ^ padding_value
            
            # Set the current guess for this byte
            fake_block[byte_index] = guess

            # Combine with target block and send to oracle
            crafted = bytes(fake_block) + target_block
            if padding_oracle(crafted):
                # Found a valid guess! Compute intermediate and plaintext byte
                intermediate[byte_index] = guess ^ padding_value
                plaintext[byte_index] = intermediate[byte_index] ^ prev_block[byte_index]
                break

    return bytes(plaintext)


In [4]:
def padding_oracle_attack(ciphertext: bytes) -> bytes:
    """
    Perform the padding oracle attack on the full ciphertext.
    
    Parameters:
        ciphertext (bytes): Full ciphertext including IV + encrypted blocks.
    
    Returns:
        bytes: The recovered full plaintext.
    """
    blocks = split_blocks(ciphertext)
    recovered_plaintext = b""

    # Start from block 1 (C1), use block 0 (IV) as prev_block
    for i in range(1, len(blocks)):
        prev = blocks[i - 1]
        curr = blocks[i]
        print(f"[*] Decrypting block {i}...")
        recovered_block = decrypt_block(prev, curr)
        recovered_plaintext += recovered_block

    return recovered_plaintext


In [None]:
def unpad_and_decode(plaintext: bytes) -> str:
    """
    Remove PKCS#7 padding and decode the plaintext to a UTF-8 string.
    
    Parameters:
        plaintext (bytes): The recovered plaintext with padding.
    
    Returns:
        str: Cleaned and readable message.
    """
    unpadder = padding.PKCS7(16 * 8).unpadder()
    unpadded = unpadder.update(plaintext) + unpadder.finalize()
    return unpadded.decode('utf-8')


In [None]:
if __name__ == "__main__":
    try:
        ciphertext = unhexlify(CIPHERTEXT_HEX)

        print(f"[*] Ciphertext length: {len(ciphertext)} bytes")
        print(f"[*] IV: {ciphertext[:BLOCK_SIZE].hex()}")

        recovered = padding_oracle_attack(ciphertext)

        print("\n[+] Decryption complete!")
        print(f"Recovered plaintext (raw bytes): {recovered}")
        print(f"Hex: {recovered.hex()}")

        
        decoded = unpad_and_decode(recovered)
        print("\nFinal plaintext:")
        print(decoded)

    except Exception as e:
        print(f"\nError occurred: {e}")

In [None]:
## Final Output

In [None]:
### Final Decrypted Message:
This is a top secret message. Decrypt me if you can!