In [1]:
# Import libraries
import pandas as pd
import numpy as np

import warnings
warnings.filterwarnings('ignore')

import ipywidgets as widgets
from ipywidgets import interactive, Tab, Layout
from IPython.display import display, clear_output, Image

In [4]:
#!jupyter nbextension enable --py widgetsnbextension --sys-prefix
#!jupyter serverextension enable voila --sys-prefix

# AES Demonstrator
---
This demonstrator serves as a learning tool to understand how the AES algorithm works, by decomposing the algorithm into individual steps and allowing users to change inputs for these steps to visualise how the computations of the steps are done. In addition, it also demonstrates the following:<br>
<li>How Galois field operations are used in AES and the stepwise computation of the GF polynomial arithmetic</li>
<li>How the AES encryption and decryption work, by allowing the user to provide a plaintext input to see how the plaintext is encrypted and then decrypted back to the original plaintext

In [3]:
# Widgets for Intro

heading_intro = widgets.HTML(value="<h1>What is AES?</h1>")
text_intro1 = widgets.HTMLMath(value=r"AES, the Advanced Encryption Standard, is a symmetric key cryptographic algorithm. It is a block cipher, with the plaintext input having a block size of 128 bits (or 16 characters). The plaintext block is represented as a 4 $\times$ 4 column major order matrix of bytes, called the state array:<br><br>")
text_intro2 = widgets.HTMLMath(value=r"\begin{vmatrix} b_0 & b_4 & b_8 & b_{12} \\ b_1 & b_5 & b_9 & b_{13} \\ b_2 & b_6 & b_{10} & b_{14} \\ b_3 & b_7 & b_{11} & b_{15} \end{vmatrix}")
text_intro3 = widgets.HTML(value="There are 3 types of AES, with different key sizes and number of processing rounds:<br><br>")
text_intro4 = widgets.HTML(value="<br>For the encryption process, the state array goes through 4 kinds of transformations in each processing round:")
text_intro5 = widgets.HTML(value="<li>SubBytes - Substitution of each byte of the state array using a substitution box (S-box)</li><li>ShiftRows - Shifting of each row of the state array to the left by a certain number of positions</li><li>MixColumns - Matrix multiplication between the state array and a predefined matrix</li><li>AddRoundKey - XOR of the state array and the round key</li>")
text_intro6 = widgets.HTML(value="<br>Similarly for decryption, the state array goes through inverse transformation functions in each processing round:")
text_intro7 = widgets.HTML(value="<li>InvShiftRows</li><li>InvSubBytes</li><li>AddRoundKey</li><li>InvMixColumns</li>")
text_intro8 = widgets.HTML(value="<br>The AES algorithm also has a Key Expansion process which expands the original key into a key schedule, generating different round keys to be used in the AddRoundKey step for each processing round.<br><br>")
text_intro9 = widgets.HTML(value="The following flowchart summarises the AES Encryption and Decryption process:<br><br><img src='https://www.researchgate.net/profile/Prabhakar-T/publication/221958203/figure/fig2/AS:339550586064915@1457966576426/AES-Encryption-Decryption-Flowchart.png' alt='AES Encryption/Decryption Flowchart'/><br>AES Encryption/Decryption Flowchart")
text_intro10 = widgets.HTMLMath(value=r"<br>This demo uses the PKCS5 Padding standard to add padding of length $n$ between 1 and 16 bytes to the plaintext before encryption. For padding removal, it checks that the last $N$ bytes of decrypted data all have value $N$ with 1 $< N \leq$ 16 and if so, $N$ bytes are removed, otherwise a decryption error is raised.")
tb = widgets.Output()
with tb:
    data = [[128, 192, 256], [10, 12, 14]]
    display(pd.DataFrame(data,
                         columns = ['AES-128','AES-192','AES-256'],
                         index = ['Key Size', 'No. of Rounds']))

vbox_intro = widgets.VBox([heading_intro,
                           text_intro1,
                           text_intro2,
                           text_intro3,
                           tb,
                           text_intro4,
                           text_intro5,
                           text_intro6,
                           text_intro7,
                           text_intro8,
                           text_intro9,
                           text_intro10])

In [4]:
def StrToHex(s):
    s_hex = s.encode("utf-8").hex()
    return s_hex

def HexToStr(s_hex):
    s = bytes.fromhex(s_hex).decode("utf-8")
    return s

def HexToByte(s_hex):
    s_byte = []
    for i in range(0, len(s_hex), 2):
        s_byte.append(s_hex[i:i+2])
    return s_byte

def HexToWord(s_hex):
    s_word = []
    for i in range(0, len(s_hex), 8):
        s_word.append(s_hex[i:i+8])
    return s_word

def xor(s1, s2, length):
    s3 = format(int(s1, 16) ^ int(s2, 16), 'x')
    if len(s3) < length:
        s3 = "0" * (length-len(s3)) + s3
    return s3

def MatrixForm(s_byte):
    n = len(s_byte) / 16
    chunks = np.array_split(s_byte, 4*n)
    matrix = pd.DataFrame(np.array(chunks).T)
    print(matrix)

In [5]:
# Demos for Processing Steps

def SubBytesDemo(input_str):
    print("SubBytes Demo")
    print("-------------")
    # Check for invalid input
    if not isinstance(input_str, str):
        raise TypeError("Input must be a string")
    if len(input_str) != 16:
        raise ValueError("Input should have 16 characters")
    
    print("Input:", input_str)
    input_hex = StrToHex(input_str)
    print("Input (in hex):", input_hex)
    MatrixForm(HexToByte(input_hex))
    print()
    print("Substituting each byte using S-box =>")
    output_hex = SubBytes(input_hex)
    MatrixForm(HexToByte(output_hex))
    print()
    print("Output (in hex):", output_hex)

def ShiftRowsDemo(input_str):
    print("ShiftRows Demo")
    print("--------------")
    # Check for invalid input
    if not isinstance(input_str, str):
        raise TypeError("Input must be a string")
    if len(input_str) != 16:
        raise ValueError("Input should have 16 characters")
    
    print("Input:", input_str)
    input_hex = StrToHex(input_str)
    print("Input (in hex):", input_hex)
    print()
    print("Shifting row i to the left by i positions")
    print()
    print("Before:")
    MatrixForm(HexToByte(input_hex))
    print()
    output_hex = ShiftRows(input_hex)
    print("After:")
    MatrixForm(HexToByte(output_hex))
    print()
    print("Output (in hex):", output_hex)

def MixColumnsDemo(input_str):
    print("MixColumns Demo")
    print("---------------")
    # Check for invalid input
    if not isinstance(input_str, str):
        raise TypeError("Input must be a string")
    if len(input_str) != 16:
        raise ValueError("Input should have 16 characters")
    
    print("Input:", input_str)
    input_hex = StrToHex(input_str)
    print("Input (in hex):", input_hex)
    print()
    print("Matrix multiplication carried out in GF(2^8) arithmetic:")
    print()
    MatrixForm(HexToByte(input_hex))
    print("         x")
    MatrixForm(['02','01','01','03','03','02','01','01','01','03','02','01','01','01','03','02'])
    print("         =")
    output_hex = MixColumns(input_hex)
    MatrixForm(HexToByte(output_hex))
    print()
    print("Output (in hex):", output_hex)

