In [669]:
import hashlib
import numpy as np
from tqdm import tqdm
import aeskeyschedule as aes
from Crypto.Cipher import AES
import cProfile

##### Task 1a - Extract the primes of RSA-CRT (Bellcore)

In [845]:
s_bytes = bytes.fromhex("14 48 FA 66 0D 3D EE 69 3A 9A E1 0E 3D BE 17 6D D0 AE 96 37 F2 89 60 03 36 7F AB 3F 71 C1 BE 2C 8A 14 3D D9 E4 16 \
                        7C 9A 07 80 1E 66 6A C2 68 F3 76 EE 6B 9A 27 75 23 22 E1 BB D1 6F 8D DA 2B 90 05 8A 07 B1 AB 56 45 37 C8 00 95 3B D2 \
                        37 71 D8 CF D0 8B 96 D3 4B A6 01 3B 10 38 3B 19 D0 F2 63 E0 EE BC 7D 09 FD FE A0 03 D7 3D DA 88 5D 25 A3 C2 87 0C F8 E5 \
                         FF E7 20 1A E7 58 74 F3 83 09 7B")

s_int = int.from_bytes(s_bytes, byteorder='big', signed=False)

s_prime_bytes = bytes.fromhex("9A BD 38 BD D4 61 B4 F8 F6 A1 82 4B 1B 43 D4 1C 18 31 90 71 CA 60 28 86 55 76 C3 2A 25 85 32 BB 08 A4 49 F2 \
                              93 72 F4 BD B0 16 B5 A1 F5 7A AA 96 BE 66 B1 7E CF 0C D2 BF 89 C7 CA C7 7E 5B 1A 43 46 06 88 C3 ED BF 6D A6 EB CE \
                               EA 9B 57 97 A4 FC 28 EC 93 EF 18 DA 6E E5 4A 52 38 61 AB ED C8 2C 4A 14 8E C5 C8 8D E1 C5 1B 6C 81 3C 8C 13 17 3E 85 \
                               26 D0 03 5E 2A 37 5C F7 22 2A 18 C2 86 0B 1A ")

s_prime_int = int.from_bytes(s_prime_bytes, byteorder='big', signed=False)

N_bytes = bytes.fromhex("9B 1F 16 A7 69 6A C9 0F A7 AE 61 5A 1F 71 BD 1A C0 C3 1B 37 A9 F1 43 76 BE C7 FB 70 14 12 F0 E3 B7 9C AB 88 \
                         F9 06 B3 50 B5 21 57 87 66 C7 8C AC D2 E8 06 32 D0 93 5F 50 CD DC 41 5D C1 B0 46 EB 3B 35 56 62 4E B4 12 D0 56 \
                        F8 73 E5 A0 56 C2 B8 5B 36 4D 03 2B BA A9 27 67 57 B0 58 87 9B 02 CB 20 98 D6 3C 61 55 1D 47 53 B1 AE 18 90 D8 FF \
                        79 BA 10 F8 23 07 49 2A 77 5A 57 15 AD 60 5B 56 01")

N_int = int.from_bytes(N_bytes, byteorder='big', signed=False)

def get_x(s, s_prime, N):
    x = np.gcd((s_prime - s), N)
    if x > 1:
        return x
    
q = get_x(s_int, s_prime_int, N_int)
p = N_int // q
print("p prime value is: ", hex(p)[2:], "\n", "and prime q is: ", hex(q)[2:])

p prime value is:  f9555b790d60dcb3fdcdf464b88ab7bb629bfce037f4154927df19fcdb1b4c7327d41b17d848455cffbda7e8080c08600be3af126df6c481ab25da70bec471c0fb 
 and prime q is:  9f44ddf28f05904455669a629df988adf203812f56aa8047c7db9bb7b4e61dd67b027e80d8700a77471943cc76370ced07056ef808a12b2a467c159e586c33


##### Task 1b - Extract the primes of RSA-CRT (Lenstra)


In [850]:
N_bytes = bytes.fromhex("C2 D2 BE 8E 72 2A E5 BB D2 3D FA D3 62 A0 8B 4D 32 A4 51 15 54 2E 23 E4 9B 35 46 58 33 38 CD 8B 8B A4 2E \
                        F2 89 B2 E4 47 E9 BF 6E AF 7F 24 D0 25 65 D2 24 AB DD D6 D2 F4 4A 6F 28 16 A4 32 31 96 94 2D F2 0D ED 8F 10 \
                        02 45 24 E1 B2 F0 2F 4A D0 C1 CB F7 C7 78 27 0B CD 70 8E BF A0 49 38 4E DE EF 24 C0 84 DA 3C A2 EE 14 6C A5 79 \
                        CC 42 AE E7 F6 D4 B0 F5 9E 58 43 A5 19 32 9B EB 5F 97 66 07")

N_int = int.from_bytes(N_bytes, byteorder='big', signed=False)

e_bytes = bytes.fromhex("01 00 01")
e_int = int.from_bytes(e_bytes, byteorder='big', signed=False)

m_bytes = bytes.fromhex("1E 8A DB 08 E9 8A 58 01 2C 55 A8 C4 19 74 7B D8 D8 DB 40 FA C2 40 DA 92 BF 48 74 F7 9E 9A D7 3B 20 A9 34 07 0C \
                        AA 60 C7 67 25 41 68 AB EB 37 95 56 18 45 8F 6B F9 4B 2D 7B A8 92 1D E7 E8 4F A6 7A F7 E0 D6 FE 9E DD 55 4A BF 44 \
                        18 F7 AE E8 D8 29 E6 EC 12 45 CF CB AF 58 96 67 96 3B 53 1B 89 AF 63 87 9C 9A 65 31 76 A0 3B A6 89 BC 5D D4 5D A6 63 \
                         91 0A 19 FA 49 6A 6A EF B3 F9 AD FF F6 96")

m_int = int.from_bytes(m_bytes, byteorder='big', signed=False)

