# Full AES Round

**Module 03f** | Galois Fields and AES

*SubBytes, ShiftRows, MixColumns, AddRoundKey, the complete cipher round.*

> **Question:** You've built every piece of AES separately, the S-box (03d), MixColumns (03e), and the GF(256) arithmetic underneath (03a-03c). Now: can you put them together into a working AES round and trace a plaintext byte through all four operations?
>
> In this notebook, you'll build a complete AES round from scratch and watch the avalanche effect unfold.

## Objectives

By the end of this notebook you will be able to:

1. Implement all four AES round operations from scratch
2. Compose them into a complete AES round
3. Trace a byte through the round and explain each transformation
4. Demonstrate the avalanche effect: one bit change → half the output bits flip
5. Verify your implementation against known AES test vectors

## Bridge from 03e

In 03d you built SubBytes (nonlinear, per-byte). In 03e you built MixColumns (linear, per-column). Now we add the two remaining operations, ShiftRows and AddRoundKey, and compose all four into a single round. This is the heart of AES.

## Setup: GF(256) and S-box

In [None]:
# Complete AES setup: field, S-box, and utilities
P.<x> = GF(2)[]
F.<a> = GF(2^8, modulus=x^8 + x^4 + x^3 + x + 1)

def byte_to_gf(b):
    return sum(GF(2)((b >> i) & 1) * a^i for i in range(8))

def gf_to_byte(elem):
    p = elem.polynomial()
    return sum(int(p[i]) << i for i in range(8))

def xtime(b):
    result = b << 1
    if result & 0x100:
        result ^^= 0x11B
    return result & 0xFF

def gf256_mul(a, b):
    result = 0; temp = a
    for i in range(8):
        if b & (1 << i): result ^^= temp
        temp = xtime(temp)
    return result

# Build S-box
A_mat = matrix(GF(2), [
    [1,0,0,0,1,1,1,1],[1,1,0,0,0,1,1,1],[1,1,1,0,0,0,1,1],[1,1,1,1,0,0,0,1],
    [1,1,1,1,1,0,0,0],[0,1,1,1,1,1,0,0],[0,0,1,1,1,1,1,0],[0,0,0,1,1,1,1,1]
])
c_vec = vector(GF(2), [(0x63 >> i) & 1 for i in range(8)])

SBOX = [0] * 256
for b in range(256):
    inv_bits = vector(GF(2), [0]*8) if b == 0 else vector(GF(2), [(int(gf_to_byte(byte_to_gf(b)^(-1))) >> i) & 1 for i in range(8)])
    result_bits = A_mat * inv_bits + c_vec
    SBOX[b] = sum(int(result_bits[i]) << i for i in range(8))

print(f'S-box built. SBOX[0x00] = 0x{SBOX[0]:02X}, SBOX[0x53] = 0x{SBOX[0x53]:02X}')
print('All AES utilities ready.')

## The Four AES Round Operations

Each AES round applies four operations in order:

1. **SubBytes**, Apply S-box to each byte (nonlinear, per-byte)
2. **ShiftRows**, Cyclically shift each row of the state (permutation)
3. **MixColumns**, Matrix multiply each column over GF(256) (linear, per-column)
4. **AddRoundKey**, XOR the state with the round key

The state is a 4×4 matrix of bytes, stored column-major.

