#### Cryptography
#### Task № 4 - AES (Rijndael: key = 128, block = 128, rounds = 10)
#### Ivan Rybin - ITMO JB SE MA 2021

In [1]:
import numpy as np
import sympy
from sympy import poly
from sympy.abc import x as sympy_x
from pyfinite import ffield

In [2]:
GF2_8 = ffield.FField(8, gen=0x11b, useLUT=0)

In [3]:
GF2_8.ShowPolynomial(GF2_8.Multiply(87, 131))

'x^7 + x^6 + 1'

In [4]:
def xor_bytes(l, r):
    return bytearray([lb ^ rb for lb, rb in zip(l, r)])

def xor_bits(l, r):
    return "".join([str(int(lb) ^ int(rb)) for lb, rb in zip(l, r)])

def stobytes(s):
    return s.encode('utf-8')

def sto64(bits):
    while len(bits) % 64 != 0:
        bits += '0'
    return bits

def sto128(bits):
    while len(bits) % 128 != 0:
        bits += '0'
    return bits

def to8bits(bits):
    while len(bits) < 8:
        bits = '0' + bits
    return bits

def to4bits(bits):
    while len(bits) < 4:
        bits = '0' + bits
    return bits

def stobits(s):
    b = ""
    for c in stobytes(s):
        b += to8bits(bin(c)[2:])
    return b

def btoi(b):
    i = 0
    k = 0
    for c in reversed(b):
        i += int(c) * (2 ** k)
        k += 1
    return i

def bytestos(b):
    return b.decode('utf-8')

def bitstos(bits):
    s = []
    for b in split_to_blocks(bits, 8):
        s.append(btoi(b))
    return bytearray(s)

def perm(p, b):
    return ''.join([b[i] for i in p])

def split_to_blocks(b, block_size):
    blocks = []
    for i in range(1, int((len(b) + block_size) / block_size)):
        blocks.append(b[(i - 1) * block_size: i * block_size])
    return blocks


def cyclic_offset(b, offset, is_left=True):
    new_b = ''
    offset = offset % len(b)
    if is_left:
        return b[offset:] + b[:offset]
    else:
        for i in range(0, offset):
            b = b[-1] + b[:-1]
        new_b = b
    return new_b


def cyclic_offset_list(l, offset, is_left=True):
    offset = offset % len(l)
    if is_left:
        return l[offset:] + l[:offset]
    return l[len(l)-offset:] + l[:len(l)-offset]

In [5]:
def forward_S_BOX(b):
    int_b = btoi(b)
    int_b = GF2_8.DoInverseForBigField(int_b)
    b = to8bits(bin(int_b)[2:])
    s = xor_bits(b, cyclic_offset(b, 1))
    s = xor_bits(s, cyclic_offset(b, 2))
    s = xor_bits(s, cyclic_offset(b, 3))
    s = xor_bits(s, cyclic_offset(b, 4))
    s = xor_bits(s, to8bits(bin(99)[2:]))
    return s


def create_S_BOXES():
    s_box = dict()
    s_box_inv = dict()
    for src in range(0, 256):
        bin_src = to8bits(bin(src)[2:])
        boxed = forward_S_BOX(bin_src)
        s_box[bin_src] = boxed
        s_box_inv[boxed] = bin_src
    return s_box, s_box_inv


S_BOX, S_BOX_INV = create_S_BOXES()

print(len(S_BOX))
print(len(S_BOX_INV))

256
256


In [6]:
def poly_4x_mult(poly1, poly2):
    p1 = poly(poly1[0] * sympy_x ** 3 + poly1[1] * sympy_x ** 2 + poly1[2] * sympy_x + poly1[3], sympy_x)
    p2 = poly(poly2[0] * sympy_x ** 3 + poly2[1] * sympy_x ** 2 + poly2[2] * sympy_x + poly2[3], sympy_x)    
    return p1 * p2


def poly_modulo(p, m):
    return sympy.rem(p, m)

In [7]:
def multiply_vec_by_vec(v1, v2):
    res = []
    for i in range(0, 4):
        v = GF2_8.Multiply(v1[i], v2[i])
        res.append(to8bits(bin(v)[2:]))
        
    xor = res[0]
    for v in res[1:]:
        xor = xor_bits(xor, v)
    return xor


