In [390]:
import numpy as np
import sbox as sb
from consts import *

KEY_SIZE = 128
WORD_ARRAY_SIZE = 4
WORD_SIZE = int(KEY_SIZE / WORD_ARRAY_SIZE)

In [391]:
def stringToHex(string: str) -> list:
  """
  Converts a string to a list of hex values
  """
  return [hex(ord(c)) for c in string]

In [392]:
test = stringToHex("Thats my Kung Fu")
print(np.array(test), type(test))

['0x54' '0x68' '0x61' '0x74' '0x73' '0x20' '0x6d' '0x79' '0x20' '0x4b'
 '0x75' '0x6e' '0x67' '0x20' '0x46' '0x75'] <class 'list'>


In [393]:
def textToMatrix(text: str) -> list:
  """
  Converts a string to a 4x4 matrix of Hex values
  """
  key = stringToHex(text)
  return np.matrix(np.reshape(key, (WORD_ARRAY_SIZE, int(KEY_SIZE / WORD_SIZE)))).tolist()

In [394]:
testMatrix = textToMatrix("Thats my Kung Fu")
print(np.array(testMatrix), type(testMatrix))

[['0x54' '0x68' '0x61' '0x74']
 ['0x73' '0x20' '0x6d' '0x79']
 ['0x20' '0x4b' '0x75' '0x6e']
 ['0x67' '0x20' '0x46' '0x75']] <class 'list'>


In [395]:
def roundConst(round: int) -> int:
  """
  Returns the round constant for the given round
  """
  if round == 1:
    return 0x01
  prevRoundConst = roundConst(round - 1)
  return (prevRoundConst << 1) ^ (0x11b & -(prevRoundConst >> 7)) 

In [396]:
def g(key: list, round: int) -> list:
  """
  Performs the g() operation on a key
  1. Rotates the key by 1 byte
  2. Applies the sbox to each byte
  3. Add the round constant (XOR)
  """
  key = np.concatenate((key[1:], key[:1]))
  key = [hex(sb.Sbox[int(x, 16)]) for x in key]
  key[0] = hex(int(key[0], 16) ^ roundConst(round))

  return key

In [397]:
testG = g(testMatrix[3], 1)
print(testG, type(testG))

['0xb6', '0x5a', '0x9d', '0x85'] <class 'list'>


In [398]:
def xor(array1: list, array2: list) -> list:
  """
  XORs two matrices
  args: array1, array2 - 1D arrays of hex values
  """
  key = []
  
  for i in range(WORD_ARRAY_SIZE):
    key.append(hex(int(array1[i], 16) ^ int(array2[i], 16)))
    
  return key

In [399]:
keyXor = xor(testMatrix[0], testG)
print(keyXor, type(keyXor))

['0xe2', '0x32', '0xfc', '0xf1'] <class 'list'>


In [400]:
def roundKey(round: int, prevkey: list) -> list:
  """
  Generates the next round's key from the previous round's key
  """
  w = [xor(prevkey[0], g(prevkey[3], round))]
  
  for i in range(1, WORD_ARRAY_SIZE):
    newKey = xor(w[i - 1], prevkey[i])
    w.append(newKey)
    
  return w

In [401]:
def convertTo128Bits(text: str) -> str:
  """
  Converts the given key to 128 bits
  """
  if len(text) > 16:
    return text[:16]
  elif len(text) < 16:
    return text + "0" * (16 - len(text))
  
  return text

In [402]:
def createAllKeys(text: str) -> list:
  """
  Generates all round keys from the initial key
  """
  text = convertTo128Bits(text)
  keys = [textToMatrix(text)]
  for i in range(10):
    keys.append(roundKey(i+1, keys[i]))
  return keys

In [403]:
allKeys = createAllKeys("Thats my Kung Fu")
print(np.array(allKeys))

