In [60]:
from pynq import Overlay, MMIO
import pynq
import time

In [61]:
class AES128IP:
    """
    Simple driver for the AES-128 AXI4-Lite peripheral on PYNQ.
    Writes 128-bit key and plaintext, reads back 128-bit ciphertext.
    """
    def __init__(self, bitfile: str, base_addr: int, addr_range: int = 0x10000):
        # Load FPGA overlay
        self.overlay = Overlay(bitfile)
        # Instantiate MMIO to the AES IP base address
        self.mmio = MMIO(base_addr, addr_range)
        # Register map offsets
        self.OFF_STATE = 0x00  # four 32-bit words of plaintext
        self.OFF_KEY   = 0x10  # four 32-bit words of key
        self.OFF_CIPH  = 0x20  # four 32-bit words of ciphertext        

    def encrypt(self, key: bytes, plaintext: bytes) -> bytes:
        """
        Encrypts a single 16-byte block with a 16-byte key.
        key, plaintext: little-endian byte arrays length 16.
        Returns 16-byte ciphertext.
        """
        assert len(key) == 16, "Key must be 16 bytes"
        assert len(plaintext) == 16, "Plaintext must be 16 bytes"
        # Write plaintext to state registers
        for i in range(4):
            word = int.from_bytes(plaintext[i*4:(i+1)*4], 'big')
            self.mmio.write_reg(self.OFF_STATE + 4*i, word)
        # Write key to key registers
        for i in range(4):
            word = int.from_bytes(key[i*4:(i+1)*4], 'big')
            self.mmio.write(self.OFF_KEY + 4*i, word)
        # Optionally, insert small delay or dummy read to ensure AES core computes
        # Read back ciphertext words
        cipher = bytearray(16)
        for i in range(4):
            word = self.mmio.read(self.OFF_CIPH + 4*i)
            cipher[i*4:(i+1)*4] = word.to_bytes(4, 'big')
        return bytes(cipher)

In [62]:
# Example usage:
aes = AES128IP(bitfile='aes_128_overlay.bit', base_addr=0x43C00000)

In [63]:
key = bytes.fromhex('2b7e151628aed2a6abf7158809cf4f3c')
pt  = bytes.fromhex('3243f6a8885a308d313198a2e0370734')
ct  = bytes.fromhex('3925841d02dc09fbdc118597196a0b32')
result  = aes.encrypt(key, pt)
print(result.hex())
print(result == ct)

3925841d02dc09fbdc118597196a0b32
True


In [64]:
key = bytes.fromhex('00000000000000000000000000000000')
pt  = bytes.fromhex('00000000000000000000000000000000')
ct  = bytes.fromhex('66e94bd4ef8a2c3b884cfa59ca342b2e')
result  = aes.encrypt(key, pt)
print(result.hex())
print(result == ct)

66e94bd4ef8a2c3b884cfa59ca342b2e
True
