# Implementation of Simplified AES



In [1]:
S_box =  [0b1001, 0b0100, 0b1010, 0b1011, 0b1101, 0b0001, 0b1000, 0b0101]
S_box += [0b0110, 0b0010, 0b0000, 0b0011, 0b1100, 0b1110, 0b1111, 0b0111]
Rcon1 = 0b10000000
Rcon2 = 0b00110000

def rot_word(word):
    n1 = word >> 4 # first nibble
    n2 = word & 0xF # last nibble
    return (n2 << 4) + n1

def sub_word(word):
    n1 = word >> 4 # first nibble
    n2 = word & 0xF # last nibble
    return (S_box[n1] << 4) + S_box[n2]

def get_keys(key):
    W = [0] * 6
    W[0] = key >> 8 # first two nibbles
    W[1] = (key & 0xFF) # last two nibbles
    W[2] = sub_word(rot_word(W[1])) ^ Rcon1 ^ W[0]
    W[3] = W[1] ^ W[2]
    W[4] = sub_word(rot_word(W[3])) ^ Rcon2 ^ W[2]
    W[5] = W[3] ^ W[4]
    return [(W[i] << 8) + W[i+1] for i in range(0,6,2)]

def text_to_matrix(text):
    return [(text & pow(2,(i+1)*4)-1) >> i*4 for i in reversed(range(4))] # matrix is represented as a list

def matrix_to_text(matrix):
    text = 0
    for i in range(4):
        text += matrix[i] << (4*(4-(i+1)))
    return text

def add_round_key(state, key_matrix):
    key_matrix = text_to_matrix(key_matrix)
    return [nibble_s ^ nibble_k for nibble_s,nibble_k in zip(state, key_matrix)]

def substitute(state):
    return [S_box[nibble] for nibble in state]

def substitute_inv(state):
    return [S_box.index(nibble) for nibble in state]

def shift(state):
    return [state[0], state[3], state[2], state[1]]

def nibble_to_bits(nibble):
    return [(nibble & pow(2,i+1)-1) >> i for i in reversed(range(4))]

def bits_to_nibble(bits):
    nibble = 0
    for i in range(4):
        nibble += bits[i] << (4-(i+1))
    return nibble

def mix(state):
    new_state = []
    for i in range(0, 4, 2):
        b = nibble_to_bits(state[i]) + nibble_to_bits(state[i+1])
        new_state.append(bits_to_nibble([b[0]^b[6], b[1]^b[4]^b[7], b[2]^b[4]^b[5], b[3]^b[5]]))
        new_state.append(bits_to_nibble([b[2]^b[4], b[0]^b[3]^b[5], b[0]^b[1]^b[6], b[1]^b[7]]))
    return new_state

def mix_inv(state):
    new_state = []
    for i in range(0, 4, 2):
        c = nibble_to_bits(state[i]) + nibble_to_bits(state[i+1])
        new_state.append(bits_to_nibble([c[3]^c[5], c[0]^c[6], c[1]^c[4]^c[7], c[2]^c[3]^c[4]]))
        new_state.append(bits_to_nibble([c[1]^c[7], c[2]^c[4], c[0]^c[3]^c[5], c[0]^c[6]^c[7]]))
    return new_state

def encrypt_round1(state, key1):
    return add_round_key(mix(shift(substitute(state))), key1)

def encrypt_round2(state, key2):
    return add_round_key(shift(substitute(state)), key2)

def decrypt_round2(state, key2):
    return substitute_inv(shift(add_round_key(state, key2)))

def decrypt_round1(state, key1):
    return substitute_inv(shift(mix_inv(add_round_key(state, key1))))

def encrypt(plaintext, key):
    key0, key1, key2 = get_keys(key)
    state = add_round_key(text_to_matrix(plaintext), key0) 
    new_state = encrypt_round2(encrypt_round1(state, key1), key2)
    return matrix_to_text(new_state)

def decrypt(ciphertext, key):
    key0, key1, key2 = get_keys(key)
    state = decrypt_round1(decrypt_round2(text_to_matrix(ciphertext), key2), key1)
    new_state = add_round_key(state, key0)
    return matrix_to_text(new_state)

In [2]:
import random
random.seed(23) # Reproducible results

key = random.randrange(0, pow(2,16))
key0, key1, key2 = get_keys(key)
print(f"Key0:   {key0:0{16}b}")
print(f"Key1:   {key1:0{16}b}")
print(f"Key2:   {key2:0{16}b}")
plaintexts, ciphertexts, plaintexts2 = [], [], []

for i in range(10):
    plaintexts.append(random.randrange(0, pow(2,16)))
    ciphertexts.append(encrypt(plaintexts[i], key))
    plaintexts2.append(decrypt(ciphertexts[i], key))
    print(f"\nPair {i+1}")
    print(f"Plain:  {plaintexts[i]:0{16}b}")
    print(f"Cipher: {ciphertexts[i]:0{16}b}")

assert(plaintexts == plaintexts2) # Assert encryption/decryption works correctly

Key0:   1001010001101111
Key1:   0110110000000011
Key2:   1110010111100110

Pair 1
Plain:  0010101011000111
Cipher: 0001100111001110

Pair 2
Plain:  0000100010111110
Cipher: 1011111111011111

Pair 3
Plain:  1001110100001101
Cipher: 1010000001111011

Pair 4
Plain:  1101100011110101
Cipher: 1101000010111000

Pair 5
Plain:  1100001000100100
Cipher: 1101010111010101

Pair 6
Plain:  1011011101010110
Cipher: 0011111110101111

Pair 7
Plain:  0100001010110111
Cipher: 0101110110000011

Pair 8
Plain:  0110001001001101
Cipher: 1001010000100100

Pair 9
Plain:  1000100011100110
Cipher: 1110100001011000

Pair 10
Plain:  1110001110011111
Cipher: 1010011101111100
