## Assignment 2
Simplified Advanced Encryption Standard (S-AES)

[Reference](https://sandilands.info/sgordon/teaching/css322y12s2/unprotected/CSS322Y12S2H02-Simplified-AES-Example.pdf)

In [None]:
import numpy as np

### Helper functions

In [None]:
def bin_to_dec(x):
  return int(x, 2)
def dec_to_bin(x):
  return bin(x).replace("0b","")
def hex_to_bin(x):
  ret = dec_to_bin(int(x, 16))
  ret = assert_value_size(ret, len(x)*4)
  return ret
def bin_to_hex(x):
  return hex(bin_to_dec(x))

In [None]:
def assert_value_size(x, s):
  while len(x) < s:
    x = "0" + x
  return x

In [None]:
def xor(a, b):
  ret = ""
  for i in range(len(a)):
    if a[i] == b[i]: ret += "0"
    else: ret += "1"
  return ret

In [None]:
def split_str(val):
  half = len(val)//2
  return val[:half], val[half:]

In [None]:
def get_indices(nib):
  r = bin_to_dec(nib[:2])
  c = bin_to_dec(nib[2:])
  return r, c

In [None]:
def nibble_list(x):
  x = assert_value_size(x, 16)
  ret = [x[i:i+4] for i in range(0, len(x), 4)]
  return ret

def list_to_mat(l):
  return [
      [l[0], l[2]],
      [l[1], l[3]]
  ]

def mat_to_list(m):
  return [m[0][0], m[1][0], m[0][1], m[1][1]]

In [None]:
def rot_nib(val):
  half = len(val)//2
  return val[half:] + val[:half]

In [None]:
def mul_nib(nib1, nib2):
  p1 = [int(c) for c in nib1]
  p2 = [int(c) for c in nib2]
  ret = np.polymul(p1, p2)
  ret = [str(c) for c in ret]
  return "".join(ret)

def add_nib(nib1, nib2):
  p1 = [int(c) for c in nib1]
  p2 = [int(c) for c in nib2]
  ret = np.polyadd(p1, p2)
  ret = [c % 2 for c in ret]
  _, r = np.polydiv(ret, [1, 0, 0, 1, 1])
  nib = [str(int(c%2)) for c in r]
  nib = "".join(nib)
  while len(nib) > 4:
    nib = nib[1:]
  nib = assert_value_size(nib, 4)
  return nib

### Constants

In [None]:
def gen_inv_s_box(s):
  ret = [r[:] for r in s]
  for i in range(4):
    for j in range(4):
      r, c = get_indices(hex_to_bin(s[i][j]))
      ret[r][c] = bin_to_hex(assert_value_size(dec_to_bin(i), 2) + assert_value_size(dec_to_bin(j), 2))[2:]
  return ret

In [None]:
S = [
    ["1", "2", "3", "4"],
     ["5", "6", "7", "8"],
     ["9", "A", "B", "C"],
     ["D", "E", "F", "0"]
]
INV_S = gen_inv_s_box(S)
M = [
     ["1", "4"],
     ["4", "1"]
]
INV_M = [
    ["9", "2"],
    ["2", "9"]
]
print(INV_S)

[['f', '0', '1', '2'], ['3', '4', '5', '6'], ['7', '8', '9', 'a'], ['b', 'c', 'd', 'e']]


### Algorithm necessary functions

In [None]:
def sub_nib(x, s):
  ret = ""
  for i in range(0, len(x), 4):
    nib = x[i:i+4] 
    r, c = get_indices(nib)
    ret += hex_to_bin(s[r][c])
  return ret

def sub_nibs(x, s):
  for i in range(len(x)):
    for j in range(len(x[i])):
      x[i][j] = sub_nib(x[i][j], s)
  return x

In [None]:
def mixcol(A, B):
  ret = [
      [None, None],
      [None, None]
  ]
  for i in [0, 1]:
    for j in [0, 1]:
      ret[i][j] = add_nib(mul_nib(A[i][0], B[0][j]), mul_nib(A[i][1], B[1][j]))
  return ret

In [None]:
def shift_row(state):
  state[1][0], state[1][1] = state[1][1], state[1][0]
  return state

In [None]:
def add_round_key(state, key):
  k_mat = list_to_mat(nibble_list(key))
  for i in range(2):
    for j in range(2):
      state[i][j] = xor(state[i][j], k_mat[i][j])
  return state

In [None]:
def get_subkey(prev_key, t):
  w0, w1 = split_str(prev_key)
  w2 = w0
  w2 = xor(w2, t)
  w2 = xor(w2, sub_nib(rot_nib(w1), S))
  w3 = xor(w2, w1)
  return w2 + w3

In [None]:
def gen_subkeys(key):
  key0 = key
  key1 = get_subkey(key0, hex_to_bin("80"))
  key2 = get_subkey(key1, hex_to_bin("60"))
  return key0, key1, key2

In [None]:
def encrypt(plaintext, key):
  key0, key1, key2 = gen_subkeys(key)
  state = list_to_mat(nibble_list(plaintext))

  # Round 0
  state = add_round_key(state, key0)

  #Round 1
  state = sub_nibs(state, S)
  state = shift_row(state)
  state = mixcol(M, state)
  state = add_round_key(state, key1)

  # Round 2
  state = sub_nibs(state, S)
  state = shift_row(state)
  state = add_round_key(state, key2)


  ciphertext = "".join(mat_to_list(state))

  return ciphertext

In [None]:
def decrypt(ciphertext, key):
  key0, key1, key2 = gen_subkeys(key)
  state = list_to_mat(nibble_list(ciphertext))

  # Inv round 2
  state = add_round_key(state, key2)
  state = shift_row(state)
  state = sub_nibs(state, INV_S)

  # Inv round 1
  state = add_round_key(state, key1)
  state = mixcol(INV_M, state)
  state = shift_row(state)
  state = sub_nibs(state, INV_S)

  # Inv round 0
  state = add_round_key(state, key0)

  plaintext = "".join(mat_to_list(state))
  return plaintext

### Testing

In [None]:
plaintext = hex_to_bin("BC78")
key = hex_to_bin("2B85")

In [None]:
c = encrypt(plaintext, key)
p = decrypt(c, key)

In [None]:
assert(p == plaintext)