In [None]:
from pynq import Overlay, MMIO
import pynq
import time

In [None]:
class AES_128_Decryption_IP:
    """
    Simple driver for the AES-128 AXI4-Lite peripheral on PYNQ.
    Writes 128-bit key and plaintext, reads back 128-bit ciphertext.
    """
    def __init__(self, bitfile: str, base_addr: int, addr_range: int = 0x10000):
        # Load FPGA overlay
        self.overlay = Overlay(bitfile)
        print(dir(self.overlay))
        # Instantiate MMIO to the AES IP base address
        self.mmio = MMIO(base_addr, addr_range)
        # Register map offsets
        self.OFF_STATE  = 0x00  # four 32-bit words of ciphertext
        self.OFF_KEY    = 0x10  # four 32-bit words of key
        self.OFF_PLAIN    = 0x20  # four 32-bit words of  plaintext

    def decrypt(self, key: bytes, ciphertext: bytes) -> bytes:
        """
        Decrypts a single 16-byte block with a 16-byte key.
        key, ciphertext: little-endian byte arrays length 16.
        Returns 16-byte plaintext.
        """
        assert len(key) == 16, "Key must be 16 bytes"
        assert len(ciphertext) == 16, "Plaintext must be 16 bytes"
        # Write ciphertext to state registers
        for i in range(4):
            word = int.from_bytes(ciphertext[i*4:(i+1)*4], 'big')
            self.mmio.write_reg(self.OFF_STATE + 4*i, word)
        # Write key to key registers
        for i in range(4):
            word = int.from_bytes(key[i*4:(i+1)*4], 'big')
            self.mmio.write(self.OFF_KEY + 4*i, word)
        # Optionally, insert small delay or dummy read to ensure AES core computes
        # Read back plaintext words
        plaintext = bytearray(16) 
        for i in range(4):
            word = self.mmio.read(self.OFF_PLAIN + 4*i)
            plaintext[i*4:(i+1)*4] = word.to_bytes(4, 'big')
        return bytes(plaintext)

In [None]:
# Example usage:
aes = AES_128_Decryption_IP(bitfile='aes128_decryption_overlay.bit', base_addr=0x43C00000)

In [None]:
key = bytes.fromhex('00000000000000000000000000000000')
pt  = bytes.fromhex('00000000000000000000000000000000')
ct  = bytes.fromhex('66e94bd4ef8a2c3b884cfa59ca342b2e')
result  = aes.decrypt(key, ct)
print(result.hex())
print(result == pt)