def state_to_int(bin_state):
    int_state = []
    for v in bin_state:
        int_state.append([btoi(i) for i in v])    
    return int_state


def state_to_bin(int_state):
    bin_state = []
    for v in int_state:
        bin_state.append([to8bits(bin(i)[2:]) for i in v])    
    return bin_state    

In [8]:
mix_col_matrix = [
    [2, 3, 1, 1],
    [1, 2, 3, 1],
    [1, 1, 2, 3],
    [3, 1, 1, 2]
]

inverse_mix_col_matrix = [
    [14, 11, 13, 9],
    [9, 14, 11, 13],
    [13, 9, 14, 11],
    [11, 13, 9, 14]
]

Rcon = [
    [0x00, 0x00, 0x00, 0x00],
    [0x01, 0x00, 0x00, 0x00],
    [0x02, 0x00, 0x00, 0x00],
    [0x04, 0x00, 0x00, 0x00],
    [0x08, 0x00, 0x00, 0x00],
    [0x10, 0x00, 0x00, 0x00],
    [0x20, 0x00, 0x00, 0x00],
    [0x40, 0x00, 0x00, 0x00],
    [0x80, 0x00, 0x00, 0x00],
    [0x1b, 0x00, 0x00, 0x00],
    [0x36, 0x00, 0x00, 0x00]
]

Rcon = [to8bits(bin(l[0])[2:]) + '0' * 24 for l in Rcon]

In [9]:
def block_to_state(data_block):
    state = [[] for i in range(0, 4)]
    curr_pos = 0
    for j in range(0, 4):
        for i in range(0, 4):
            state[i].append(data_block[curr_pos * 8: (curr_pos + 1) * 8])
            curr_pos += 1
    return state


def state_to_block(state):
    block = ''
    for j in range(0, 4):
        for i in range(0, 4):
            block += state[i][j]
    return block

In [10]:
def SubByte4(word):
    new_word = ''
    for i in range(0, 4):
        new_word += S_BOX[word[i * 8: (i + 1) * 8]]
    return new_word



def RotByte(word):
    return cyclic_offset(word, 8, is_left=True)
    
    

def KeyExpansion(key):
    w = []
    for i in range(0, 4):
        w.append(key[i * 32: (i + 1) * 32])
        
    for i in range(4, 4 * 11):
        tmp = w[i - 1]
        if i % 4 == 0:
            tmp = xor_bits(SubByte4(RotByte(tmp)), Rcon[int(i / 4)])
            
        w.append(xor_bits(w[i - 4], tmp))
    return w

In [11]:
def SubBytes(state):
    for i in range(0, 4):
        for j in range(0, 4):
            state[i][j] = S_BOX[state[i][j]]
    return state


def InvSubBytes(state):
    for i in range(0, 4):
        for j in range(0, 4):
            state[i][j] = S_BOX_INV[state[i][j]]
    return state


def ShiftRows(state):
    for i in range(1, 4):
        state[i] = cyclic_offset_list(state[i], i, is_left=True)
    return state


def InvShiftRows(state):
    for i in range(1, 4):
        state[i] = cyclic_offset_list(state[i], i, is_left=False)
    return state


def MixColumns(state):
    state = state_to_int(state)
    new_state = [[] for i in range(0, 4)]
    for j in range(0, 4):
        v = []
        for i in range(0, 4):
            v.append(state[i][j])
        
        for i in range(0, 4):
            mult_res = btoi(multiply_vec_by_vec(mix_col_matrix[i], v))
            new_state[i].append(mult_res)
            
    return state_to_bin(new_state)


def InvMixColumns(state):
    state = state_to_int(state)
    new_state = [[] for i in range(0, 4)]
    for j in range(0, 4):
        v = []
        for i in range(0, 4):
            v.append(state[i][j])
        
        for i in range(0, 4):
            mult_res = btoi(multiply_vec_by_vec(inverse_mix_col_matrix[i], v))
            new_state[i].append(mult_res)
            
    return state_to_bin(new_state)