def AddRoundKeyDemo(input_str, key):
    print("AddRoundKey Demo")
    print("----------------")
    # Check for invalid input
    if not isinstance(input_str, str):
        raise TypeError("Input must be a string")
    if not isinstance(key, str):
        raise TypeError("Key must be a string")
    if len(input_str) != 16:
        raise ValueError("Input should have 16 characters")
    if len(key) != 16:
        raise ValueError("Key should have 16 characters")
    
    print("Input:", input_str)
    print("Key:", key)
    input_hex = StrToHex(input_str)
    print("Input (in hex):", input_hex)
    key_hex = StrToHex(key)
    print("Key (in hex):", key_hex)
    print()
    MatrixForm(HexToByte(input_hex))
    print("         +       (XOR)")
    MatrixForm(HexToByte(key_hex))
    print("         =")
    output_hex = xor(input_hex, key_hex, len(input_hex))
    MatrixForm(HexToByte(output_hex))
    print()
    print("Output (in hex):", output_hex)

def NumOfRounds(key_length):
    if key_length == 128:
        return 10
    elif key_length == 192:
        return 12
    elif key_length == 256:
        return 14

def AddPadding(plaintext):
    s = StrToHex(plaintext)
    n = 16 - (len(s) // 2) % 16
    if n == 16:
        padding = n * (format(n, 'x'))
    else:
        padding = n * ("0" + format(n, 'x'))
    s = s + padding
    return s

def RemovePadding(s):
    n = int(s[-2:], 16)
    if n == 16:
        padding = n * (format(n, 'x'))
    else:
        padding = n * ("0" + format(n, 'x'))
        
    if 1 <= n <= 16:
        padbytes = s[-2*n:]
        if (padbytes == padding):
            return s[:-2*n]
        else:
            print("Decryption Error: padbytes != padding")
    else:
        print("Decryption Error: n <= 0 or >= 17")

def pad_or_truncate(string):
    if len(string) == 16:
        return string
    elif len(string) < 16:
        return string + (16-len(string)) * '0'
    else:
        return string[:16]

In [6]:
# Widgets for Galois Field

heading_gf = widgets.HTML(value="<h1>Math Prerequisite: Galois Field Operations</h1>")
text_gf1 = widgets.HTMLMath(value=r"A Galois Field (GF), or finite field, contains a finite number of elements and is denoted as GF($p^n$) with $p^n$ elements where $p$ is prime and $n$ is a positive integer.<br><br>")
text_gf2 = widgets.HTMLMath(value=r"When $n = 1$, GF($p$) (finite field of order $p$) contains elements that are integers modulo $p$. For example, GF($5$) consists of elements 0 to 4. For $n > 1$, the elements are represented as polynomials of degree $n-1$ or less, over GF($p$).<br><br>")
text_gf3 = widgets.HTMLMath(value=r"E.g. GF($2^3$) consists of 8 elements: 0, 1, 2, 3, 4, 5, 6, 7<br>Binary representation of elements: 0, 1, 10, 11, 100, 101, 110, 111<br>Polynomial representation: $0, 1, x, x+1, x^2, x^2+1, x^2+x, x^2+x+1$<br><br>")
text_gf4 = widgets.HTML(value="<h2>GF Polynomial Arithmetic<h2>")
text_gf5 = widgets.HTMLMath(value=r"For AES, operations are done in GF($2^8$), and $ x^8+x^4+x^3+x+1 $ (0x11b in hexadecimal) is chosen as the irreducible polynomial (IP) which is a polynomial that cannot be factored. Arithmetic operations of elements in a GF are done differently from the usual polynomial arithmetic:<br>")
text_gf6 = widgets.HTML(value="<ul><li>Addition: Polynomial addition with coefficients being added modulo 2; this is the same as XORing the 2 polynomials together</li><li>Multiplication: Polynomial multiplication is performed normally with coefficients modulo 2, then if the degree of the resulting polynomial is the same as or higher than that of the IP, the result is divided by the IP and the remainder is taken</li><li>Multiplicative Inverse: To find the multiplicative inverse of a polynomial, the Extended Euclidean algorithm is used; the method used by this demo is as follows:</li></ul>")
text_gf7 = widgets.HTMLMath(value=r"To find multiplicative inverse of $x$ in GF($2^8$):<ol><li>Let $r_0$ be the Irreducible Polynomial (IP) $x^8+x^4+x^3+x+1$ and $r_1$ be $x$ in polynomial form</li><li>Let $ s_0 = 1, s_1 = 0, s_i = s_{i-2} - q_{i-1}s_{i-1} $<br>and $ t_0 = 0, t_1 = 1, t_i = t_{i-2} - q_{i-1}t_{i-1} $</li><li>$q$ (quotient) and $r$ (remainder) values are obtained by dividing $r_{i-2}$ and $r_{i-1}$<br>i.e. $q_{i-1} = r_{i-2} \phantom{0} / \phantom{0} r_{i-1}$ and $r_i = r_{i-2} \bmod r_{i-1}$</li><li>$q$, $r$, $s$ and $t$ values are then computed and stored in a table until $r_i$ reaches 0</li><li>From the table, when $r_k = 1$, multiplicative inverse of $x$ can be found at $t_k$</li></ol>")

vbox_gf_desc = widgets.VBox([heading_gf,
                               text_gf1,
                               text_gf2,
                               text_gf3,
                               text_gf4,
                               text_gf5,
                               text_gf6,
                               text_gf7])

In [7]:
def PolynomialForm(a):
    if a == 0:
        return "0"
    y = list(bin(a)[2:])[::-1]
    y = list(map(int, y))
    for i in range(len(y)):
        if y[i] == 1:
            y[i] = "x^"+str(i)
    poly = ""
    for var in y[::-1]:
        if var != 0:
            if var == "x^0":
                poly += "+1"
            elif var == "x^1":
                poly += "+x"
            else:
                poly += "+" + var
    return poly[1:]

# Demo for GF Multiplication

def gf_mul_for_demo(x, y):
    if y == 1:
        return x
    y_poly = list(bin(y)[2:])
    y_poly = list(map(int, y_poly))
    y_deg = len(y_poly) - 1
    val = []
    for var in y_poly:
        if var == 1:
            val.append(x * 2**y_deg)
        y_deg -= 1
    result = 0
    for var in val:
        if result != 0:
            print("+")
        print(PolynomialForm(var))
        result ^= var
    print()
    print("Coefficients being added modulo 2:", PolynomialForm(result))
    result = gf_mod_for_demo(result, 0x11b)
    return result

def gf_mod_for_demo(fx, gx):
    if fx >= 2**int(np.log2(gx)):
        print()
        print("Since degree of resulting polynomial >= degree of IP,")
        print("thus required to modulo the IP")
        print()
        print("({}) mod ({})".format(PolynomialForm(fx), PolynomialForm(gx)))
    while fx >= 2**int(np.log2(gx)):
        deg = int(np.log2(fx)) - int(np.log2(gx))
        fx ^= (gx * 2**deg)
    print("=", PolynomialForm(fx))
    print("=", bin(fx))
    print("=", fx)
    return fx

def GFMulDemo(x, y):
    print("GF Multiplication Demo")
    print("----------------------")
    print("{} = {} => {} in polynomial form".format(x, bin(x), PolynomialForm(x)))
    print("{} = {} => {} in polynomial form".format(y, bin(y), PolynomialForm(y)))
    print()
    print("{} x {} => ({}) * ({})".format(x, y, PolynomialForm(x), PolynomialForm(y)))
    print("=")
    print(gf_mul_for_demo(x, y))

heading_demo_gf_mul = widgets.HTML(value="<h2>GF Multiplication Demo</h2>")
text_demo_gf_mul = widgets.HTML(value="Please enter 2 numbers from 1 to 255 before clicking on the 'GF Multiplication' button:")

input_num1 = widgets.Text(
    value='0',
    placeholder='Type something here',
    description='Number 1:',
    disabled=False
)
input_num2 = widgets.Text(
    value='0',
    placeholder='Type something here',
    description='Number 2:',
    disabled=False
)

vbox_gf_mul = widgets.VBox([heading_demo_gf_mul,
                            text_demo_gf_mul,
                            input_num1,
                            input_num2])

button_gf_mul = widgets.Button(
                description='Multiplication',
                style={'description_width': 'initial'}
            )

output_gf_mul = widgets.Output()

def on_button_clicked_gf_mul(event):
    output_gf_mul.clear_output()
    with output_gf_mul:
        x = input_num1.value
        y = input_num2.value
        # To check for empty or invalid input
        try:
            x = int(x)
            y = int(y)
        except ValueError:
            print("Please input 2 integers for the demo")
            return
        
        if x < 1 or x > 255 or y < 1 or y > 255:
            print("Please input a valid number from 1 to 255")
        else:
            res = GFMulDemo(x, y)

button_gf_mul.on_click(on_button_clicked_gf_mul)

vbox_gf_mul_demo = widgets.VBox([vbox_gf_mul, button_gf_mul, output_gf_mul])

In [8]:
# Galois Field Multiplication
def gf_mul(x, y):
    if y == 1:
        return x
    y_poly = list(bin(y)[2:])
    y_poly = list(map(int, y_poly))
    y_deg = len(y_poly) - 1
    val = []
    for var in y_poly:
        if var == 1:
            val.append(x * 2**y_deg)
        y_deg -= 1
    result = 0
    for var in val:
        result ^= var
    result = gf_mod(result, 0x11b)
    return result

# Galois Field Modulo
def gf_mod(fx, gx):
    while fx >= 2**int(np.log2(gx)):
        deg = int(np.log2(fx)) - int(np.log2(gx))
        fx ^= (gx * 2**deg)
    return fx

# Galois Field Division
def gf_div(fx, gx):
    qx = 0
    while fx >= 2**int(np.log2(gx)):
        deg = int(np.log2(fx)) - int(np.log2(gx))
        fx ^= (gx * 2**deg)
        qx += 2**deg
    return qx

# Galois Field Multiplicative Inverse
def gf_mul_inv(fx, gx):
    q = [0]
    r = [fx, gx]
    s = [1, 0]
    t = [0, 1]
    i = 2
    while r[i-1] != 1:
        q.append(gf_div(r[i-2], r[i-1]))
        r.append(gf_mod(r[i-2], r[i-1]))
        s.append(s[i-2] ^ gf_mul(q[i-1], s[i-1]))
        t.append(t[i-2] ^ gf_mul(q[i-1], t[i-1]))
        i += 1
    return t[i-1]

In [9]:
# Demo for GF Multiplicative Inverse

def gf_mul_inv_tb(q, r, s, t):
    q1, r1, s1, t1 = q, r, s, t
    for array in [q1, r1, s1, t1]:
        for i in range(len(array)):
            array[i] = PolynomialForm(array[i])
    print()
    data = [q1, r1, s1, t1]
    display(pd.DataFrame(data, index = ['𝑞', '𝑟', '𝑠', '𝑡']))

def gf_mul_inv_for_demo(fx, gx):
    q = [0]
    r = [fx, gx]
    s = [1, 0]
    t = [0, 1]
    i = 2
    while r[i-1] != 1:
        q.append(gf_div(r[i-2], r[i-1]))
        r.append(gf_mod(r[i-2], r[i-1]))
        s.append(s[i-2] ^ gf_mul(q[i-1], s[i-1]))
        t.append(t[i-2] ^ gf_mul(q[i-1], t[i-1]))
        i += 1
    result = t[i-1]
    gf_mul_inv_tb(q, r, s, t)
    k = len(r) - 1
    print("The multiplicative inverse can be found at 𝑡[𝑘] when 𝑟[𝑘] = 1,")
    print("so when 𝑟[{}] = 1,".format(k))
    print("𝑡[{}] = {} ≡ {} = {}".format(k, PolynomialForm(result), bin(result), result))
    return result

def GFMulInvDemo(x):
    print("GF Multiplicative Inverse Demo")
    print("------------------------------")
    print()
    print("Finding multiplicative inverse of {} in GF(2^8):".format(x))
    irreducible_poly = 0x11b
    print()
    print("{} = {} => {} in polynomial form".format(x, bin(x), PolynomialForm(x)))
    print()
    print("Let 𝑟[0] be the Irreducible Polynomial (IP) {} and 𝑟[1] be {}".format(PolynomialForm(irreducible_poly), PolynomialForm(x)))
    print()
    print("Let 𝑠[0] = 1, 𝑠[1] = 0, 𝑠[𝑖] = 𝑠[𝑖-2] - 𝑞[𝑖-1]*𝑠[𝑖-1]")
    print("Let 𝑡[0] = 0, 𝑡[1] = 1, 𝑡[𝑖] = 𝑡[𝑖-2] - 𝑞[𝑖-1]*𝑡[𝑖-1]")
    print()
    print("Computing 𝑞, 𝑟, 𝑠 and 𝑡 values, until 𝑟[𝑖] reaches 0:")
    res = gf_mul_inv_for_demo(irreducible_poly, x)
    print("∴ multiplicative inverse of {} = {}".format(x, res))

heading_demo_gf_mulinv = widgets.HTML(value="<h2>GF Multiplicative Inverse Demo</h2>")
text_demo_gf_mulinv = widgets.HTML(value="Please enter a number from 1 to 255 before clicking on the 'Multiplicative Inv' button:")

input_num = widgets.Text(
    value='0',
    placeholder='Type something here',
    description='Number:',
    disabled=False
)

vbox_gf_mulinv = widgets.VBox([heading_demo_gf_mulinv,
                               text_demo_gf_mulinv,
                               input_num])

button_gf_mulinv = widgets.Button(
                description='Multiplicative Inv',
                style={'description_width': 'initial'}
            )

output_gf_mulinv = widgets.Output()

def on_button_clicked_gf_mulinv(event):
    output_gf_mulinv.clear_output()
    with output_gf_mulinv:
        x = input_num.value
        # To check for empty or invalid input
        try:
            x = int(x)
        except ValueError:
            print("Please input a number for the demo")
            return
        
        if x < 1 or x > 255:
            print("Please input a valid number from 1 to 255")
        else:
            res = GFMulInvDemo(x)

button_gf_mulinv.on_click(on_button_clicked_gf_mulinv)

vbox_gf_mulinv_demo = widgets.VBox([vbox_gf_mulinv, button_gf_mulinv, output_gf_mulinv])
hbox_gf_demo = widgets.HBox([vbox_gf_mul_demo, vbox_gf_mulinv_demo])
vbox_gf = widgets.VBox([vbox_gf_desc, hbox_gf_demo])

In [10]:
# Generation of S-box
def SBoxElements(i):
    irreducible_poly = 0x11b
    inv = gf_mul_inv(irreducible_poly, i)
    y = list(bin(inv))[2:]
    y = list(map(int, y))
    if len(y) != 8:
        for i in range(0, 8-len(y)): 
            y.insert(0, 0)
    s = []
    c = list("01100011") # {63}
    for i in range(0, 8):
        s.append(str(y[i] ^ y[(i+1)%8] ^ y[(i+2)%8] ^ y[(i+3)%8] ^ y[(i+4)%8] ^ int(c[i])))
    ss = int("".join(s), 2)
    return ("0" if 0 <= ss <= 15 else "") + format(ss, 'x')

s_box_array = []
for i in range(0, 256):
    if i == 0:
        s_box_array.append('63')
    else:
        s_box_array.append(SBoxElements(i))

chunks = np.array_split(s_box_array, 16)
s_box = pd.DataFrame(np.array(chunks), columns = ['0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f'],
                  index = [list('0123456789abcdef')])

In [11]:
# Generation of Inverse S-box
inv_s_box_array = []
for i in range(0, len(s_box_array)):
    for j in range(0, len(s_box_array)):
        val = int(s_box_array[j], 16)
        if (val == i):
            if 0 <= j <= 15:
                inv_s_box_array.append("0" + format(j, 'x'))
            else:
                inv_s_box_array.append(format(j, 'x'))

chunks = np.array_split(inv_s_box_array, 16)
inv_s_box = pd.DataFrame(np.array(chunks), columns = ['0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f'],
                  index = [list('0123456789abcdef')])

In [12]:
# SubBytes
def SubBytes(s):
    for i in range(0, len(s), 2):
        index = int(s[i:i+2],16)
        sb = s_box_array[index]
        
        s_list = list(s)
        s_list[i:i+2] = sb
        s = "".join(s_list)
    return s

# InvSubBytes
def InvSubBytes(s):
    for i in range(0, len(s), 2):
        index = int(s[i:i+2],16)
        sb = inv_s_box_array[index]
        
        s_list = list(s)
        s_list[i:i+2] = sb
        s = "".join(s_list)
    return s

# ShiftRows
def ShiftRows(s):
    s_byte = HexToByte(s)
    shifted = []
    for i in range(0, len(s_byte), 16):
        shifted = shifted + [s_byte[i+0], s_byte[i+5], s_byte[i+10], s_byte[i+15],
                             s_byte[i+4], s_byte[i+9], s_byte[i+14], s_byte[i+3],
                             s_byte[i+8], s_byte[i+13], s_byte[i+2], s_byte[i+7],
                             s_byte[i+12], s_byte[i+1], s_byte[i+6], s_byte[i+11]]
    s_new = "".join(shifted)
    return s_new

# InvShiftRows
def InvShiftRows(s):
    s_byte = HexToByte(s)
    shifted = []
    for i in range(0, len(s_byte), 16):
        shifted = shifted + [s_byte[i+0], s_byte[i+13], s_byte[i+10], s_byte[i+7],
                             s_byte[i+4], s_byte[i+1], s_byte[i+14], s_byte[i+11],
                             s_byte[i+8], s_byte[i+5], s_byte[i+2], s_byte[i+15],
                             s_byte[i+12], s_byte[i+9], s_byte[i+6], s_byte[i+3]]
    s_new = "".join(shifted)
    return s_new

# MixColumns
def MixColumns(s):
    s_byte = HexToByte(s)
    for i in range(0, len(s_byte)):
        s_byte[i] = int(s_byte[i], 16)
    s_new = []
    for i in range(0, len(s_byte), 4):
        s0 = gf_mul(s_byte[i+0], 2) ^ gf_mul(s_byte[i+1], 3) ^ gf_mul(s_byte[i+2], 1) ^ gf_mul(s_byte[i+3], 1)
        s1 = gf_mul(s_byte[i+0], 1) ^ gf_mul(s_byte[i+1], 2) ^ gf_mul(s_byte[i+2], 3) ^ gf_mul(s_byte[i+3], 1)
        s2 = gf_mul(s_byte[i+0], 1) ^ gf_mul(s_byte[i+1], 1) ^ gf_mul(s_byte[i+2], 2) ^ gf_mul(s_byte[i+3], 3)
        s3 = gf_mul(s_byte[i+0], 3) ^ gf_mul(s_byte[i+1], 1) ^ gf_mul(s_byte[i+2], 1) ^ gf_mul(s_byte[i+3], 2)
        s_new.append(("0" if 0 <= s0 <= 15 else "") + format(s0, 'x'))
        s_new.append(("0" if 0 <= s1 <= 15 else "") + format(s1, 'x'))
        s_new.append(("0" if 0 <= s2 <= 15 else "") + format(s2, 'x'))
        s_new.append(("0" if 0 <= s3 <= 15 else "") + format(s3, 'x'))
    return ("".join(s_new))

# InvMixColumns
def InvMixColumns(s):
    s_byte = HexToByte(s)
    for i in range(0, len(s_byte)):
        s_byte[i] = int(s_byte[i], 16)
    s_new = []
    for i in range(0, len(s_byte), 4):
        s0 = gf_mul(s_byte[i+0], 14) ^ gf_mul(s_byte[i+1], 11) ^ gf_mul(s_byte[i+2], 13) ^ gf_mul(s_byte[i+3], 9)
        s1 = gf_mul(s_byte[i+0], 9) ^ gf_mul(s_byte[i+1], 14) ^ gf_mul(s_byte[i+2], 11) ^ gf_mul(s_byte[i+3], 13)
        s2 = gf_mul(s_byte[i+0], 13) ^ gf_mul(s_byte[i+1], 9) ^ gf_mul(s_byte[i+2], 14) ^ gf_mul(s_byte[i+3], 11)
        s3 = gf_mul(s_byte[i+0], 11) ^ gf_mul(s_byte[i+1], 13) ^ gf_mul(s_byte[i+2], 9) ^ gf_mul(s_byte[i+3], 14)
        s_new.append(("0" if 0 <= s0 <= 15 else "") + format(s0, 'x'))
        s_new.append(("0" if 0 <= s1 <= 15 else "") + format(s1, 'x'))
        s_new.append(("0" if 0 <= s2 <= 15 else "") + format(s2, 'x'))
        s_new.append(("0" if 0 <= s3 <= 15 else "") + format(s3, 'x'))
    return ("".join(s_new))

# AddRoundKey
def AddRoundKey(s, round_keys, r):
    round_key = "".join(round_keys[r])
    key = round_key * (len(s) // len(round_key))
    s_new = xor(s, key, len(s))
    return s_new

In [13]:
# Widgets for S-box
heading_sbox = widgets.HTML(value="<h2>S-box & Inverse S-box Generation</h2>")
text_sbox1 = widgets.HTML(value="The SubBytes step uses a substitution box (S-box), which is a lookup table constructed from the following 2 operations on bytes 0x00 - 0xff:")
text_sbox2 = widgets.HTMLMath(value=r"<li>Finding the multiplicative inverse in GF($2^8$)</li><li>Applying the following affine transformation on the multiplicative inverse as shown below:</li>")
text_sbox3 = widgets.HTML(value="<br><img src='https://wikimedia.org/api/rest_v1/media/math/render/svg/78c56ffe89890582a7060845e131a788266cbd59' alt='Affine Transformation'/>")
text_sbox4 = widgets.HTMLMath(value=r"where $s$ is the S-box output and $b$ is the multiplicative inverse<br>")
text_sbox5 = widgets.HTML(value="Affine Transformation<br>")
text_sbox6 = widgets.HTML(value="<br>The S-box is designed to be invertible, thus the inverse S-box can be generated from the S-box by thinking of the S-box as a Python dictionary and swapping all the keys and values.")

# Widgets for SubBytes Demo

heading_sb = widgets.HTML(value="<h1>SubBytes</h1>")
text_sb = widgets.HTML(value="The SubBytes step substitutes each byte of the state array with a new one using the S-box to obtain a new state:")
heading_demo = widgets.HTML(value="<h2>Demo:</h2>")
text_demo = widgets.HTML(value="Please enter the input for the demo. You can also change the input to view the output change in real time:")

img3 = widgets.HTML(value="<img src='https://upload.wikimedia.org/wikipedia/commons/thumb/a/a4/AES-SubBytes.svg/400px-AES-SubBytes.svg.png'/>")
label3 = widgets.Label(
    value='AES SubBytes Step',
    style={'description_width': 'initial'}
)
vbox_img3 = widgets.VBox([img3, label3])

def sb(Plaintext):
    p = pad_or_truncate(Plaintext)
    SubBytesDemo(p)
    return Plaintext
demo_sb = interactive(sb, Plaintext='ExamplePlaintext')

vbox_sb = widgets.VBox([heading_sb, text_sb, vbox_img3, heading_demo, text_demo, demo_sb])

# Widgets for ShiftRows Demo

heading_sr = widgets.HTML(value="<h1>ShiftRows</h1>")
text_sr = widgets.HTML(value="The ShiftRows step cyclically shifts each row of the state array to the left by a certain number of positions; the 1st row is not shifted, and the 2nd, 3rd and 4th rows are shifted by 1, 2 and 3 to the left respectively:")

img4 = widgets.HTML(value="<img src='https://upload.wikimedia.org/wikipedia/commons/thumb/6/66/AES-ShiftRows.svg/400px-AES-ShiftRows.svg.png'/>")
label4 = widgets.Label(
    value='AES ShiftRows Step',
    style={'description_width': 'initial'}
)
vbox_img4 = widgets.VBox([img4, label4])

def sr(Plaintext):
    p = pad_or_truncate(Plaintext)
    ShiftRowsDemo(p)
    return Plaintext
demo_sr = interactive(sr, Plaintext='ExamplePlaintext')

vbox_sr = widgets.VBox([heading_sr, text_sr, vbox_img4, heading_demo, text_demo, demo_sr])

# Widgets for MixColumns Demo

heading_mc = widgets.HTML(value="<h1>MixColumns</h1>")
text_mc = widgets.HTML(value="The MixColumns step performs a matrix multiplication in GF(2^8) between the state array and a predefined fixed matrix, multiplying each column of the state array with the fixed matrix to obtain the column for the output:")

img5 = widgets.HTML(value="<img src='https://wikimedia.org/api/rest_v1/media/math/render/svg/b35516e14dcf7ed323058752cfbe832f2db5f305'/>")
label5 = widgets.Label(
    value='AES MixColumns Step',
    style={'description_width': 'initial'}
)
vbox_img5 = widgets.VBox([img5, label5])

def mc(Plaintext):
    p = pad_or_truncate(Plaintext)
    MixColumnsDemo(p)
    return Plaintext
demo_mc = interactive(mc, Plaintext='ExamplePlaintext')

vbox_mc = widgets.VBox([heading_mc, text_mc, vbox_img5, heading_demo, text_demo, demo_mc])

# Widgets for AddRoundKey Demo

heading_ark = widgets.HTML(value="<h1>AddRoundKey</h1>")
text_ark = widgets.HTML(value="A round key for each processing round is generated from the key expansion process. The AddRoundKey step is just a simple XOR of the state array and the round key, since addition is equivalent to XOR in GF arithmetic:")

img6 = widgets.HTML(value="<img src='https://upload.wikimedia.org/wikipedia/commons/thumb/a/ad/AES-AddRoundKey.svg/400px-AES-AddRoundKey.svg.png'/>")
label6 = widgets.Label(
    value='AES AddRoundKey Step',
    style={'description_width': 'initial'}
)
vbox_img6 = widgets.VBox([img6, label6])

def ark(Plaintext, Key):
    p = pad_or_truncate(Plaintext)
    k = pad_or_truncate(Key)
    AddRoundKeyDemo(p, k)
demo_ark = interactive(ark, Plaintext='ExamplePlaintext', Key='Example_RoundKey')

vbox_ark = widgets.VBox([heading_ark, text_ark, vbox_img6, heading_demo, text_demo, demo_ark])

# Tab Widget for Processing Steps

tab_titles = ['SubBytes', 'ShiftRows', 'MixColumns', 'AddRoundKey']
tab_contents = [vbox_sb, vbox_sr, vbox_mc, vbox_ark]
tab = widgets.Tab(children=[widgets.VBox([vbox]) for vbox in tab_contents],
                  layout=Layout(width='65%'))
for i in range(len(tab_titles)):
    tab.set_title(i, tab_titles[i])

heading_steps = widgets.HTML(value="<h1>Processing Steps</h1>")
text_steps = widgets.HTML(value="<li>SubBytes - Substitution of each byte of the state array using a substitution box (S-box)</li><li>ShiftRows - Shifting of each row of the state array to the left by a certain number of positions</li><li>MixColumns - Matrix multiplication between the state array and a predefined matrix</li><li>AddRoundKey - XOR of the state array and the round key</li>")

vbox_steps = widgets.VBox([heading_steps,
                           text_steps,
                           heading_sbox,
                           text_sbox1,
                           text_sbox2,
                           text_sbox3,
                           text_sbox4,
                           text_sbox5,
                           text_sbox6])

In [14]:
def gfunc(w, rc):
    w_byte = HexToByte(w)
    
    # Left shift
    w_byte_shifted = [w_byte[1], w_byte[2], w_byte[3], w_byte[0]]
    w_shifted = "".join(w_byte_shifted)
    
    # SubBytes
    w_subbed = SubBytes(w_shifted)
    
    # XOR most significant byte with round constant
    msb = int(w_subbed[0:2],16) ^ rc
    w_new = format(msb, 'x') + w_subbed[2:]
    return w_new

def KeyExpansion(key_hex):
    print("Key Expansion")
    print("-------------")
    key_word = HexToWord(key_hex)
    n = len(key_word)
    print("Word 1 to {}: {}".format(n, key_word))
    c = 4 * (n+6) // n
    rc = []
    for i in range(0, c):
        if i == 0:
            rc.append(1)
        elif rc[i-1] < 128:
            rc.append(2 * rc[i-1])
        else:
            rc.append((2 * rc[i-1]) ^ int("11b", 16))

    w_new = key_word
    for i in range(0, c):
        w = HexToWord(key_hex)
        g = gfunc(w[-1], rc[i])
        for j in range(0, n):
            if j == 0:
                w_new[j] = xor(g, w[-n], 8)
            elif n > 6 and j % n == 4:
                w_new[j] = xor(SubBytes(w_new[j-1]), w[j-n], 8)
            else:
                w_new[j] = xor(w_new[j-1], w[j-n], 8)

        print("Word {} to {}: {}".format(n*(i+1)+1, n*(i+2), w_new))
        word_new = "".join(w_new)
        key_hex += word_new

    print("")
    print("Final expanded key:")
    if len(HexToWord(key_hex)) != c * n + 4:
        diff = len(HexToWord(key_hex)) - (c * n + 4)
        print(key_hex[:-8*diff])
        return (key_hex[:-8*diff])
    else:
        print(key_hex)
        return key_hex

In [15]:
# Widgets for Key Expansion

heading_ke = widgets.HTML(value="<h1>Key Expansion</h1>")
text_ke1 = widgets.HTML(value="The Key Expansion process expands the original key into a key schedule, generating different round keys to be used in the AddRoundKey step for each processing round. The longer the key length, the more the number of round keys required to be generated. The key schedule consists of the initial round key before the processing rounds as well as the round keys for all the rounds, in the form of 4-byte words.<br><br>The table below shows the number of round keys required and the number of words in the expanded key schedule for each type of AES:<br>")
text_ke2 = widgets.HTML(value="To expand a key into a key schedule, the initial key is used to generate the next few words, and those few words are then used to generate the next few; this will continue until all the round keys have been generated to form the finished key schedule. So how are the next few words generated from the previous group of words?<br>")
text_ke3 = widgets.HTML(value="<h2>g() Function</h2>")
text_ke4 = widgets.HTML(value="The first word of the new group is obtained by XORing the first word of the previous group with the result of the function g() applied to the last word of the previous group. The g() does the following:<ol><li>Left shift of the bytes</li><li>Substitution of each byte using the S-box</li><li>XORing the most significant byte with a round constant</li></ol>")
text_ke5 = widgets.HTML(value="<h2>Round Constant</h2>")
text_ke6 = widgets.HTMLMath(value=r"The round constant is determined as follows:<br><br>$ rc_1 = 1 $<br>$ rc_i = 2 \times rc_{i-1} $ (Multiplication in GF($2^8$))<br>")
text_ke7 = widgets.HTMLMath(value=r"Each AES type requires a different number of round constants for the key expansion:<ul><li>AES-128: up to $rc_{10}$</li><li>AES-192: up to $rc_8$</li><li>AES-256: up to $rc_7$</li></ul>")
text_ke8 = widgets.HTML(value="<h2>Generating New Words</h2>")
text_ke9 = widgets.HTMLMath(value=r"The remaining words of the new group are obtained by XORing the previously generated word with the corresponding word in the previous group.<br><br>In AES-128 for example, to generate $w_{i+4}, w_{i+5}, w_{i+6}, w_{i+7}$ from $w_i, w_{i+1}, w_{i+2}, w_{i+3}$:<br>")
text_ke10 = widgets.HTMLMath(value=r"$ w_{i+4} = w_i \oplus g(w_{i+3}) $<br>$ w_{i+5} = w_{i+4} \oplus w_{i+1} $<br>$ w_{i+6} = w_{i+5} \oplus w_{i+2} $<br>$ w_{i+7} = w_{i+6} \oplus w_{i+3} $<br><br>")
text_ke11 = widgets.HTML(value="For AES-256, there is an additional operation when generating the 5th word of each new group; an extra SubBytes function is performed on the previously generated word before XORing with the corresponding word in the previous group:<br>")
text_ke12 = widgets.HTMLMath(value=r"$ w_{i+8} = w_i \oplus g(w_{i+7}) $<br>$ w_{i+9} = w_{i+8} \oplus w_{i+1} $<br>$ w_{i+10} = w_{i+9} \oplus w_{i+2} $<br>$ w_{i+11} = w_{i+10} \oplus w_{i+3} $<br><br>$ w_{i+12} = SubBytes(w_{i+11}) \oplus w_{i+4} $<br>$ w_{i+13} = w_{i+12} \oplus w_{i+5} $<br>$ w_{i+14} = w_{i+13} \oplus w_{i+6} $<br>$ w_{i+15} = w_{i+14} \oplus w_{i+7} $<br>")
text_ke13 = widgets.HTML(value="<br>For AES-192 and AES-256, as the total number of words after the key expansion process exceeds the number of words required in the key schedule, the extra generated words are removed (2 for AES-192 and 4 for AES-256):<br>")

tb_ke1 = widgets.Output()
with tb_ke1:
    data = [['128-bit','192-bit','256-bit'],
            ['4 words','6 words','8 words'],
            [10,12,14],
            [11,13,15],
            ['11 x 4 = 44 words','13 x 4 = 52 words','15 x 4 = 60 words']]
    display(pd.DataFrame(data,
                         columns = ['AES-128','AES-192','AES-256'],
                         index = ['Key Size/Length',
                                  '',
                                  'No. of Processing Rounds',
                                  'No. of Round Keys',
                                  'Key Schedule']))

tb_ke2 = widgets.Output()
with tb_ke2:
    data = [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
            ['0x01', '0x02', '0x04', '0x08', '0x10', '0x20', '0x40', '0x80', '0x1B', '0x36']]
    display(pd.DataFrame(data,
                         columns = ['','','','','','','','','',''],
                         index = ['i', 'rc[i] (in hex)']))

tb_ke3 = widgets.Output()
with tb_ke3:
    data = [['4 words','6 words','8 words'],
            [44,52,60],
            [10,8,7],
            ['4 + (10 x 4) = 44','6 + (8 x 6) = 54','8 + (7 x 8) = 64'],
            ['44 - 44 = 0','54 - 52 = 2','64 - 60 = 4']]
    display(pd.DataFrame(data,
                         columns = ['AES-128','AES-192','AES-256'],
                         index = ['Key Length',
                                  'Req. No. of Words in Key Schedule',
                                  'No. of Key Expansion Rounds',
                                  'No. of Words after Key Expansion',
                                  'Extra Words to be Removed']))

vbox_ke = widgets.VBox([heading_ke,
                        text_ke1,
                        tb_ke1,
                        text_ke2,
                        text_ke3,
                        text_ke4,
                        text_ke5,
                        text_ke6,
                        tb_ke2,
                        text_ke7,
                        text_ke8,
                        text_ke9,
                        text_ke10,
                        text_ke11,
                        text_ke12,
                        text_ke13,
                        tb_ke3])

In [16]:
# AES Encryption
def Encryption(s_hex, expanded_key_hex, rounds, demo):
    round_keys = np.array_split(HexToWord(expanded_key_hex), rounds+1)
    for i in range(0, rounds+1):
        if i == 0:
            s_hex = AddRoundKey(s_hex, round_keys, 0)
        if 1 <= i <= rounds-1:
            s_hex = SubBytes(s_hex)
            s_hex = ShiftRows(s_hex)
            s_hex = MixColumns(s_hex)
            s_hex = AddRoundKey(s_hex, round_keys, i)
        if i == rounds:
            s_hex = SubBytes(s_hex)
            s_hex = ShiftRows(s_hex)
            s_hex = AddRoundKey(s_hex, round_keys, rounds)
            encrypted = s_hex
        if demo == True:
            print("Round", i, "Output:", s_hex)
    return encrypted

# AES Decryption
def Decryption(s_hex, expanded_key_hex, rounds, demo):
    round_keys = np.array_split(HexToWord(expanded_key_hex), rounds+1)
    for i in range(0, rounds+1):
        if i == 0:
            s_hex = AddRoundKey(s_hex, round_keys, rounds)
        if 1 <= i <= rounds-1:
            s_hex = InvShiftRows(s_hex)
            s_hex = InvSubBytes(s_hex)
            s_hex = AddRoundKey(s_hex, round_keys, rounds-i)
            s_hex = InvMixColumns(s_hex)
        if i == rounds:
            s_hex = InvShiftRows(s_hex)
            s_hex = InvSubBytes(s_hex)
            s_hex = AddRoundKey(s_hex, round_keys, 0)
            decrypted = s_hex
        if demo == True:
            print("Round", i, "Output:", s_hex)
    return decrypted

In [17]:
# Widgets for Encryption/Decryption Demo

heading_e = widgets.HTML(value="<h1>AES Encryption Process</h1>")
text_e1 = widgets.HTML(value="For the initial round (Round 0), AddRoundKey is applied to the state array and the first 4 words of the key schedule (initial key). From round 1 to the second last round, the 4 processing steps are then performed as follows:")
text_e2 = widgets.HTML(value="<ol><li>SubBytes</li><li>ShiftRows</li><li>MixColumns</li><li>AddRoundKey</li></ol>")
text_e3 = widgets.HTML(value="For the final round, there is no MixColumns step:")
text_e4 = widgets.HTML(value="<ol><li>SubBytes</li><li>ShiftRows</li><li>AddRoundKey</li></ol>")
vbox_enc = widgets.VBox([heading_e, text_e1, text_e2, text_e3, text_e4])

heading_d = widgets.HTML(value="<h1>AES Decryption Process</h1>")
text_d1 = widgets.HTML(value="For the initial round (Round 0), AddRoundKey is applied to the state array and the final 4 words of the key schedule. From round 1 to the second last round, the 4 processing steps are then performed as follows:")
text_d2 = widgets.HTML(value="<ol><li>InvShiftRows</li><li>InvSubBytes</li><li>AddRoundKey</li><li>InvMixColumns</li></ol>")
text_d3 = widgets.HTML(value="For the final round, there is no InvMixColumns step:")
text_d4 = widgets.HTML(value="<ol><li>InvShiftRows</li><li>InvSubBytes</li><li>AddRoundKey</li></ol>")
vbox_dec = widgets.VBox([heading_d, text_d1, text_d2, text_d3, text_d4])

heading_demo_ed = widgets.HTML(value="<h2>Encryption/Decryption Demo:</h2>")
text_demo_ed = widgets.HTML(value="Please select the AES type and enter your inputs before clicking on the 'Encrypt and Decrypt' button to view the Encryption and Decryption process:")

select_key_length_ed = widgets.Dropdown(
    options=[('AES-128', 128), ('AES-192', 192), ('AES-256', 256)],
    value=128,
    description='AES Type:',
)

input_key_ed = widgets.Text(
    value='thisisakeyforaes',
    placeholder='Type something here',
    description='Key:',
    disabled=False
)

input_plaintext_ed = widgets.Text(
    value='Hello, World!',
    placeholder='Type something here',
    description='Plaintext:',
    disabled=False
)

hbox_enc_dec = widgets.HBox([vbox_enc, vbox_dec])
vbox_enc_dec = widgets.VBox([hbox_enc_dec,
                             heading_demo_ed,
                             text_demo_ed,
                             select_key_length_ed,
                             input_key_ed,
                             input_plaintext_ed])

button_enc_dec = widgets.Button(
                description='Encrypt and Decrypt',
                tooltip='Send',
                style={'description_width': 'initial'}
            )

output = widgets.Output()

def EncryptionDemo(key_length, key, plaintext):
    rounds = NumOfRounds(key_length)
    print("Key:", key)
    key_hex = StrToHex(key)
    print("Key (in hex):", key_hex)
    print()
    expanded_key_hex = KeyExpansion(key_hex)
    print()
    print("Plaintext:", plaintext)
    s_hex = AddPadding(plaintext)
    print("Plaintext after padding (in hex):", s_hex)
    print()
    print("Encryption")
    print("----------")
    print()
    print("No. of round keys:", rounds+1)
    print()
    encrypted = Encryption(s_hex, expanded_key_hex, rounds, True)
    print()
    print("Encrypted:", encrypted)
    return encrypted, expanded_key_hex, rounds

def DecryptionDemo(encrypted, expanded_key_hex, rounds):
    print("Decryption")
    print("----------")
    decrypted = Decryption(encrypted, expanded_key_hex, rounds, True)
    decrypted = RemovePadding(decrypted)
    print()
    print("Padding removed (in hex):", decrypted)
    decrypted_s = HexToStr(decrypted)
    print("Decrypted:", decrypted_s)

def on_button_clicked_enc_dec(event):
    output.clear_output()
    with output:
        key_length = select_key_length_ed.value
        key = input_key_ed.value
        plaintext = input_plaintext_ed.value
        # To check for invalid plaintext/key input length
        if len(plaintext) == 0:
            print("Please input a plaintext string for the demo")
        elif len(key) != key_length // 8:
            print("Key must have {} characters for AES-{}".format(key_length // 8, key_length))
        else:
            encrypted, expanded_key_hex, rounds = EncryptionDemo(key_length, key, plaintext)
            print()
            DecryptionDemo(encrypted, expanded_key_hex, rounds)

button_enc_dec.on_click(on_button_clicked_enc_dec)

vbox_result_enc_dec = widgets.VBox([vbox_enc_dec, button_enc_dec, output])

In [18]:
# ECB MODE
def AESEncryption_ECB(p_hex, expanded_key_hex, rounds):
    blocks = np.array_split(HexToWord(p_hex), len(p_hex)//32)
    encrypted_ecb = ""
    for i in range(0, len(p_hex)//32):
        plaintext_block = "".join(blocks[i])
        ciphertext_block = Encryption(plaintext_block, expanded_key_hex, rounds, True)
        encrypted_ecb += ciphertext_block
    return encrypted_ecb

def AESDecryption_ECB(p_hex, expanded_key_hex, rounds):
    blocks = np.array_split(HexToWord(p_hex), len(p_hex)//32)
    decrypted_ecb = ""
    for i in range(0, len(p_hex)//32):
        ciphertext_block = "".join(blocks[i])
        plaintext_block = Decryption(ciphertext_block, expanded_key_hex, rounds, True)
        decrypted_ecb += plaintext_block
    return decrypted_ecb

# CBC MODE
def AESEncryption_CBC(p_hex, expanded_key_hex, rounds, iv):
    blocks = np.array_split(HexToWord(p_hex), len(p_hex)//32)
    prev_ciphertext_block = iv
    encrypted_cbc = ""
    for i in range(0, len(p_hex)//32):
        plaintext_block = "".join(blocks[i])
        xored_block = xor(plaintext_block, prev_ciphertext_block, 32)
        ciphertext_block = Encryption(xored_block, expanded_key_hex, rounds, True)
        encrypted_cbc += ciphertext_block
        prev_ciphertext_block = ciphertext_block
    return encrypted_cbc

def AESDecryption_CBC(p_hex, expanded_key_hex, rounds, iv):
    blocks = np.array_split(HexToWord(p_hex), len(p_hex)//32)
    prev_ciphertext_block = iv
    decrypted_cbc = ""
    for i in range(0, len(p_hex)//32):
        ciphertext_block = "".join(blocks[i])
        decrypted_block = Decryption(ciphertext_block, expanded_key_hex, rounds, True)
        # P[i] = Decryption(C[i]) ^ C[i-1]
        plaintext_block = xor(decrypted_block, prev_ciphertext_block, 32)
        decrypted_cbc += plaintext_block
        prev_ciphertext_block = ciphertext_block
    return decrypted_cbc

In [19]:
# Widgets for Cipher Modes Demo

heading_ecb = widgets.HTML(value="<h1>Electronic Codebook (ECB) Mode</h1>")
text_ecb1 = widgets.HTML(value="For ECB mode, the plaintext is divided into blocks to be encrypted individually to obtain the ciphertext blocks. Similarly for decryption, the ciphertext is split into blocks before being decrypted to obtain the original plaintext.")
text_ecb2 = widgets.HTML(value="<img src='https://upload.wikimedia.org/wikipedia/commons/thumb/d/d6/ECB_encryption.svg/400px-ECB_encryption.svg.png' alt='ECB Encryption'/><br>ECB Encryption")
text_ecb3 = widgets.HTML(value="<img src='https://upload.wikimedia.org/wikipedia/commons/thumb/e/e6/ECB_decryption.svg/400px-ECB_decryption.svg.png' alt='ECB Decryption'/><br>ECB Decryption")
text_ecb4 = widgets.HTML(value="")
vbox_ecb = widgets.VBox([heading_ecb, text_ecb1, text_ecb2, text_ecb3, text_ecb4])

heading_cbc = widgets.HTML(value="<h1>Cipher Block Chaining (CBC) Mode</h1>")
text_cbc1 = widgets.HTML(value="For CBC mode, the plaintext is divided into blocks and each block is XORed with the previous ciphertext block before encryption; an initialization vector (IV) is used to XOR with the first plaintext block. For decryption, each ciphertext block is decrypted before XORing with the previous ciphertext block, with the first being XORed with the IV, to obtain the plaintext blocks.")
text_cbc2 = widgets.HTML(value="<img src='https://upload.wikimedia.org/wikipedia/commons/thumb/8/80/CBC_encryption.svg/400px-CBC_encryption.svg.png' alt='CBC Encryption'/><br>CBC Encryption")
text_cbc3 = widgets.HTML(value="<img src='https://upload.wikimedia.org/wikipedia/commons/thumb/2/2a/CBC_decryption.svg/400px-CBC_decryption.svg.png' alt='CBC Decryption'/><br>CBC Decryption")
text_cbc4 = widgets.HTML(value="")
vbox_cbc = widgets.VBox([heading_cbc, text_cbc1, text_cbc2, text_cbc3, text_cbc4])

heading_demo_cm = widgets.HTML(value="<h2>ECB & CBC Mode Demo:</h2>")
text_demo_cm = widgets.HTML(value="Please select the AES type and enter your inputs before clicking on the 'Encrypt and Decrypt' button to view the Encryption and Decryption process for the 2 modes:")

select_key_length_cm = widgets.Dropdown(
    options=[('AES-128', 128), ('AES-192', 192), ('AES-256', 256)],
    value=128,
    description='AES Type:',
)

input_key_cm = widgets.Text(
    value='thisisakeyforaes',
    placeholder='Type something here',
    description='Key:',
    disabled=False
)

input_plaintext_cm = widgets.Text(
    value='Hello, World!',
    placeholder='Type something here',
    description='Plaintext:',
    disabled=False
)

input_iv_cm = widgets.Text(
    value='0000000000000000',
    placeholder='Type something here',
    description='IV (for CBC only):',
    style={'description_width': 'initial'},
    disabled=False
)

hbox_cm = widgets.HBox([vbox_ecb, vbox_cbc])
vbox_cm = widgets.VBox([hbox_cm,
                         heading_demo_cm,
                         text_demo_cm,
                         select_key_length_cm,
                         input_key_cm,
                         input_plaintext_cm,
                         input_iv_cm])

button_cm = widgets.Button(
                description='Encrypt and Decrypt',
                style={'description_width': 'initial'}
            )

output_cm = widgets.Output()

def EncryptionDemo_CM(key_length, key, plaintext):
    rounds = NumOfRounds(key_length)
    print("Key:", key)
    key_hex = StrToHex(key)
    print("Key (in hex):", key_hex)
    print()
    expanded_key_hex = KeyExpansion(key_hex)
    print()
    print("Plaintext:", plaintext)
    s_hex = AddPadding(plaintext)
    print("Plaintext after padding (in hex):", s_hex)
    return s_hex, expanded_key_hex, rounds

def Demo_ECB(s_hex, expanded_key_hex, rounds):
    print("ECB Mode Encryption")
    print("-------------------")
    print()
    encrypted_ecb = AESEncryption_ECB(s_hex, expanded_key_hex, rounds)
    print()
    print("Encrypted:", encrypted_ecb)
    print()
    print("ECB Mode Decryption")
    print("-------------------")
    decrypted_ecb = AESDecryption_ECB(encrypted_ecb, expanded_key_hex, rounds)
    decrypted_ecb = RemovePadding(decrypted_ecb)
    print()
    print("Padding removed (in hex):", decrypted_ecb)
    decrypted_s = HexToStr(decrypted_ecb)
    print("Decrypted:", decrypted_s)

def Demo_CBC(s_hex, expanded_key_hex, rounds, iv):
    print("CBC Mode Encryption")
    print("-------------------")
    print()
    print("IV:", iv)
    print()
    encrypted_cbc = AESEncryption_CBC(s_hex, expanded_key_hex, rounds, iv)
    print()
    print("Encrypted:", encrypted_cbc)
    print()
    print("CBC Mode Decryption")
    print("-------------------")
    decrypted_cbc = AESDecryption_CBC(encrypted_cbc, expanded_key_hex, rounds, iv)
    decrypted_cbc = RemovePadding(decrypted_cbc)
    print()
    print("Padding removed (in hex):", decrypted_cbc)
    decrypted_s = HexToStr(decrypted_cbc)
    print("Decrypted:", decrypted_s) 

def on_button_clicked_cm(event):
    output_cm.clear_output()
    with output_cm:
        key_length = select_key_length_cm.value
        key = input_key_cm.value
        plaintext = input_plaintext_cm.value
        iv = input_iv_cm.value
        # To check for invalid plaintext/key input length
        if len(plaintext) == 0:
            print("Please input a plaintext string for the demo")
        elif len(key) != key_length // 8:
            print("Key must have {} characters for AES-{}".format(key_length // 8, key_length))
        elif len(iv) != 16:
            print("IV must have 16 characters")
        else:
            s_hex, expanded_key_hex, rounds = EncryptionDemo_CM(key_length, key, plaintext)
            print()
            Demo_ECB(s_hex, expanded_key_hex, rounds)
            print()
            Demo_CBC(s_hex, expanded_key_hex, rounds, StrToHex(iv))

button_cm.on_click(on_button_clicked_cm)

vbox_result_cm = widgets.VBox([vbox_cm, button_cm, output_cm])

In [20]:
# Menu Buttons

button_intro = widgets.Button(
                description='Introduction to AES',
                style={'description_width': 'initial'})

button_gf = widgets.Button(
                description='Galois Field',
                style={'description_width': 'initial'})

button_steps = widgets.Button(
                description='Processing Steps',
                style={'description_width': 'initial'})

button_ke = widgets.Button(
                description='Key Expansion',
                style={'description_width': 'initial'})

button_ed = widgets.Button(
                description='Encryption/Decryption',
                style={'description_width': 'initial'})

button_cm = widgets.Button(
                description='Cipher Modes',
                style={'description_width': 'initial'})

main_output = widgets.Output()

def on_button_clicked_intro(event):
    main_output.clear_output()
    with main_output:
        display(vbox_intro)
button_intro.on_click(on_button_clicked_intro)

def on_button_clicked_gf(event):
    main_output.clear_output()
    with main_output:
        display(vbox_gf)
button_gf.on_click(on_button_clicked_gf)

def on_button_clicked_steps(event):
    main_output.clear_output()
    with main_output:
        print()
        display(widgets.HBox([vbox_steps, tab]))
        print("S-box:")
        display(s_box)
        print()
        print("Inverse S-box:")
        display(inv_s_box)
button_steps.on_click(on_button_clicked_steps)

def on_button_clicked_ke(event):
    main_output.clear_output()
    with main_output:
        display(vbox_ke)
button_ke.on_click(on_button_clicked_ke)

def on_button_clicked_ed(event):
    main_output.clear_output()
    with main_output:
        display(vbox_result_enc_dec)
button_ed.on_click(on_button_clicked_ed)

def on_button_clicked_cm(event):
    main_output.clear_output()
    with main_output:
        display(vbox_result_cm)
button_cm.on_click(on_button_clicked_cm)

hbox_buttons = widgets.HBox([button_intro, button_gf, button_steps, button_ke, button_ed, button_cm])
vbox_main = widgets.VBox([hbox_buttons, main_output])
vbox_main

VBox(children=(HBox(children=(Button(description='Introduction to AES', style=ButtonStyle()), Button(descripti…

<br><br><br><br><br><br><br><br><br><br>

---
Images retrieved from:<br><br>
\[1\] Design and Analysis of Multimedia Communication System - Scientific Figure on ResearchGate. Available from: https://www.researchgate.net/figure/AES-Encryption-Decryption-Flowchart_fig2_221958203 \[accessed 22 Jul, 2021\]
<br>
\[2\] https://en.wikipedia.org/wiki/Rijndael_S-box
<br>
\[3\] https://en.wikipedia.org/wiki/Advanced_Encryption_Standard
<br>
\[4\] https://en.wikipedia.org/wiki/Block_cipher_mode_of_operation

In [2]:
#!pip freeze > requirements.txt