s_prime_bytes = bytes.fromhex("25 72 EE 15 57 9D 2E 18 72 4E 98 A1 37 BC 82 CC 46 65 4E 04 E0 AF 22 7C 36 D7 B0 C2 9E F4 9D 1B 77 57 A3 \
                              67 71 2E BC 6C 8D AD 7E 52 66 78 86 0C CD 44 AF FB E0 C3 79 1F 4E 0B A3 E1 86 33 03 E8 07 CC 4B D8 A8 95 42 \
                              B2 21 58 D6 7D 99 DC 93 05 0A CA 58 4D 2D 06 95 0B 6D C6 15 7E 47 CF ED 4D C6 D8 77 E4 7A 0C 7F 1A 09 FE A4 11 \
                              5E BF 67 EF DA F4 A8 40 96 89 05 43 66 E5 87 86 E7 4D 2A BD ")

s_prime_int = int.from_bytes(s_prime_bytes, byteorder='big', signed=False)

def get_q(N, m, e, s_prime):
    inner = pow(s_prime, e, N)
    q = np.gcd(N, (m - inner))
    
    return q

q = get_q(N_int, m_int, e_int, s_prime_int)
p = N_int // q
d = pow(e_int, -1, (p-1)*(q-1))
print("p prime value is:", hex(p)[2:], "\n", "and prime q is:", hex(q)[2:], "\n", "and d is:", hex(d)[2:])

p prime value is: d4693216ca3210f1491477d556e709141f6b5ea57e8b64a51011190d607b6b92a601857e4ad26e2b45123804ebdd08ccd15b0e50edcdc8754d5b2bb99dc8286087 
 and prime q is: eacd987fce4c2815b8e1f6557a4120cd822763baa732e6fbd2d35d61b85f8278263ce068cddf6099ba885cda0b4ed1c2374de5d34b265fec3358611905ae81 
 and d is: 6f7345e591342f162230a8b392814b0f4f80268e7008b129cf0c4c009ad4cce91e6a3f1d2a5eb72ed86e55b079a8a2963248640819b1121f0411d0ba1b1647445bf2438288738d9ecaef90ed7b3a4d1170f22f28cc60396854b9508df0cc39397bbc4eae35428edb25416b6370bda32bc8d58ac6fceef0b3d94a0d65c7ef1501


##### Task 2a

In [670]:
# load in the dataset and check the successful glitches
ctext = np.fromfile('task2/full_dfa_data/ctext.bin', dtype=np.uint8).reshape(1000, 16)
ftext = np.fromfile('task2/full_dfa_data/ftext.bin', dtype=np.uint8).reshape(1000, 16)
ptext = np.fromfile('task2/full_dfa_data/ptext.bin', dtype=np.uint8).reshape(1000, 16)

def get_successful_glitches():
    no_glitches = []
    for i in tqdm(range(1000)):
        if np.all(ctext[i] == ftext[i]) == False:
            no_glitches.append(i)
    return no_glitches

print("The number of successful glitches in the dataset is", len(get_successful_glitches()))


'''A success rate of 402/10000 = 4.02% is obtained. If we consider the difficulty of the attack and the complexity involved, 
this seems like a high success rate. That means that on average, every 25th glitch of AES-128 is successful.'''

100%|██████████| 1000/1000 [00:00<00:00, 264091.68it/s]

The number of successful glitches in the dataset is 402





'A success rate of 402/10000 = 4.02% is obtained. If we consider the difficulty of the attack and the complexity involved, \nthis seems like a high success rate. That means that on average, every 25th glitch of AES-128 is successful.'

##### Task 2b.

In [671]:
'''To perform the DFA attack described by Piret and Quisquater, we will need 8 ciphertext/faultytext pairs to 
successfully recover all 16 bytes of the 10th round key. We will use 2 pairs each to recover 4 key bytes.'''

# action steps:
# 1. loop through all the cipher text and faulty text pairs
# 2. for each pair, xor them to see if they differ in known bytes
# 3. if they differ, add them to a list of pairs of their type, set a counter for each type
# 4. At the end of the loop, add up the counters for all types and check if the sum equals the total number of glitches to 
# confirm that all glitches happened in the first column
# 5. If the sum is equal to the total number of glitches, then print out two pairs of each type in hex format

def get_ctext_ftext_pairs(ext_col=float('inf')):
    col_1 = [] # 0, 7, 10, 13
    col_2 = [] # 1, 4, 11, 14
    col_3 = [] # 2, 5, 8, 15
    col_4 = [] # 3, 6, 9, 12

    counter = 0
    for i in range(1000):
        if np.all(ctext[i] == ftext[i]) == False:
            xor = np.bitwise_xor(ctext[i], ftext[i])
            if (xor[0] != 0 and xor[7] != 0 and xor[10] != 0 and xor[13] != 0):
                col_1.append([ctext[i], ftext[i]])
                counter += 1
            if (xor[1] != 0 and xor[4] != 0 and xor[11] != 0 and xor[14] != 0):
                col_2.append([ctext[i], ftext[i]])
                counter += 1
            if (xor[2] != 0 and xor[5] != 0 and xor[8] != 0 and xor[15] != 0):
                col_3.append([ctext[i], ftext[i]])
                counter += 1
            if (xor[3] != 0 and xor[6] != 0 and xor[9] != 0 and xor[12] != 0):
                col_4.append([ctext[i], ftext[i]])
                counter += 1
    
    if counter != 402:
        raise ValueError("Not all glitches happened in the first row, re-write the function to correctly extract all the keypairs \
                         since the fault injection is random")

    if ext_col == 1: return col_1
    elif ext_col == 2: return col_2
    elif ext_col == 3: return col_3
    elif ext_col == 4: return col_4


def extract_needed_pairs(pairs=2, col=float('inf')):
    pairs_for_col = get_ctext_ftext_pairs(ext_col=col)[:pairs]
    return pairs_for_col
    