In [None]:
# AES state representation: 4x4 matrix of bytes
def bytes_to_state(data):
    """Convert 16 bytes to 4x4 state (column-major)."""
    state = [[0]*4 for _ in range(4)]
    for i in range(16):
        state[i % 4][i // 4] = data[i]
    return state

def state_to_bytes(state):
    """Convert 4x4 state back to 16 bytes."""
    return [state[i % 4][i // 4] for i in range(16)]

def print_state(state, label='State'):
    print(f'{label}:')
    for row in range(4):
        print(f'  [{" ".join(f"{state[row][col]:02X}" for col in range(4))}]')
    print()

# Test with FIPS 197 Appendix B input
plaintext = [0x32, 0x43, 0xF6, 0xA8, 0x88, 0x5A, 0x30, 0x8D,
             0x31, 0x31, 0x98, 0xA2, 0xE0, 0x37, 0x07, 0x34]
state = bytes_to_state(plaintext)
print_state(state, 'Input state (from FIPS 197)')

## Operation 1: SubBytes

Apply the S-box to every byte in the state. This is the only nonlinear operation.

In [None]:
def sub_bytes(state):
    """Apply S-box to every byte in the state."""
    return [[SBOX[state[r][c]] for c in range(4)] for r in range(4)]

# Demo
test_state = bytes_to_state(plaintext)
print_state(test_state, 'Before SubBytes')
after_sub = sub_bytes(test_state)
print_state(after_sub, 'After SubBytes')
print('Each byte independently replaced by its S-box image.')
print(f'Example: 0x32 → S-box[0x32] = 0x{SBOX[0x32]:02X}')

## Operation 2: ShiftRows

Cyclically shift row $i$ left by $i$ positions:
- Row 0: no shift
- Row 1: shift left by 1
- Row 2: shift left by 2
- Row 3: shift left by 3

This ensures that each column of the output depends on bytes from **all four columns** of the input (after MixColumns).

In [None]:
def shift_rows(state):
    """Cyclically shift row i left by i positions."""
    result = [row[:] for row in state]  # copy
    for i in range(1, 4):
        result[i] = state[i][i:] + state[i][:i]
    return result

# Demo
print_state(after_sub, 'Before ShiftRows')
after_shift = shift_rows(after_sub)
print_state(after_shift, 'After ShiftRows')
print('Row 0: unchanged')
print('Row 1: shifted left by 1')
print('Row 2: shifted left by 2')
print('Row 3: shifted left by 3')

> **Checkpoint:** ShiftRows is a simple permutation, no arithmetic, no field operations. But it's essential: without it, MixColumns would only mix within columns, and bytes in different columns would never interact. ShiftRows ensures cross-column diffusion.

## Operation 3: MixColumns

In [None]:
MC = [[0x02, 0x03, 0x01, 0x01],
      [0x01, 0x02, 0x03, 0x01],
      [0x01, 0x01, 0x02, 0x03],
      [0x03, 0x01, 0x01, 0x02]]

def mix_columns(state):
    """Apply MixColumns to each column of the state."""
    result = [[0]*4 for _ in range(4)]
    for col in range(4):
        for row in range(4):
            for k in range(4):
                result[row][col] ^^= gf256_mul(MC[row][k], state[k][col])
    return result

# Demo
print_state(after_shift, 'Before MixColumns')
after_mix = mix_columns(after_shift)
print_state(after_mix, 'After MixColumns')
print('Each column is now a mix of all four input bytes in that column.')

## Operation 4: AddRoundKey

In [None]:
def add_round_key(state, round_key):
    """XOR the state with the round key."""
    return [[state[r][c] ^^ round_key[r][c] for c in range(4)] for r in range(4)]

# FIPS 197 Appendix B, Round 1 key
round_key_bytes = [0xA0, 0xFA, 0xFE, 0x17, 0x88, 0x54, 0x2C, 0xB1,
                   0x23, 0xA3, 0x39, 0x39, 0x2A, 0x6C, 0x76, 0x05]
rk = bytes_to_state(round_key_bytes)

print_state(after_mix, 'Before AddRoundKey')
print_state(rk, 'Round Key')
after_ark = add_round_key(after_mix, rk)
print_state(after_ark, 'After AddRoundKey')
print('AddRoundKey = XOR with round key. This is GF(2) vector addition.')
print('Without this step, AES would be a fixed permutation (no secret key).')

## Complete AES Round

Now let's compose all four operations into a single round function:

In [None]:
def aes_round(state, round_key):
    """Apply one complete AES round."""
    state = sub_bytes(state)
    state = shift_rows(state)
    state = mix_columns(state)
    state = add_round_key(state, round_key)
    return state

# Apply to FIPS 197 test vector
# First: AddRoundKey with initial key (round 0 = pre-whitening)
key_bytes = [0x2B, 0x7E, 0x15, 0x16, 0x28, 0xAE, 0xD2, 0xA6,
             0xAB, 0xF7, 0x15, 0x88, 0x09, 0xCF, 0x4F, 0x3C]
initial_key = bytes_to_state(key_bytes)
state = bytes_to_state(plaintext)

print_state(state, 'Plaintext')
state = add_round_key(state, initial_key)  # Round 0: pre-whitening
print_state(state, 'After initial AddRoundKey (round 0)')

# Round 1
state = aes_round(state, rk)
print_state(state, 'After Round 1')

## The Avalanche Effect

A good cipher should exhibit the **avalanche effect**: flipping one input bit should change approximately half the output bits. Let's test this after just one round.

In [None]:
# Avalanche effect: flip one bit of plaintext, observe output change
pt_a = plaintext[:]
pt_b = plaintext[:]
pt_b[0] ^^= 0x01  # flip one bit in the first byte

# Apply initial key + one round to both
state_a = add_round_key(bytes_to_state(pt_a), initial_key)
state_a = aes_round(state_a, rk)

state_b = add_round_key(bytes_to_state(pt_b), initial_key)
state_b = aes_round(state_b, rk)

# Count differing bits
out_a = state_to_bytes(state_a)
out_b = state_to_bytes(state_b)
diff_bits = sum(bin(a ^^ b).count('1') for a, b in zip(out_a, out_b))

print(f'Plaintext A: {" ".join(f"{b:02X}" for b in pt_a)}')
print(f'Plaintext B: {" ".join(f"{b:02X}" for b in pt_b)}')
print(f'  (differ by 1 bit in byte 0)')
print()
print(f'After 1 round:')
print(f'Output A: {" ".join(f"{b:02X}" for b in out_a)}')
print(f'Output B: {" ".join(f"{b:02X}" for b in out_b)}')
print(f'  Differing bits: {diff_bits} / 128 ({100*diff_bits/128:.1f}%)')
print(f'  Ideal: ~64 / 128 (50%)')
print()
print('After just ONE round, a single bit change has already spread.')
print('After 10 rounds (full AES-128), the output is indistinguishable from random.')

> **Common mistake:** "More rounds = more security, so why not 100 rounds?" Each round adds computational cost. AES-128 uses 10 rounds, the minimum needed for full diffusion and security margin. This was determined by extensive cryptanalysis. Adding rounds beyond 10 doesn't significantly improve security but does slow down encryption.

## Anatomy of a Round: Why Each Step Matters

| Operation | Type | Purpose |
|-----------|------|--------|
| SubBytes | Nonlinear, per-byte | Confusion, resist linear/differential attacks |
| ShiftRows | Permutation | Cross-column mixing, break column isolation |
| MixColumns | Linear, per-column | Diffusion, spread each byte across the column |
| AddRoundKey | XOR with key | Key dependence, without it, AES is key-independent |

Remove any one and the cipher breaks.

## Exercises

### Exercise 1 (Worked)

Trace byte 0x32 (position [0,0] of the plaintext) through one complete round.

In [None]:
# Exercise 1 (Worked), Trace a single byte through a round
print('Tracing byte at position [0,0] through Round 1:')
print()

# Start: after initial AddRoundKey
state = add_round_key(bytes_to_state(plaintext), initial_key)
val = state[0][0]
print(f'After initial ARK: state[0][0] = 0x{val:02X}')

# SubBytes
sb = SBOX[val]
print(f'After SubBytes:    SBOX[0x{val:02X}] = 0x{sb:02X}')

# ShiftRows: row 0 doesn't shift
print(f'After ShiftRows:   0x{sb:02X} (row 0 = no shift)')

# MixColumns: position [0,0] of the output depends on all 4 bytes of column 0
state_after_sub = sub_bytes(state)
state_after_shift = shift_rows(state_after_sub)
col = [state_after_shift[r][0] for r in range(4)]
print(f'MixColumns input column 0: [{" ".join(f"0x{b:02X}" for b in col)}]')
mc_val = 0
for k in range(4):
    term = gf256_mul(MC[0][k], col[k])
    mc_val ^^= term
    print(f'  0x{MC[0][k]:02X} × 0x{col[k]:02X} = 0x{term:02X}')
print(f'After MixColumns:  0x{mc_val:02X}')

# AddRoundKey
ark_val = mc_val ^^ rk[0][0]
print(f'After AddRoundKey: 0x{mc_val:02X} ⊕ 0x{rk[0][0]:02X} = 0x{ark_val:02X}')
print()
print(f'One byte traveled: 0x{plaintext[0]:02X} → 0x{ark_val:02X} in one round.')

### Exercise 2 (Guided)

Implement the inverse round (for decryption): InvShiftRows, InvSubBytes, InvMixColumns, AddRoundKey. Apply it to the round 1 output and verify you recover the round 0 state.

In [None]:
# Exercise 2 (Guided), Inverse round

# Build inverse S-box
INV_SBOX = [0] * 256
for i in range(256):
    INV_SBOX[SBOX[i]] = i

def inv_sub_bytes(state):
    """Apply inverse S-box to every byte."""
    return [[INV_SBOX[state[r][c]] for c in range(4)] for r in range(4)]

def inv_shift_rows(state):
    """Shift row i RIGHT by i positions."""
    result = [row[:] for row in state]
    for i in range(1, 4):
        # TODO: shift row i right by i (= shift left by 4-i)
        result[i] = state[i][4-i:] + state[i][:4-i]  # TODO: verify this
    return result

# TODO: implement inv_mix_columns using the inverse MDS matrix
# The inverse matrix entries are: 0x0E, 0x0B, 0x0D, 0x09
# IMC = [[0x0E, 0x0B, 0x0D, 0x09], ...]

# TODO: compose into inv_aes_round and verify roundtrip
# Hint: order is AddRoundKey → InvMixColumns → InvSubBytes → InvShiftRows
#        (equivalent order for AES decryption)

### Exercise 3 (Independent)

1. Run 4 rounds of AES (you'll need to implement a simple key schedule, or use fixed round keys). After how many rounds does a single-bit plaintext change affect all 128 output bits?

2. What happens if you remove ShiftRows? Apply SubBytes → MixColumns → AddRoundKey for 10 rounds. Can you identify a structural weakness? (Hint: each column stays independent.)

3. Compute the **branch number** of MixColumns experimentally: for random nonzero input differences, what is the minimum number of nonzero bytes in (input difference + output difference)?

In [None]:
# Exercise 3 (Independent), Your code here


## Summary

- An AES round = **SubBytes** → **ShiftRows** → **MixColumns** → **AddRoundKey**
- SubBytes provides **confusion** (nonlinearity from GF(256) inversion + affine map)
- ShiftRows provides **cross-column mixing** (a simple permutation with deep consequences)
- MixColumns provides **diffusion** (MDS matrix multiplication over GF(256))
- AddRoundKey provides **key dependence** (XOR = GF(2) vector addition)
- After one round, a single bit change already spreads across the state
- After 10 rounds (AES-128), the cipher achieves full avalanche

**Every operation in AES is field theory in disguise:**
- Bytes = GF(256) elements
- S-box = GF(256) inversion + GF(2) affine map
- MixColumns = GF(256) matrix multiplication
- AddRoundKey = GF(2)$^{128}$ vector addition

> **Crypto foreshadowing:** AES is the most widely deployed symmetric cipher in the world, it protects TLS, Wi-Fi (WPA), disk encryption, and more. In Module 04, you'll study RSA, which uses a completely different mathematical foundation (number theory instead of Galois fields). But the underlying principle is the same: build cryptographic security on top of algebraic hardness.

**This completes Module 03.** Next: [Module 04: Number Theory and RSA](../../04-number-theory-rsa/)