[[['0x54' '0x68' '0x61' '0x74']
  ['0x73' '0x20' '0x6d' '0x79']
  ['0x20' '0x4b' '0x75' '0x6e']
  ['0x67' '0x20' '0x46' '0x75']]

 [['0xe2' '0x32' '0xfc' '0xf1']
  ['0x91' '0x12' '0x91' '0x88']
  ['0xb1' '0x59' '0xe4' '0xe6']
  ['0xd6' '0x79' '0xa2' '0x93']]

 [['0x56' '0x8' '0x20' '0x7']
  ['0xc7' '0x1a' '0xb1' '0x8f']
  ['0x76' '0x43' '0x55' '0x69']
  ['0xa0' '0x3a' '0xf7' '0xfa']]

 [['0xd2' '0x60' '0xd' '0xe7']
  ['0x15' '0x7a' '0xbc' '0x68']
  ['0x63' '0x39' '0xe9' '0x1']
  ['0xc3' '0x3' '0x1e' '0xfb']]

 [['0xa1' '0x12' '0x2' '0xc9']
  ['0xb4' '0x68' '0xbe' '0xa1']
  ['0xd7' '0x51' '0x57' '0xa0']
  ['0x14' '0x52' '0x49' '0x5b']]

 [['0xb1' '0x29' '0x3b' '0x33']
  ['0x5' '0x41' '0x85' '0x92']
  ['0xd2' '0x10' '0xd2' '0x32']
  ['0xc6' '0x42' '0x9b' '0x69']]

 [['0xbd' '0x3d' '0xc2' '0x87']
  ['0xb8' '0x7c' '0x47' '0x15']
  ['0x6a' '0x6c' '0x95' '0x27']
  ['0xac' '0x2e' '0xe' '0x4e']]

 [['0xcc' '0x96' '0xed' '0x16']
  ['0x74' '0xea' '0xaa' '0x3']
  ['0x1e' '0x86' '0x3f' '0x24']
  [

In [404]:
def stateMatrix(text: str) -> list:
  """
  Converts the given text to a 4x4 matrix in column major order
  args: text: str - the text to convert. It is assumed that the text is 128 bits. 
    If the text is less than 128 bits, it is padded with 0s. If it is greater than 128 bits,
    it is truncated to 128 bits.
  """
  text = convertTo128Bits(text)
  matrix = textToMatrix(text)
  
  return np.array(matrix).T.tolist()
  

In [405]:
stateMat = stateMatrix("Two One Nine Two")
print(np.array(stateMat))

round0key = allKeys[0]
print(np.array(round0key).T)

[['0x54' '0x4f' '0x4e' '0x20']
 ['0x77' '0x6e' '0x69' '0x54']
 ['0x6f' '0x65' '0x6e' '0x77']
 ['0x20' '0x20' '0x65' '0x6f']]
[['0x54' '0x73' '0x20' '0x67']
 ['0x68' '0x20' '0x4b' '0x20']
 ['0x61' '0x6d' '0x75' '0x46']
 ['0x74' '0x79' '0x6e' '0x75']]


In [406]:
def addRoundKey(stateMatrix: list, roundKey: list) -> list:
  """
  Adds the round key to the state matrix
  
  args: stateMatrix: list - the state matrix (assumed to be in column major order)
        roundKey: list - the round key (assumed to be in row major order)
  """
  roundKey = np.array(roundKey).T.tolist()
  newStateMatrix = []
  
  for i in range(WORD_ARRAY_SIZE):
    newStateMatrix.append(xor(stateMatrix[i], roundKey[i]))
    
  return newStateMatrix

In [407]:
stateMat = addRoundKey(stateMat, round0key)
print(np.array(stateMat))

[['0x0' '0x3c' '0x6e' '0x47']
 ['0x1f' '0x4e' '0x22' '0x74']
 ['0xe' '0x8' '0x1b' '0x31']
 ['0x54' '0x59' '0xb' '0x1a']]


In [408]:
def subBytes(stateMatrix: list) -> list:
  """
  Applies the sbox to each byte in the state matrix
  """
  newStateMatrix = []
  
  for i in range(WORD_ARRAY_SIZE):
    newStateMatrix.append([hex(sb.Sbox[int(x, 16)]) for x in stateMatrix[i]])
    
  return newStateMatrix

In [409]:
stateMat = subBytes(stateMat)
print(np.array(stateMat))

[['0x63' '0xeb' '0x9f' '0xa0']
 ['0xc0' '0x2f' '0x93' '0x92']
 ['0xab' '0x30' '0xaf' '0xc7']
 ['0x20' '0xcb' '0x2b' '0xa2']]


In [410]:
def shiftRow(stateMat: list):
  """
  Round Shifts the rows of the state matrix to the left by the row number
  """
  newStateMatrix = []
  for i in range(WORD_ARRAY_SIZE):
    newStateMatrix.append(np.roll(stateMat[i], -i))
    
  return newStateMatrix
  

In [411]:
stateMat = shiftRow(stateMat)
print(np.array(stateMat))

[['0x63' '0xeb' '0x9f' '0xa0']
 ['0x2f' '0x93' '0x92' '0xc0']
 ['0xaf' '0xc7' '0xab' '0x30']
 ['0xa2' '0x20' '0xcb' '0x2b']]


In [412]:
def mixColumn(stateMat: list) -> list:
  """
  Mixes the columns of the state matrix with a fixed matrix
  """
  stateMat = np.array(stateMat).T
  print(stateMat)
  newStateMatrix = []
  
  for i in range(WORD_ARRAY_SIZE):
    row = []
    for j in range(int(KEY_SIZE / WORD_SIZE)):
      dotProd = 0
      
      for k in range(WORD_ARRAY_SIZE):
        bv1 = Mixer[i][k]
        bv2 = BitVector(hexstring = str(int(stateMat[j][k], 16)))
        dotProd ^= (bv1.gf_multiply_modular(bv2, AES_modulus, 8)).int_val()
      
      row.append(hex(dotProd))
    
    newStateMatrix.append(row)
  
  return newStateMatrix  
  
  

In [413]:
stateMat = mixColumn(stateMat)
print(np.array(stateMat))

[['0x63' '0x2f' '0xaf' '0xa2']
 ['0xeb' '0x93' '0xc7' '0x20']
 ['0x9f' '0x92' '0xab' '0xcb']
 ['0xa0' '0xc0' '0x30' '0x2b']]
[['0xf7' '0x52' '0x3c' '0x7d']
 ['0xdc' '0x14' '0x73' '0xe9']
 ['0x89' '0x16' '0x94' '0xa7']
 ['0x6b' '0xbf' '0x9b' '0xca']]