def print_pairs_in_hex(column_no):
    arr_pairs = extract_needed_pairs(col=column_no)

    hex_arr_pairs = []
    for i in range(len(arr_pairs)):
        ctext = arr_pairs[i][0]
        ctext_hex_arr = [hex(x) for x in ctext]
        ftext = arr_pairs[i][1]
        ftext_hex_arr = [hex(x) for x in ftext]
        hex_arr_pairs.append([ctext_hex_arr, ftext_hex_arr])

    return hex_arr_pairs

def get_pairs_for_all_bytes(pairs=2):
    _0_7_10_13 = get_ctext_ftext_pairs(ext_col=1)[:pairs]
    # _0_7_10_13 = get_ctext_ftext_pairs(ext_col=1)[5:7]
    _1_4_11_14 = get_ctext_ftext_pairs(ext_col=2)[:pairs]
    _2_5_8_15 = get_ctext_ftext_pairs(ext_col=3)[:pairs]
    _3_6_9_12 = get_ctext_ftext_pairs(ext_col=4)[:pairs]
    dict = {'col_1': _0_7_10_13, 'col_2': _1_4_11_14, 'col_3': _2_5_8_15, 'col_4': _3_6_9_12}
    return dict

print the pairs needed to complete the full attack in hex format

In [672]:
print("---ciphertext/ftext pairs (in hex) to extract the first 4 bytes of the 10th round key corresponding to the first column are---")
print("Ctext:",print_pairs_in_hex(1)[0][0], "\n" , "Ftext:",print_pairs_in_hex(1)[0][1], "\n")
print("Ctext:",print_pairs_in_hex(1)[1][0], "\n" , "Ftext:",print_pairs_in_hex(1)[1][1], "\n")

print("---ciphertext/ftext pairs (in hex) to extract the next 4 bytes of the 10th round key corresponding to the second column are---")
print("Ctext:",print_pairs_in_hex(2)[0][0], "\n" , "Ftext:",print_pairs_in_hex(2)[0][1], "\n")
print("Ctext:",print_pairs_in_hex(2)[1][0], "\n" , "Ftext:",print_pairs_in_hex(2)[1][1], "\n")

print("---ciphertext/ftext pairs (in hex) to extract the next 4 bytes of the 10th round key corresponding to the third column are---")
print("Ctext:",print_pairs_in_hex(3)[0][0], "\n" , "Ftext:",print_pairs_in_hex(3)[0][1], "\n")
print("Ctext:",print_pairs_in_hex(3)[1][0], "\n" , "Ftext:",print_pairs_in_hex(3)[1][1], "\n")

print("---ciphertext/ftext pairs (in hex) to extract the next 4 bytes of the 10th round key corresponding to the fourth column are---")
print("Ctext:",print_pairs_in_hex(4)[0][0], "\n" , "Ftext:",print_pairs_in_hex(4)[0][1], "\n")
print("Ctext:",print_pairs_in_hex(4)[1][0], "\n" , "Ftext:",print_pairs_in_hex(4)[1][1], "\n")

'''Since I have been able to ascertain the location of the fault injection to be in the first row, I suppose I wouldn't need more than 8 pairs total to
extract all the 16 key bytes. However, if the location of the glitch was unknown or if the faults injected are not strictly in the first row, 
perhaps if the attacker is injecting faults at a location in the state where the effect of the fault is diffused over multiple rounds, then it may 
require more pairs of plaintext/ciphertext to isolate the fault effect and identify the correct key.
In addition, if the fault injection technique is imprecise and the faults injected are not strong enough to cause a distinguishable effect, 
then it may require more pairs.'''

---ciphertext/ftext pairs (in hex) to extract the first 4 bytes of the 10th round key corresponding to the first column are---
Ctext: ['0x23', '0xdb', '0x99', '0xde', '0xfc', '0x61', '0x74', '0x7b', '0x8c', '0x5f', '0x36', '0x65', '0x1b', '0x7f', '0x26', '0x1e'] 
 Ftext: ['0x33', '0xdb', '0x99', '0xde', '0xfc', '0x61', '0x74', '0xf6', '0x8c', '0x5f', '0xa3', '0x65', '0x1b', '0x2e', '0x26', '0x1e'] 

Ctext: ['0xe5', '0x99', '0x2a', '0xab', '0xc2', '0xd1', '0x22', '0xa6', '0xde', '0x2', '0xb', '0xcd', '0x53', '0x7e', '0xad', '0xf'] 
 Ftext: ['0x95', '0x99', '0x2a', '0xab', '0xc2', '0xd1', '0x22', '0x9a', '0xde', '0x2', '0x22', '0xcd', '0x53', '0x56', '0xad', '0xf'] 