def AddRoundKey(state, key_by32):
    new_state = [[] for i in range(0, 4)]
    for j in range(0, 4):
        state_word = ''
        for i in range(0, 4):
            state_word += state[i][j]
            
        x = xor_bits(state_word, key_by32[j])
        for i in range(0, 4):
            new_state[i].append(x[i * 8: (i + 1) * 8])
    return new_state

In [12]:
def AES_ENCRYPT(data, key):
    w = KeyExpansion(key)
    
    encrypted = ''
    for bid in range(0, int(len(data) / 128)):
        data_block = data[bid * 128: (bid + 1) * 128]
        state = block_to_state(data_block)

        # AES
        
        state = AddRoundKey(state, w[0:4])
        
        for r in range(1, 10):
            state = SubBytes(state)
            state = ShiftRows(state)
            state = MixColumns(state)
            state = AddRoundKey(state, w[r * 4: (r + 1) * 4])
        
        state = SubBytes(state)
        state = ShiftRows(state)
        state = AddRoundKey(state, w[10 * 4: 11 * 4])
        
        encrypted += state_to_block(state)
        
    return encrypted


def AES_DECRYPT(data, key):
    w = KeyExpansion(key)
    
    decrypted = ''
    for bid in range(0, int(len(data) / 128)):
        data_block = data[bid * 128: (bid + 1) * 128]
        state = block_to_state(data_block)

        # AES
        state = AddRoundKey(state, w[10 * 4:11 * 4])
        state = InvShiftRows(state)
        state = InvSubBytes(state)
        
        for r in range(9, 0, -1):
            state = AddRoundKey(state, w[r * 4: (r + 1) * 4])
            state = InvMixColumns(state)
            state = InvShiftRows(state)
            state = InvSubBytes(state)

        state = AddRoundKey(state, w[0:4])
        
        decrypted += state_to_block(state)
    
    return decrypted

In [17]:
def run_all():
    data = 'this is my message 123 HELLO'
    AES_key = 'crypto42_AES_key'
    
    
    data_bits = stobits(data)
    data_bits_128 = sto128(data_bits)
    
    AES_key_bits = stobits(AES_key)
    
    encrypted = AES_ENCRYPT(data_bits_128, AES_key_bits)
    decrypted = AES_DECRYPT(encrypted, AES_key_bits)

    
    print(f'msg: {data}')
    print(f'msg len: {len(data_bits_128)}')
    print(f'msg: {data_bits_128}\n')
    
    print(f'enc key: {AES_key}')
    print(f'key len: {len(AES_key)}')
    print(f'len bits: {len(AES_key_bits)}')
    print(f'key bits: {AES_key_bits}\n')
        
    print('ENCRYPTED')
    print(f'enc len: {len(encrypted)}')
    print(f'enc: {encrypted}\n')
    
    print('DECRYPTED')
    print(f'IS DECRYPTED == MSG BITS: {data_bits_128 == decrypted}\n')
    print(f'dec len: {len(decrypted)}')
    print(f'dec bits: {decrypted}\n')
    print(f'dec msg: {bytestos(bitstos(decrypted))}')

In [18]:
run_all()

msg: this is my message 123 HELLO
msg len: 256
msg: 0111010001101000011010010111001100100000011010010111001100100000011011010111100100100000011011010110010101110011011100110110000101100111011001010010000000110001001100100011001100100000010010000100010101001100010011000100111100000000000000000000000000000000

enc key: crypto42_AES_key
key len: 16
len bits: 128
key bits: 01100011011100100111100101110000011101000110111100110100001100100101111101000001010001010101001101011111011010110110010101111001

ENCRYPTED
enc len: 256
enc: 1110000100111111001011011110010011100011110001011000001101101010100101010001101000000111001000000110001111010001010001101011011100100001010111101100100001100001000101010010101110000101110111111100111100010001000011010011000011001001000111000011100000010111

DECRYPTED
IS DECRYPTED == MSG BITS: True

dec len: 256
dec bits: 011101000110100001101001011100110010000001101001011100110010000001101101011110010010000001101101011001010111001101110011011000010110011101100101001