In [94]:
IRRED_POLY = 0x11B

In [95]:
import numpy as np

# Basic Functions

In [96]:
def process_key(key, Nk=4):
    try:
        key = key.replace(" ", "")
        
        # Convert hex string to list of integers
        bytes_list = [int(key[i:i+2], 16) for i in range(0, len(key), 2)]

        # Reshape into a (Nk, 4) NumPy array
        return np.array(bytes_list).reshape((Nk, 4))
    except:
        print ("Password must be hexadecimal.")

def process_input(input_text):
    try:
        return process_key(input_text).T

    except Exception as e:
        print("Input must be a 16-byte hexadecimal string.")

def bytes_to_word(bytes_list):
    """1 word = 2 * 4 bytes, 1 byte = 8 bits"""
    if len(bytes_list) != 4:
        raise Exception()

    return (bytes_list[0] << 24) | (bytes_list[1] << 16) | (bytes_list[2] << 8) | bytes_list[3]

def word_to_bytes(word):
    return [
        (word >> 24) & 0xFF,
        (word >> 16) & 0xFF,
        (word >> 8) & 0xFF,
        word & 0xFF
    ]

# Mathematics

In [97]:
# GF Multiplication
def gf_mult(a, b):
    result = 0
    for i in range(8):
        if b & 1:
            result ^= a
        a <<= 1
        if a & 0x100:
            a ^= IRRED_POLY
        b >>= 1
    return result & 0xFF  # Ensure the result is within 0x00 to 0xFF

In [98]:
# Extension field multiplication - book Definition 4.3.4
def generate_gf2_8(primitive_element):
    field = [0]  # Start with 0
    element = primitive_element

    for _ in range(1, 256):
        field.append(element)
        element = gf_mult(element, primitive_element)

    return field

In [99]:
gfp2 = generate_gf2_8(2)
gfp3 = generate_gf2_8(3)

# AES Functions

## SubByte Transformation

In [100]:
from BitVector import *

def gen_sbox():
    irred_poly = BitVector(bitstring='100011011') # irreducible polynomial
    subBytesTable = [] 

    c = BitVector(bitstring='01100011')

    for i in range(0, 256):
        # Initialise bit 
        b = BitVector(intVal = i, size=8)

        # GF(2) Inverse
        b = b.gf_MI(irred_poly, 8) if i != 0 else BitVector(intVal=0)

        # Affine mapping
        b1,b2,b3,b4 = [b.deep_copy() for x in range(4)]
        b ^= (b1 >> 4) ^ (b2 >> 5) ^ (b3 >> 6) ^ (b4 >> 7) ^ c

        subBytesTable.append(int(b))
    return subBytesTable

In [101]:
sbox = gen_sbox()

In [102]:
def SubBytes(state):
    return [[sbox[byte] for byte in word] for word in state]

## ShiftRow transformation

In [103]:
def ShiftRows(state):
    Nb = len(state)
    n = [word[:] for word in state] # temp state

    for i in range(Nb):
        for j in range(4): # shift rows by i
            n[i][j] = state[i][(i + j) % Nb]
            # print((i + j) % Nb)

    return n

# MixColumn transformation

In [104]:
def MixColumns(state):
    Nb = len(state)
    n = [[0] * Nb for _ in state]

    for c in range(Nb):
        s0 = state[0][c]
        s1 = state[1][c]
        s2 = state[2][c]
        s3 = state[3][c]

        n[0][c] = gf_mult(2, s0) ^ gf_mult(3, s1) ^ s2 ^ s3
        n[1][c] = s0 ^ gf_mult(2, s1) ^ gf_mult(3, s2) ^ s3
        n[2][c] = s0 ^ s1 ^ gf_mult(2, s2) ^ gf_mult(3, s3)
        n[3][c] = gf_mult(3, s0) ^ s1 ^ s2 ^ gf_mult(2, s3)

    return n

## Key Expansion

### Generate RC value

In [105]:
def generate_rc_values(Nr=10):
    rc = [0x01]
    for _ in range(1, Nr):
        rc.append(gf_mult(rc[-1], 0x02))
    return rc

rc_values = generate_rc_values()

### g function

In [106]:
def g(word, rc):
    bytes = word_to_bytes(word)
    # Rotate left (RotWord)
    bytes = bytes[1:] + bytes[:1]
    # Substitute bytes (SubWord)
    bytes = [sbox[byte] for byte in bytes]
    # Add round constant (Rcon) - XOR
    bytes[0] ^= rc
    return bytes_to_word(bytes)

In [107]:
def keyExpansion(key, Nb=4, Nk=4, Nr=10):
    # Default AES-128

    # Initialise first round of keys
    w = [0] * (Nb + Nk * Nr)
    for i in range(Nb):
        w[i] = bytes_to_word(key[i])

    # Calculate subsequent round of keys
    for i in range(1, Nr+1):
        for j in range(4):
            if j == 0:
                w[4*i] = w[4*(i-1)] ^ g(w[4*i-1], rc=rc_values[i-1])
            else: 
                w[4*i + j] = w[4*i + j - 1] ^ w[4*(i-1) + j]

    return w

## AddKeyRound transformation

In [108]:
def AddRoundKey(state, keys):
    Nb = len(state)
    s_ = [[None for j in range(4)] for i in range(Nb)]

    k_ = [word_to_bytes(word) for word in keys]

    for c in range(Nb):
        for i in range(4):
            s_[i][c] = state[i][c] ^ k_[c][i]

    return s_

# AES-128

In [109]:
# Variables for AES-128
Nb = 4
Nk = 128 / 32
Nr = 10

IRRED_POLY = 0x11B

# Default
sbox = gen_sbox()

key = '2b 7e 15 16 28 ae d2 a6 ab f7 15 88 09 cf 4f 3c'
input_str = '32 43 f6 a8 88 5a 30 8d 31 31 98 a2 e0 37 07 34'


In [110]:
key = process_key(key)
W = keyExpansion(key)

In [111]:
block = process_input(input_str)

In [112]:
def print_state_hex(state):
    for word in state:
        for byte in word:
            print(f"{byte:02x}", end=" ")
        print()

In [113]:
print_state_hex(block)

32 88 31 e0 
43 5a 31 37 
f6 30 98 07 
a8 8d a2 34 


In [114]:
# Initial round
block = AddRoundKey(block, W[:4])

# Nr
for r in range(1,Nr):
    block = SubBytes(block)
    block = ShiftRows(block)
    block = MixColumns(block)
    block = AddRoundKey(block, W[4*r : 4*r + 4])

# Final round
block = SubBytes(block)
block = ShiftRows(block)
block = AddRoundKey(block, W[-4:])

print_state_hex(block)

39 02 dc 19 
25 dc 11 6a 
84 09 85 0b 
1d fb 97 32 
