In [None]:
from test_aes import *

# AES Cipher

The following Figure shows the AES cipher in more detail, indicating the sequence of transformations in each round and showing the corresponding decryption function. In this notebook, we are going to implement AES cipher which performs encryption and decryption.

<img src='aes_images/aes.png' width=50%>


Notice that

- The first $N - 1$ rounds consist of four distinct transformation functions: SubBytes, ShiftRows, MixColumns, and AddRoundKey. 
- The final round contains only three transformations, 




## Import AES Round API

**Exercise:** We have implement the AES round procedure in `AES_round_function` notebook for homework. Do the following procedure:
1. Backup your original `AES_round_function.ipynb`.
2. Remove `from test_AES import *` and all sanity checks and grade calculation cells from  `AES_round_function.ipynb` and save the notebook.
3. Download the `AES_round_function.ipynb` as `aes_round_api.py`
4. Import `aes_round_enc`, `aes_round_dec` and `add_round_key` from `aes_round_api`


In [2]:
# Import aes_round_enc and aes_round_dec from aes_round_api
from aes_round_api import aes_round_enc, aes_round_dec, add_round_key

## Import Key Expansion API

**Exercise:** We have implement the AES key expansion in `aes_key_expansion` notebook for homework. Do the following procedure:
1. Backup your original `aes_key_expansion.ipynb`.
2. Remove `from test_aes_key_expansion import *` and all sanity checks and grade calculation cells from  `aes_key_expansion.ipynb` and save the notebook.
3. Download the `aes_key_expansion.ipynb` as `aes_key_api.py`
4. Import `expand_key_128`, `expand_key_192` and `expand_key_256` from `aes_key_api`



In [3]:
# Import expand_key_128, expand_key_192 and expand_key_256 from aes_key_api
from aes_key_api import expand_key_128, expand_key_192, expand_key_256

## Implement AES Cipher

**Exercise:** Implement `aes_cipher(input_block, initial_key, mode)` to perform AES encryption and decryption for different key length where:

- `input_block` represent `Plaintext` when `mode` is encryption 'E' and `Ciphertext` when `mode` is decryption 'D'
- `initial_key` represent the key.

The cipher consists of $N$ rounds, where the number of rounds depends on the key length: 10 rounds for a 16-byte key, 12 rounds for a 24-byte key, and 14 rounds for a 32-byte key as listed below:

<img src='aes_images/aes_table.png' width=50%>

To implement `aes_cipher()`:

- Use `aes_round_enc` and `aes_round_dec` for encyption and decryption round operation.
- Use `expand_key_128`, `expand_key_192`, and `expand_key_256` to expand key for different key length.
- Notice that there is a initial single transformation (AddRoundKey) before the first round, which can be considered Round 0 (see figure 6.3). Use `add_round_key` for this step. 

In [146]:
def aes_cipher(input_block, initial_key, mode):
    key_len_table = {16, 24, 32}
    keylen_round_table = {16:10, 24:12, 32:14}
    
    # verify mode should be either 'E' or 'D' using assert statement
    assert (mode == 'E' or mode == 'D'), "Must be E or D"
    
    
    # verify that key has a valid length (according to the table above) using assert statement and 
    # perform key expansion based on the initial_key 
    assert len(initial_key) in keylen_round_table, "Invalid Key Length"
   
    expanded_key = bytearray()
    
    num_rounds = keylen_round_table[len(initial_key)]
    if len(initial_key) == 16:   
        expanded_key = expand_key_128(initial_key)
        
    elif len(initial_key) == 24:
        expanded_key = expand_key_192(initial_key)
        
    else:
        expanded_key = expand_key_256(initial_key)
        
    # verify that the inputblock has a valid length using assert statement and
    # perform aes encryption and decryption based on the mode ('E' or 'D')
    # Notice that there is a initial single transformation (AddRoundKey)
    # before the first round, which can be considered Round 0 
    assert len(input_block) == 16, "Invalid Block Size"    
    
    output_block = bytearray()
    
    if mode == 'E':
        aes_round_encrypt = bytearray()
        aes_round_encrypt = add_round_key(input_block, expanded_key[:16])        
        
        for itr in range(1,num_rounds):
            round_key = expanded_key[itr*16:itr*16+16]
            aes_round_encrypt = aes_round_enc(aes_round_encrypt, round_key)
            #print("Round: " + str(itr) + " Enc: " + str(aes_round_encrypt) + str(len(round_key)))

        last_round_key = expanded_key[num_rounds*16:num_rounds*16+16]
        aes_round_encrypt = aes_round_enc(aes_round_encrypt, last_round_key, True)
        #print("Round: E Enc: " + str(aes_round_encrypt))
        
        output_block = aes_round_encrypt
        
    if mode == 'D':
        aes_round_decrypt = bytearray()
        last_round_key = expanded_key[num_rounds*16:num_rounds*16+16]
        aes_round_decrypt = add_round_key(input_block, last_round_key)
        
        for itr2 in range(num_rounds-1, 0, -1):
            round_key = expanded_key[itr2*16:itr2*16+16]
            aes_round_decrypt = aes_round_dec(aes_round_decrypt, round_key)
            #print("Round: " + str(itr2) + " Dec: " + str(aes_round_decrypt) + str(len(round_key)))
        
        first_round_key = expanded_key[:16]
        aes_round_decrypt = aes_round_dec(aes_round_decrypt, first_round_key, True)
        #print("Round: E Dec: " + str(aes_round_decrypt))
        
        output_block = aes_round_decrypt

    return output_block
            

### Grade
Run the following cell to calculate your grade.

In [147]:
exercise_functions = get_module_functions(sys.modules[__name__])
evaluate(exercise_functions)

+-------+----------+--------+-------+
| Index | Exercise | Passed | Grade |
+-------+----------+--------+-------+
| 0     | aes_128  | True   | 10    |
| 1     | aes_192  | True   | 10    |
| 2     | aes_256  | True   | 10    |
| 3     | key_size | True   | 10    |
| 4     | mode     | True   | 10    |
+-------+----------+--------+-------+
Grade: 100.00