---ciphertext/ftext pairs (in hex) to extract the next 4 bytes of the 10th round key corresponding to the second column are---
Ctext: ['0x72', '0x90', '0x2', '0x6', '0x1e', '0x13', '0x83', '0xe8', '0xfc', '0xa6', '0x51', '0x42', '0xe9', '0x2', '0xca', '0xb4'] 
 Ftext: ['0x72', '0xd7', '0x2', '0x6', '0x44', '0x13', '0x83', '0

"Since I have been able to ascertain the location of the fault injection to be in the first row, I suppose I wouldn't need more than 8 pairs total to\nextract all the 16 key bytes. However, if the location of the glitch was unknown or if the faults injected are not strictly in the first row, \nperhaps if the attacker is injecting faults at a location in the state where the effect of the fault is diffused over multiple rounds, then it may \nrequire more pairs of plaintext/ciphertext to isolate the fault effect and identify the correct key.\nIn addition, if the fault injection technique is imprecise and the faults injected are not strong enough to cause a distinguishable effect, \nthen it may require more pairs."

get the pairs needed to complete the full attack

In [673]:
_1st_4_bytes_pairs = get_pairs_for_all_bytes()['col_1']
_2nd_4_bytes_pairs = get_pairs_for_all_bytes()['col_2']
_3rd_4_bytes_pairs = get_pairs_for_all_bytes()['col_3']
_4th_4_bytes_pairs = get_pairs_for_all_bytes()['col_4']

##### Task 2c - Simple Piret and Quisquater DFA

define some utilities funcs

In [851]:
# Check if you got the right answer for the 4 keybytes in the first column (key at indices [0, 13, 10, 7])
# Note: We're referring to the round 10 keybytes (no need to rewind with the the key schedule)
# Note: The same key is used both for the `simple` and `full` DFA

def check_keybytes(k_0: int, k_13: int, k_10: int, k_7: int):
    keybytes = bytes([k_0, k_13, k_10, k_7])
    hasher = hashlib.sha3_256()
    hasher.update(keybytes)
    key_hash = hasher.hexdigest()
    if key_hash == '4409976e63e88e6d0ef93405e6b6d678c2a498d22dcaa72b28c8c9cd6233ec7f':
        print("Congratulations! Correct 4 keybytes found")
        return True
    
    print("Not quite right")
    return False

SBOX = np.array([
        0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
        0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
        0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
        0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
        0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
        0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
        0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
        0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
        0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
        0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
        0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
        0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
        0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
        0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
        0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
        0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16
], dtype=np.uint8)

# Inverse of AES SBOX
ISBOX = SBOX.argsort()

# AES MixCols matrix
MIXCOLS = np.array([[2, 3, 1, 1],
                    [1, 2, 3, 1],
                    [1, 1, 2, 3],
                    [3, 1, 1, 2]])

# Galois multiplication by 2 (for MixCols)
def galois_mult_2(a):
    temp = (a << 1) & 0xff

    if (a & 0x80):
        temp ^= 0x1b

    return temp

# Galois multiplication by 3 (for MixCols)
def galois_mult_3(a):
    return galois_mult_2(a) ^ a

# AES ShiftRows
# `mat` is 4x4 AES matrix
def shift(mat):
    shifted = np.zeros_like(mat)
    for i in range(4):
        shifted[i] = mat[i, np.arange(i, 4+i) % 4]
    return shifted

# TODO: implement your own. Should undo shift(mat)
def unshift(mat):
    pass

# Precompute all possible mixcols(glitch)
D = []
mixcol = MIXCOLS[:, 0]
for x in range(1, 255+1):
    D_element = []
    for j in range(4):
        out = None
        if mixcol[j] == 1:
            out = x
        if mixcol[j] == 2:
            out = galois_mult_2(x)
        if mixcol[j] == 3:
            out = galois_mult_3(x)
        D_element.append(out)
    D.append(D_element)

D = np.array(D)

print("Length of lookup table:", len(D))

# get the bytes at each position of interest in the AES state
def get_bytes_at_resp_pos(ciphertext, faultytext, indices):
    c_1st = ciphertext[np.unravel_index(indices[0], shape=(4,4), order='F')]
    c__1st_prime = faultytext[np.unravel_index(indices[0], shape=(4,4), order='F')]
    c_2nd = ciphertext[np.unravel_index(indices[1], shape=(4,4), order='F')]
    c_2nd_prime = faultytext[np.unravel_index(indices[1], shape=(4,4), order='F')]
    c_3rd = ciphertext[np.unravel_index(indices[2], shape=(4,4), order='F')]
    c_3rd_prime = faultytext[np.unravel_index(indices[2], shape=(4,4), order='F')]
    c_4th = ciphertext[np.unravel_index(indices[3], shape=(4,4), order='F')]
    c_4th_prime = faultytext[np.unravel_index(indices[3], shape=(4,4), order='F')]
    
    return c_1st, c__1st_prime, c_2nd, c_2nd_prime, c_3rd, c_3rd_prime, c_4th, c_4th_prime

Length of lookup table: 255


get the cipher/ftext pairs needed for the simple attack

In [811]:
# Two pairs of ciphertext/faulty texts.
# The fault is injected in the *first byte* before the MixCols in the 9th round
# Note: use for `simple` DFA (part C)

simple_ctxt1 = [174, 44, 204, 43, 18, 196, 238, 88, 3, 227, 92, 0, 137, 106, 205, 88]
simple_ftxt1 = [128, 44, 204, 43, 18, 196, 238, 171, 3, 227, 159, 0, 137, 186, 205, 88]

simple_ctxt2 = [41, 4, 148, 29, 23, 74, 41, 127, 125, 148, 36, 219, 29, 127, 4, 58]
simple_ftxt2 = [186, 4, 148, 29, 23, 74, 41, 160, 125, 148, 59, 219, 29, 172, 4, 58]


# Load ctext/ftext pairs in the correct AES column order
simple_ctxt1 = np.reshape(simple_ctxt1, (4, 4), order='F').astype(np.uint8)
simple_ftxt1 = np.reshape(simple_ftxt1, (4, 4), order='F').astype(np.uint8)
first_pair_simple = (simple_ctxt1, simple_ftxt1)

simple_ctxt2 = np.reshape(simple_ctxt2, (4, 4), order='F').astype(np.uint8)
simple_ftxt2 = np.reshape(simple_ftxt2, (4, 4), order='F').astype(np.uint8)
second_pair_simple = (simple_ctxt2, simple_ftxt2)

In [815]:
"""precomputing the D lookup table for fast lookup"""
# convert D to a set of strings of 2 digits each for fast lookup
D_in_2s = []
for arr in range(len(D)):
    D_in_2s.append("".join(D[arr][:2].astype(str)))
D_in_2s_set = set(D_in_2s)

# convert D to a set of strings of 3 digits each
D_in_3s = []
for arr in range(len(D)):
    D_in_3s.append("".join(D[arr][:3].astype(str)))
D_in_3s_set = set(D_in_3s)

# convert D to a set of strings of 4 digits each
D_in_4s = []
for arr in range(len(D)):
    D_in_4s.append("".join(D[arr][:4].astype(str)))
D_in_4s_set = set(D_in_4s)

In [817]:
# simple DFA to bruteforce 4 key bytes extraction
def simple_attack(_1stpair, _2ndpair):
    first_ctxt, first_ftxt = _1stpair
    second_ctxt, second_ftxt = _2ndpair

    # get the 0th, 7th, 10th and 13th byte of the first pair
    c0, c0_prime, c7, c7_prime, c10, c10_prime, c13, c13_prime = get_bytes_at_resp_pos(first_ctxt, first_ftxt, [0, 7, 10, 13])

    # get the 0th, 7th, 10th and 13th byte of the second pair
    c0_x, c0_prime_x, c7_x, c7_prime_x, c10_x, c10_prime_x, c13_x, c13_prime_x = get_bytes_at_resp_pos(second_ctxt, second_ftxt, [0, 7, 10, 13])

    # do the first filtering of key candidates
    K_c_2 = []
    for k0 in tqdm(range(256)):
        for k13 in range(256):
            str_1 = str(ISBOX[k0 ^ c0] ^ ISBOX[k0 ^ c0_prime]) + str(ISBOX[k13 ^ c13] ^ ISBOX[k13 ^ c13_prime])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k0 and k10 is equal the lookup table
            if str_1 in D_in_2s_set:
                K_c_2.append((k0, k13))

    # do the second filtering of key candidates
    K_c_3 = []
    for k10 in tqdm(range(256)):
        for k0, k13 in K_c_2:
            str_2 = str(ISBOX[k0 ^ c0] ^ ISBOX[k0 ^ c0_prime]) + str(ISBOX[k13 ^ c13] ^ ISBOX[k13 ^ c13_prime]) + \
                    str(ISBOX[k10 ^ c10] ^ ISBOX[k10 ^ c10_prime])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k0, k10 and k13 is equal the lookup table
            if str_2 in D_in_3s_set:
                K_c_3.append((k0, k13, k10))

    # do the third filtering of key candidates
    K_c_4 = []
    for k7 in tqdm(range(256)):
        for k0, k13, k10 in K_c_3:
            str_3 = str(ISBOX[k0 ^ c0] ^ ISBOX[k0 ^ c0_prime]) + str(ISBOX[k13 ^ c13] ^ ISBOX[k13 ^ c13_prime]) + \
                    str(ISBOX[k10 ^ c10] ^ ISBOX[k10 ^ c10_prime]) + str(ISBOX[k7 ^ c7] ^ ISBOX[k7 ^ c7_prime])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k0, k7, k10 and k13 is equal the lookup table
            if str_3 in D_in_4s_set:
                K_c_4.append((k0, k13, k10, k7))

    # do the fourth filtering of key candidates using the second pairs of ctext and ftext
    K_c_5 = []
    for k0, k13, k10, k7 in tqdm(K_c_4):
        str_4 = str(ISBOX[k0 ^ c0_x] ^ ISBOX[k0 ^ c0_prime_x]) + str(ISBOX[k13 ^ c13_x] ^ ISBOX[k13 ^ c13_prime_x]) + \
                    str(ISBOX[k10 ^ c10_x] ^ ISBOX[k10 ^ c10_prime_x]) + str(ISBOX[k7 ^ c7_x] ^ ISBOX[k7 ^ c7_prime_x])
        # check if the xor of the inverse sbox of ciphertext/faultytext with keys k0, k7, k10 and k13 is equal the lookup table
        if str_4 in D_in_4s_set:
            K_c_5.append((k0, k13, k10, k7))
    
    return K_c_5

In [814]:
last_res = simple_attack(first_pair_simple, second_pair_simple)
print('the key bytes for the simple attack are: ', last_res)

100%|██████████| 256/256 [00:00<00:00, 583.53it/s]
100%|██████████| 256/256 [00:00<00:00, 406.93it/s]
100%|██████████| 256/256 [00:00<00:00, 333.51it/s]
100%|██████████| 240/240 [00:00<00:00, 75172.35it/s]

the key bytes for the simple attack are:  [(168, 138, 164, 45)]





In [679]:
check_keybytes(k_0=168, k_13=138, k_10=164, k_7=45)

Congratulations! Correct 4 keybytes found


True

##### Task 2d - Full Piret and Quisquater DFA attack #####

In [809]:
""" some precomputations done to speed up the key extraction for the full attack"""
# for full dfa, D is modified to the form below as the glitch can happen in any row of the columns of the corresponding four bytes
D2 = []
for i in range(4):
    mixcol = MIXCOLS[:, i]
    for x in range(1, 255+1):
        D_element = []
        for j in range(4):
            out = None
            if mixcol[j] == 1:
                out = x
            if mixcol[j] == 2:
                out = galois_mult_2(x)
            if mixcol[j] == 3:
                out = galois_mult_3(x)
            D_element.append(out)
        D2.append(D_element)

D2 = np.array(D2)

# convert D to a set of strings of 2 digits each for fast lookup
D2_in_2s = []
for arr in range(len(D2)):
    D2_in_2s.append("".join(D2[arr][:2].astype(str)))
D2_in_2s_set = set(D2_in_2s)

# convert D to a set of strings of 3 digits each
D2_in_3s = []
for arr in range(len(D2)):
    D2_in_3s.append("".join(D2[arr][:3].astype(str)))
D2_in_3s_set = set(D2_in_3s)

# convert D to a set of strings of 4 digits each
D2_in_4s = []
for arr in range(len(D2)):
    D2_in_4s.append("".join(D2[arr][:4].astype(str)))
D2_in_4s_set = set(D2_in_4s)

# get the cipher text and the faulty cipher text corresponding to the first bytes set in the correct AES column order
def find_pairs(arr):
    ctext_byte1 = np.reshape(arr[0][0], (4, 4), order='F').astype(np.uint8)
    ftext_byte1 = np.reshape(arr[0][1], (4, 4), order='F').astype(np.uint8)
    first_pair = (ctext_byte1, ftext_byte1)

    ctext2_byte1 = np.reshape(arr[1][0], (4, 4), order='F').astype(np.uint8)
    ftext2_byte1 = np.reshape(arr[1][1], (4, 4), order='F').astype(np.uint8)
    second_pair = (ctext2_byte1, ftext2_byte1)
    return first_pair, second_pair

##### N.B: For the full key, the simple attack is repeated four times each time to retrieve the four key bytes corresponding to each column
first byte key extraction

In [782]:
# full DFA to bruteforce the first 4 key bytes, the D is modified to the form shown above
def full_attack_1(_1stpair, _2ndpair):
    first_ctxt, first_ftxt = _1stpair
    second_ctxt, second_ftxt = _2ndpair

    # get the 0th, 7th, 10th and 13th byte of the first pair
    c0, c0_prime, c7, c7_prime, c10, c10_prime, c13, c13_prime = get_bytes_at_resp_pos(first_ctxt, first_ftxt, [0, 7, 10, 13])

    # get the 0th, 7th, 10th and 13th byte of the second pair
    c0_x, c0_prime_x, c7_x, c7_prime_x, c10_x, c10_prime_x, c13_x, c13_prime_x = get_bytes_at_resp_pos(second_ctxt, second_ftxt, [0, 7, 10, 13])

    # do the first filtering of key candidates
    K_c_2 = []
    for k0 in tqdm(range(256)):
        for k13 in range(256):
            str_1 = str(ISBOX[k0 ^ c0] ^ ISBOX[k0 ^ c0_prime]) + str(ISBOX[k13 ^ c13] ^ ISBOX[k13 ^ c13_prime])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k0 and k10 is equal the lookup table
            if str_1 in D2_in_2s_set:
                K_c_2.append((k0, k13))
    
    # do the second filtering of key candidates
    K_c_3 = []
    for k10 in tqdm(range(256)):
        for k0, k13 in K_c_2:
            str_2 = str(ISBOX[k0 ^ c0] ^ ISBOX[k0 ^ c0_prime]) + str(ISBOX[k13 ^ c13] ^ ISBOX[k13 ^ c13_prime]) + \
                    str(ISBOX[k10 ^ c10] ^ ISBOX[k10 ^ c10_prime])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k0, k10 and k13 is equal the lookup table
            if str_2 in D2_in_3s_set:
                K_c_3.append((k0, k13, k10))

    # do the third filtering of key candidates
    K_c_4 = []
    for k7 in tqdm(range(256)):
        for k0, k13, k10 in K_c_3:
            str_3 = str(ISBOX[k0 ^ c0] ^ ISBOX[k0 ^ c0_prime]) + str(ISBOX[k13 ^ c13] ^ ISBOX[k13 ^ c13_prime]) + \
                    str(ISBOX[k10 ^ c10] ^ ISBOX[k10 ^ c10_prime]) + str(ISBOX[k7 ^ c7] ^ ISBOX[k7 ^ c7_prime])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k0, k7, k10 and k13 is equal the lookup table
            if str_3 in D2_in_4s_set:
                K_c_4.append((k0, k13, k10, k7))

    # do the fourth filtering of key candidates using the second pairs of ctext and ftext
    K_c_5 = []
    for k0, k13, k10, k7 in tqdm(K_c_4):
            str_4 = str(ISBOX[k0 ^ c0_x] ^ ISBOX[k0 ^ c0_prime_x]) + str(ISBOX[k13 ^ c13_x] ^ ISBOX[k13 ^ c13_prime_x]) + \
                    str(ISBOX[k10 ^ c10_x] ^ ISBOX[k10 ^ c10_prime_x]) + str(ISBOX[k7 ^ c7_x] ^ ISBOX[k7 ^ c7_prime_x])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k0, k7, k10 and k13 is equal the lookup table
            if str_4 in D2_in_4s_set:
                K_c_5.append((k0, k13, k10, k7))

    return K_c_5

In [808]:
firstpair, secondpair = find_pairs(_1st_4_bytes_pairs)
_res = full_attack_1(firstpair, secondpair)
print('the first set of key bytes for the full attack are: ', _res)

100%|██████████| 256/256 [00:00<00:00, 645.87it/s]
100%|██████████| 256/256 [00:02<00:00, 98.12it/s]
100%|██████████| 256/256 [00:03<00:00, 75.50it/s]
100%|██████████| 1360/1360 [00:00<00:00, 75504.02it/s]

the first set of key bytes for the full attack are:  [(168, 138, 164, 45)]





second byte key extraction

In [787]:
# full DFA to bruteforce the next 4 set of key bytes using the modified D above
def full_attack_2(_1stpair, _2ndpair):
    first_ctxt, first_ftxt = _1stpair
    second_ctxt, second_ftxt = _2ndpair

    # get the 1st, 4th, 11th and 14th byte of the first pair
    c1, c1_prime, c4, c4_prime, c11, c11_prime, c14, c14_prime = get_bytes_at_resp_pos(first_ctxt, first_ftxt, [1, 4, 11, 14])

    # get the 1st, 4th, 11th and 14th byte of the second pair
    c1_x, c1_prime_x, c4_x, c4_prime_x, c11_x, c11_prime_x, c14_x, c14_prime_x = get_bytes_at_resp_pos(second_ctxt, second_ftxt, [1, 4, 11, 14])

    # do the first filtering of key candidates
    K_c_2 = []
    for k4 in tqdm(range(256)):
        for k1 in range(256):
            str_1 = str(ISBOX[k4 ^ c4] ^ ISBOX[k4 ^ c4_prime]) + str(ISBOX[k1 ^ c1] ^ ISBOX[k1 ^ c1_prime])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k4 and k1 is equal the lookup table
            if str_1 in D2_in_2s_set:
                K_c_2.append((k4, k1))
    
    # do the second filtering of key candidates
    K_c_3 = []
    for k14 in tqdm(range(256)):
        for k4, k1 in K_c_2:
            str_2 = str(ISBOX[k4 ^ c4] ^ ISBOX[k4 ^ c4_prime]) + str(ISBOX[k1 ^ c1] ^ ISBOX[k1 ^ c1_prime]) + \
                        str(ISBOX[k14 ^ c14] ^ ISBOX[k14 ^ c14_prime])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k4, k1 and k14 is equal the lookup table
            if str_2 in D2_in_3s_set:
                K_c_3.append((k4, k1, k14))

    # do the third filtering of key candidates
    K_c_4 = []
    for k11 in tqdm(range(256)):
        for k4, k1, k14 in K_c_3:
            str_3 = str(ISBOX[k4 ^ c4] ^ ISBOX[k4 ^ c4_prime]) + str(ISBOX[k1 ^ c1] ^ ISBOX[k1 ^ c1_prime]) + \
                str(ISBOX[k14 ^ c14] ^ ISBOX[k14 ^ c14_prime]) + str(ISBOX[k11 ^ c11] ^ ISBOX[k11 ^ c11_prime])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k4, k1, k14 and k11 is equal the lookup table
            if str_3 in D2_in_4s_set:
                K_c_4.append((k4, k1, k14, k11))

    # do the fourth filtering of key candidates using the second pairs of ctext and ftext
    K_c_5 = []
    for k4, k1, k14, k11 in tqdm(K_c_4):
        str_4 = str(ISBOX[k4 ^ c4_x] ^ ISBOX[k4 ^ c4_prime_x]) + str(ISBOX[k1 ^ c1_x] ^ ISBOX[k1 ^ c1_prime_x]) + \
            str(ISBOX[k14 ^ c14_x] ^ ISBOX[k14 ^ c14_prime_x]) + str(ISBOX[k11 ^ c11_x] ^ ISBOX[k11 ^ c11_prime_x])
        # check if the xor of the inverse sbox of ciphertext/faultytext with keys k4, k1, k14 and k11 is equal the lookup table
        if str_4 in D2_in_4s_set:
            K_c_5.append((k4, k1, k14, k11))

    return K_c_5

In [788]:
first_pair_2, second_pair_2 = find_pairs(_2nd_4_bytes_pairs)
res_set2 = full_attack_2(first_pair_2, second_pair_2)
print('the second set of key bytes for the full attack are: ', res_set2)

100%|██████████| 256/256 [00:00<00:00, 601.27it/s]
100%|██████████| 256/256 [00:02<00:00, 93.68it/s]
100%|██████████| 256/256 [00:03<00:00, 68.86it/s]
100%|██████████| 1168/1168 [00:00<00:00, 85916.29it/s]

the second set of key bytes for the full attack are:  [(53, 73, 46, 0)]





third byte key extraction

In [799]:
# full DFA to bruteforce the third set of 4 key bytes using the modified D above
def full_attack_3(_1stpair, _2ndpair):
    first_ctxt, first_ftxt = _1stpair
    second_ctxt, second_ftxt = _2ndpair

    # get the 8th, 5th, 2nd and 15th byte of the first pair
    c8, c8_prime, c5, c5_prime, c2, c2_prime, c15, c15_prime = get_bytes_at_resp_pos(first_ctxt, first_ftxt, [8, 5, 2, 15])

    # get the 8th, 5th, 2nd and 15th byte of the second pair
    c8_x, c8_prime_x, c5_x, c5_prime_x, c2_x, c2_prime_x, c15_x, c15_prime_x = get_bytes_at_resp_pos(second_ctxt, second_ftxt, [8, 5, 2, 15])

    # do the first filtering of key candidates
    K_c_2 = []
    for k8 in tqdm(range(256)):
        for k5 in range(256):
            str_1 = str(ISBOX[k8 ^ c8] ^ ISBOX[k8 ^ c8_prime]) + str(ISBOX[k5 ^ c5] ^ ISBOX[k5 ^ c5_prime])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k8 and k5 is equal the lookup table
            if str_1 in D2_in_2s_set:
                K_c_2.append((k8, k5))
    
    # do the second filtering of key candidates
    K_c_3 = []
    for k2 in tqdm(range(256)):
        for k8, k5 in K_c_2:
            str_2 = str(ISBOX[k8 ^ c8] ^ ISBOX[k8 ^ c8_prime]) + str(ISBOX[k5 ^ c5] ^ ISBOX[k5 ^ c5_prime]) + \
                str(ISBOX[k2 ^ c2] ^ ISBOX[k2 ^ c2_prime])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k8, k5 and k2 is equal the lookup table
            if str_2 in D2_in_3s_set:
                K_c_3.append((k8, k5, k2))
    
    # do the third filtering of key candidates
    K_c_4 = []
    for k15 in tqdm(range(256)):
        for k8, k5, k2 in K_c_3:
            str_3 = str(ISBOX[k8 ^ c8] ^ ISBOX[k8 ^ c8_prime]) + str(ISBOX[k5 ^ c5] ^ ISBOX[k5 ^ c5_prime]) + \
                str(ISBOX[k2 ^ c2] ^ ISBOX[k2 ^ c2_prime]) + str(ISBOX[k15 ^ c15] ^ ISBOX[k15 ^ c15_prime])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k8, k5, k2 and k15 is equal the lookup table
            if str_3 in D2_in_4s_set:
                K_c_4.append((k8, k5, k2, k15))
    
    # do the fourth filtering of key candidates using the second pairs of ctext and ftext
    K_c_5 = []
    for k8, k5, k2, k15 in tqdm(K_c_4):
        str_4 = str(ISBOX[k8 ^ c8_x] ^ ISBOX[k8 ^ c8_prime_x]) + str(ISBOX[k5 ^ c5_x] ^ ISBOX[k5 ^ c5_prime_x]) + \
            str(ISBOX[k2 ^ c2_x] ^ ISBOX[k2 ^ c2_prime_x]) + str(ISBOX[k15 ^ c15_x] ^ ISBOX[k15 ^ c15_prime_x])
        # check if the xor of the inverse sbox of ciphertext/faultytext with keys k8, k5, k2 and k15 is equal the lookup table
        if str_4 in D2_in_4s_set:
            K_c_5.append((k8, k5, k2, k15))

    return K_c_5

In [800]:
first_pair_3, second_pair_3 = find_pairs(_3rd_4_bytes_pairs)
res_set3 = full_attack_3(first_pair_3, second_pair_3)
print('the third set of key bytes for the full attack are: ', res_set3)

100%|██████████| 256/256 [00:00<00:00, 653.02it/s]
100%|██████████| 256/256 [00:02<00:00, 95.79it/s]
100%|██████████| 256/256 [00:03<00:00, 72.75it/s]
100%|██████████| 1120/1120 [00:00<00:00, 76866.52it/s]

the third set of key bytes for the full attack are:  [(93, 213, 55, 198)]





fourth byte key extraction

In [803]:
# full DFA to bruteforce the fourth set of 4 key bytes using the modified D above
def full_attack_4(_1stpair, _2ndpair):
    first_ctxt, first_ftxt = _1stpair
    second_ctxt, second_ftxt = _2ndpair

    # get the 12th, 9th, 6th and 3rd byte of the first pair
    c12, c12_prime, c9, c9_prime, c6, c6_prime, c3, c3_prime = get_bytes_at_resp_pos(first_ctxt, first_ftxt, [12, 9, 6, 3])

    # get the 12th, 9th, 6th and 3rd byte of the second pair
    c12_x, c12_prime_x, c9_x, c9_prime_x, c6_x, c6_prime_x, c3_x, c3_prime_x = get_bytes_at_resp_pos(second_ctxt, second_ftxt, [12, 9, 6, 3])

    # do the first filtering of key candidates
    K_c_2 = []
    for k12 in tqdm(range(256)):
        for k9 in range(256):
            str_1 = str(ISBOX[k12 ^ c12] ^ ISBOX[k12 ^ c12_prime]) + str(ISBOX[k9 ^ c9] ^ ISBOX[k9 ^ c9_prime])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k12 and k9 is equal the lookup table
            if str_1 in D2_in_2s_set:
                K_c_2.append((k12, k9))
    
    # do the second filtering of key candidates
    K_c_3 = []
    for k6 in tqdm(range(256)):
        for k12, k9 in K_c_2:  
            str_2 = str(ISBOX[k12 ^ c12] ^ ISBOX[k12 ^ c12_prime]) + str(ISBOX[k9 ^ c9] ^ ISBOX[k9 ^ c9_prime]) + \
                str(ISBOX[k6 ^ c6] ^ ISBOX[k6 ^ c6_prime])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k12, k9 and k6 is equal the lookup table
            if str_2 in D2_in_3s_set:
                K_c_3.append((k12, k9, k6))
    
    # do the third filtering of key candidates
    K_c_4 = []
    for k3 in tqdm(range(256)):
        for k12, k9, k6 in K_c_3:       
            str_3 = str(ISBOX[k12 ^ c12] ^ ISBOX[k12 ^ c12_prime]) + str(ISBOX[k9 ^ c9] ^ ISBOX[k9 ^ c9_prime]) + \
                str(ISBOX[k6 ^ c6] ^ ISBOX[k6 ^ c6_prime]) + str(ISBOX[k3 ^ c3] ^ ISBOX[k3 ^ c3_prime])
            # check if the xor of the inverse sbox of ciphertext/faultytext with keys k12, k9, k6 and k3 is equal the lookup table
            if str_3 in D2_in_4s_set:
                K_c_4.append((k12, k9, k6, k3))

    # do the fourth filtering of key candidates using the second pairs of ctext and ftext
    K_c_5 = []
    for k12, k9, k6, k3 in tqdm(K_c_4):
        str_4 = str(ISBOX[k12 ^ c12_x] ^ ISBOX[k12 ^ c12_prime_x]) + str(ISBOX[k9 ^ c9_x] ^ ISBOX[k9 ^ c9_prime_x]) + \
            str(ISBOX[k6 ^ c6_x] ^ ISBOX[k6 ^ c6_prime_x]) + str(ISBOX[k3 ^ c3_x] ^ ISBOX[k3 ^ c3_prime_x])
        # check if the xor of the inverse sbox of ciphertext/faultytext with keys k8, k5, k2 and k15 is equal the lookup table
        if str_4 in D2_in_4s_set:
            K_c_5.append((k12, k9, k6, k3))

    return K_c_5

In [804]:
first_pair_4, second_pair_4 = find_pairs(_4th_4_bytes_pairs)
res_set4 = full_attack_4(first_pair_4, second_pair_4)
print('the fourth set of key bytes for the full attack are: ', res_set4)

100%|██████████| 256/256 [00:00<00:00, 617.45it/s]
100%|██████████| 256/256 [00:02<00:00, 93.19it/s]
100%|██████████| 256/256 [00:04<00:00, 56.26it/s]
100%|██████████| 1456/1456 [00:00<00:00, 81874.09it/s]

the fourth set of key bytes for the full attack are:  [(170, 35, 50, 172)]





get the original key from the 10th round key using the aeskeyschedule function

In [805]:
all_keys = [168, 73, 55, 172, 53, 213, 50, 45, 93, 35, 164, 0, 170, 138, 46, 198]
keys_hex = [hex(x)[2:].zfill(2) for x in all_keys]  # Add leading zeros and remove "0x"
keys_str = "".join(keys_hex)  # Join hexadecimal strings
keys_byte = bytes.fromhex(keys_str)  # Convert to bytes

original_key = aes.reverse_key_schedule(keys_byte, 10)
print("The original key is:", original_key)

The original key is: b'=\x83\xa4\x01t\xa3Xg;l=\x99\xdcS\x92\xc3'


In [806]:
# After recovering the full key, decrypt the secret message:
# You can use the `pycryptodome` library 
secret = bytes.fromhex("2a92fc6ad8006b658f49062c2843ad99")
secret

b'*\x92\xfcj\xd8\x00ke\x8fI\x06,(C\xad\x99'

In [807]:
cipher = AES.new(original_key, AES.MODE_ECB)
plaintext = cipher.decrypt(secret)
print("The secret message is:", plaintext.decode('utf-8'))

The secret message is: DFAIsAFunAttack!
