In [89]:
IRRED_POLY = 0x11B

# GF Multiplication
def gf_mult(a, b):
    result = 0
    for i in range(8):
        if b & 1:
            result ^= a
        a <<= 1
        if a & 0x100:
            a ^= IRRED_POLY
        b >>= 1
    return result & 0xFF  # Ensure the result is within 0x00 to 0xFF

# Generate elements of GF(2^8)
def generate_gf2_8(primitive_element):
    field = [0]  # Start with 0
    element = primitive_element

    for _ in range(1, 256):
        field.append(element)
        element = gf_mult(element, primitive_element)

    return field

gfp2 = generate_gf2_8(2)
gfp3 = generate_gf2_8(3)

# SubByte Transformation

In [1]:
from BitVector import *
import numpy as np

def gen_sbox():
    irred_poly = BitVector(bitstring='100011011') # irreducible polynomial
    subBytesTable = [] 

    c = BitVector(bitstring='01100011')

    for i in range(0, 256):
        # Get 
        b = BitVector(intVal = i, size=8)

        # GF(2) Inverse
        b = b.gf_MI(irred_poly, 8) if i != 0 else BitVector(intVal=0)

        # Affine mapping
        b1,b2,b3,b4 = [b.deep_copy() for x in range(4)]
        b ^= (b1 >> 4) ^ (b2 >> 5) ^ (b3 >> 6) ^ (b4 >> 7) ^ c

        subBytesTable.append(int(b))
    return subBytesTable

In [2]:
sbox = gen_sbox()

In [92]:
def SubBytes(state):
    return [[sbox[byte] for byte in word] for word in state]

# ShiftRow transformation

In [93]:
def ShiftRows(state):
    Nb = len(state)
    n = [word[:] for word in state] # temp state
    # print(n)

    for i in range(Nb):
        for j in range(4): # shift rows by i
            n[i][j] = state[i][(i + j) % Nb]
            # print((i + j) % Nb)

    return n

# MixColumn transformation

In [94]:
def MixColumns(state):
    Nb = len(state)
    n = [word[:] for word in state]

    for i in range(Nb):
        n[i][0] = (gfp2[state[i][0]] ^ gfp3[state[i][1]] ^ state[i][2] ^ state[i][3])
        n[i][1] = (state[i][0] ^ gfp2[state[i][1]] ^ gfp3[state[i][2]] ^ state[i][3])
        n[i][2] = (state[i][0] ^ state[i][1] ^ gfp2[state[i][2]] ^ gfp3[state[i][3]])
        n[i][3] = (gfp3[state[i][0]] ^ state[i][1] ^ state[i][2] ^ gfp2[state[i][3]])

    return n

In [None]:
# --- Example Input ---
input_bytes = [
    0x32, 0x88, 0x31, 0xe0,
    0x43, 0x5a, 0x31, 0x37,
    0xf6, 0x30, 0x98, 0x07,
    0xa8, 0x8d, 0xa2, 0x34
]

# Convert input to state matrix
state = input_bytes
print("Original State:")
print(state)

# Apply ShiftRows
shifted_state = ShiftRows(state)
print("\nAfter ShiftRows:")
print(shifted_state)

# Apply MixColumns
shifted_state = MixColumns(state)
print("\nAfter MixColumns:")
print(shifted_state)

Original State:
[[50, 136, 49, 224], [67, 90, 49, 55], [246, 48, 152, 7], [168, 141, 162, 52]]

After ShiftRows:
[[50, 136, 49, 224], [90, 49, 55, 67], [152, 7, 246, 48], [52, 168, 141, 162]]

After MixColumns:
[[1, 154, 99, 42], [235, 203, 30, 210], [170, 66, 180, 138], [172, 3, 7, 175]]
