# DFA_Theory

**SUMMARY:** *Fault attacks aren't limited to breaking past password checks and breaking bootloaders. In this lab, we'll look at how we can recover and AES key by inserting faults near the end of the encryption operation. In the next lab, we'll use this knowledge to actually recover an AES key from a device running TINYAES128C*

##  A quick overview of AES

AES is a symmetric (meaning it uses the same key for encryption and decryption) encryption algorithm that uses



## 

After that analysis, we are left with a system of equations:

$$O_0 + O_0^\prime = S(Y_0) + S(2Z + Y_0) \\
O_7 + O_7^\prime = S(Y_1) + S(3Z + Y_1) \\
O_{10} + O_{10}^\prime = S(Y_2) + S(Z + Y_2) \\
O_{13} + O_{13}^\prime = S(Y_3) + S(2Z + Y_3) \\
$$

This a non-linear system of equations with multiple solutions, so it's going to be much easier to just brute force it -  aka try every possible Z, Y_0, Y_1, Y_2, and Y_3 value in these equations, taking only the ones that work for all the equations. You can make this much faster by short circuiting - as soon as it fails one of these equations, there's no need to continue on from that point. For example, if you're going through the equations in the above sequence and the second one fails, there's no need to continue on with Y_2 and Y_3 for that particular combination of Z, Y_0, and Y_1.

At the end, you should have a list of combinations of Z, Y_0, Y_1, Y_2, and Y_3 that work for these equations. We can eliminate most (typically all) of the rest of the incorrect guesses by inserting a new fault in byte 0 with the same plaintext. This will leave Y_0, Y_1, Y_2, Y_3, and O_0 the same, but changes Z and O_0. Using this fact, we can insert a new fault and narrow down our set of Y_0, Y_1, Y_2, and Y_3. In fact, we only need that one additional fault to get down to a single guess!

In [384]:
def generate_glitch(pt, cipher):
    state = list(pt)
    state = state+[16-len(state)]*(16-len(state))
    cipher._add_round_key(state, 0)
    for i in range(1, 9):
        cipher._sub_bytes(state)
        cipher._shift_rows(state)
        cipher._mix_columns(state, False)
        cipher._add_round_key(state, i)
    cipher._sub_bytes(state)
    cipher._shift_rows(state)

    x = list(state)
    cipher._mix_columns(x, False)
    cipher._add_round_key(x, 9)
    cipher._sub_bytes(x)
    cipher._shift_rows(x)
    cipher._add_round_key(x, 10)

    import random
    random.seed()
    fault = random.getrandbits(8)
    state[0] = fault

    cipher._mix_columns(state, False)
    cipher._add_round_key(state, 9)
    cipher._sub_bytes(state)
    cipher._shift_rows(state)
    cipher._add_round_key(state, 10)
    return state, x

from tqdm.notebook import trange
def get_Y_guesses(state, x):    
    def check_Y(Z, Yn, n):
        lookup = [0, 7, 10, 13]
        lhs = state[lookup[n]] ^ x[lookup[n]]
        coeff = [2, 3, 1, 1]
        rhs = aes_tables.sbox[Yn] ^ aes_tables.sbox[((Z*coeff[n])&0xFF) ^ Yn]
        return lhs == rhs
    guesses = []
    
    for Z in trange(255):
        for Y0 in range(255):
            if check_Y(Z, Y0, 0):
                for Y1 in range(255):
                    if check_Y(Z, Y1, 1):
                        for Y2 in range(255):
                            if check_Y(Z, Y2, 2):
                                for Y3 in range(255):
                                    if check_Y(Z, Y3, 3):
                                        guesses.append((Y0, Y1, Y2, Y3))
    return guesses
    
def update_Y_guesses(Y_old, Y_new):
    updated_Y = []
    for Ys in Y_old:
        if Ys in Y_new:
            updated_Y.append(Ys)
    return updated_Y

def Y_to_key(x, Y):
    return aes_tables.sbox[Y] ^ x

In [385]:
pt = ktp.next()[1]
state, x = generate_glitch(pt, cipher)
print(bytearray(state), bytearray(x))

CWbytearray(b'09 bc a9 cb 2f 1a 9f 19 1d 86 bd b9 e1 b9 34 ba') CWbytearray(b'10 bc a9 cb 2f 1a 9f 0b 1d 86 fd b9 e1 a6 34 ba')


In [386]:
Y_guesses = get_Y_guesses(state, x)
print(Y_guesses)

HBox(children=(FloatProgress(value=0.0, max=255.0), HTML(value='')))


[(35, 32, 41, 129), (35, 32, 41, 130), (35, 32, 42, 129), (35, 32, 42, 130), (35, 41, 41, 129), (35, 41, 41, 130), (35, 41, 42, 129), (35, 41, 42, 130), (37, 32, 41, 129), (37, 32, 41, 130), (37, 32, 42, 129), (37, 32, 42, 130), (37, 41, 41, 129), (37, 41, 41, 130), (37, 41, 42, 129), (37, 41, 42, 130), (155, 8, 128, 215), (155, 8, 128, 227), (155, 8, 180, 215), (155, 8, 180, 227), (155, 148, 128, 215), (155, 148, 128, 227), (155, 148, 180, 215), (155, 148, 180, 227), (243, 8, 128, 215), (243, 8, 128, 227), (243, 8, 180, 215), (243, 8, 180, 227), (243, 148, 128, 215), (243, 148, 128, 227), (243, 148, 180, 215), (243, 148, 180, 227), (50, 107, 6, 22), (50, 107, 6, 94), (50, 107, 78, 22), (50, 107, 78, 94), (50, 179, 6, 22), (50, 179, 6, 94), (50, 179, 78, 22), (50, 179, 78, 94), (162, 107, 6, 22), (162, 107, 6, 94), (162, 107, 78, 22), (162, 107, 78, 94), (162, 179, 6, 22), (162, 179, 6, 94), (162, 179, 78, 22), (162, 179, 78, 94), (7, 25, 3, 172), (7, 25, 3, 230), (7, 25, 73, 172), (7

In [387]:
state, x = generate_glitch(pt, cipher)
print(bytearray(state), bytearray(x))
Y_guesses = update_Y_guesses(Y_guesses, get_Y_guesses(state, x))
print(Y_guesses)
print(bytearray(Y_to_key(x, Y_guesses[0], n)))


CWbytearray(b'36 bc a9 cb 2f 1a 9f 90 1d 86 54 b9 e1 23 34 ba') CWbytearray(b'10 bc a9 cb 2f 1a 9f 0b 1d 86 fd b9 e1 a6 34 ba')


HBox(children=(FloatProgress(value=0.0, max=255.0), HTML(value='')))


[]


IndexError: list index out of range

In [None]:
K0 = 

In [65]:
from chipwhisperer.common.utils.aes_cipher import AESCipher, aes_tables
import chipwhisperer.analyzer as cwa

In [5]:
import chipwhisperer as cw
ktp = cw.ktp.Basic()
key = list(ktp.next()[0])
for i in range(10):
    key.extend(cwa.aes_funcs.key_schedule_rounds(key[0:16], 0, i+1))

In [6]:
key

[43,
 126,
 21,
 22,
 40,
 174,
 210,
 166,
 171,
 247,
 21,
 136,
 9,
 207,
 79,
 60,
 160,
 250,
 254,
 23,
 136,
 84,
 44,
 177,
 35,
 163,
 57,
 57,
 42,
 108,
 118,
 5,
 242,
 194,
 149,
 242,
 122,
 150,
 185,
 67,
 89,
 53,
 128,
 122,
 115,
 89,
 246,
 127,
 61,
 128,
 71,
 125,
 71,
 22,
 254,
 62,
 30,
 35,
 126,
 68,
 109,
 122,
 136,
 59,
 239,
 68,
 165,
 65,
 168,
 82,
 91,
 127,
 182,
 113,
 37,
 59,
 219,
 11,
 173,
 0,
 212,
 209,
 198,
 248,
 124,
 131,
 157,
 135,
 202,
 242,
 184,
 188,
 17,
 249,
 21,
 188,
 109,
 136,
 163,
 122,
 17,
 11,
 62,
 253,
 219,
 249,
 134,
 65,
 202,
 0,
 147,
 253,
 78,
 84,
 247,
 14,
 95,
 95,
 201,
 243,
 132,
 166,
 79,
 178,
 78,
 166,
 220,
 79,
 234,
 210,
 115,
 33,
 181,
 141,
 186,
 210,
 49,
 43,
 245,
 96,
 127,
 141,
 41,
 47,
 172,
 119,
 102,
 243,
 25,
 250,
 220,
 33,
 40,
 209,
 41,
 65,
 87,
 92,
 0,
 110,
 208,
 20,
 249,
 168,
 201,
 238,
 37,
 137,
 225,
 63,
 12,
 200,
 182,
 99,
 12,
 166]

In [7]:
cipher = AESCipher(key)

In [11]:
from Crypto.Cipher import AES
check_cipher = AES.new(ktp.next()[0], AES.MODE_ECB)

In [166]:
pt = ktp.next()[1]

In [167]:
ct1 = cipher.cipher_block(list(pt))
ct2 = check_cipher.encrypt(pt)

In [168]:
bytearray(ct1)

CWbytearray(b'17 2a 21 a7 f6 7a ca c3 0e 64 35 08 23 f8 6f bb')

In [169]:
bytearray(ct2)

CWbytearray(b'17 2a 21 a7 f6 7a ca c3 0e 64 35 08 23 f8 6f bb')

In [170]:
state = list(pt)
state = state+[16-len(state)]*(16-len(state))

In [171]:
cipher._add_round_key(state, 0)
for i in range(1, 9):
    cipher._sub_bytes(state)
    cipher._shift_rows(state)
    cipher._mix_columns(state, False)
    cipher._add_round_key(state, i)
cipher._sub_bytes(state)
cipher._shift_rows(state)

In [172]:
bytearray(state)

CWbytearray(b'53 7c c8 77 a3 96 b2 72 1d 45 83 3f 74 bc e5 28')

In [173]:
x = list(state)
cipher._mix_columns(x, False)
cipher._add_round_key(x, 9)
cipher._sub_bytes(x)
cipher._shift_rows(x)
cipher._add_round_key(x, 10)

In [174]:
bytearray(x)

CWbytearray(b'17 2a 21 a7 f6 7a ca c3 0e 64 35 08 23 f8 6f bb')

In [175]:
bytearray(state)

CWbytearray(b'53 7c c8 77 a3 96 b2 72 1d 45 83 3f 74 bc e5 28')

In [176]:
import random
random.seed()
fault = random.getrandbits(8)

In [177]:
state[0] = fault

In [178]:
cipher._mix_columns(state, False)
cipher._add_round_key(state, 9)
cipher._sub_bytes(state)
cipher._shift_rows(state)
cipher._add_round_key(state, 10)
bytearray(state)

CWbytearray(b'06 2a 21 a7 f6 7a ca 01 0e 64 e5 08 23 09 6f bb')

In [179]:
print(bytearray(state))
print(bytearray(x))

CWbytearray(b'06 2a 21 a7 f6 7a ca 01 0e 64 e5 08 23 09 6f bb')
CWbytearray(b'17 2a 21 a7 f6 7a ca c3 0e 64 35 08 23 f8 6f bb')


In [181]:
def gmul(a, b):
    p = 0
    while a and b:
        if b & 1:
            p ^= a
        if (a & 0x80):
            a = (a << 1) ^ 0x11b;
        else:
            a <<= 1
        b >>= 1
    return p
def check_Y(Z, Yn, n):
    lookup = [0, 7, 10, 13]
    lhs = state[lookup[n]] ^ x[lookup[n]]
    coeff = [2, 3, 1, 1]
    rhs = aes_tables.sbox[Yn] ^ aes_tables.sbox[gmul(Z, coeff[n]) ^ Yn]
    return lhs == rhs

guesses = []
from tqdm.notebook import trange
for Z in trange(255):
    for Y0 in range(255):
        if check_Y(Z, Y0, 0):
            for Y1 in range(255):
                if check_Y(Z, Y1, 1):
                    for Y2 in range(255):
                        if check_Y(Z, Y2, 2):
                            for Y3 in range(255):
                                if check_Y(Z, Y3, 3):
                                    guesses.append((Z, Y0, Y1, Y2, Y3))
print(guesses)                            

HBox(children=(FloatProgress(value=0.0, max=255.0), HTML(value='')))


[(4, 208, 4, 26, 248), (4, 208, 4, 26, 252), (4, 208, 4, 30, 248), (4, 208, 4, 30, 252), (4, 208, 8, 26, 248), (4, 208, 8, 26, 252), (4, 208, 8, 30, 248), (4, 208, 8, 30, 252), (4, 216, 4, 26, 248), (4, 216, 4, 26, 252), (4, 216, 4, 30, 248), (4, 216, 4, 30, 252), (4, 216, 8, 26, 248), (4, 216, 8, 26, 252), (4, 216, 8, 30, 248), (4, 216, 8, 30, 252), (15, 0, 142, 225, 212), (15, 0, 142, 225, 219), (15, 0, 142, 238, 212), (15, 0, 142, 238, 219), (15, 0, 159, 225, 212), (15, 0, 159, 225, 219), (15, 0, 159, 238, 212), (15, 0, 159, 238, 219), (15, 7, 142, 225, 212), (15, 7, 142, 225, 219), (15, 7, 142, 238, 212), (15, 7, 142, 238, 219), (15, 7, 159, 225, 212), (15, 7, 159, 225, 219), (15, 7, 159, 238, 212), (15, 7, 159, 238, 219), (15, 25, 142, 225, 212), (15, 25, 142, 225, 219), (15, 25, 142, 238, 212), (15, 25, 142, 238, 219), (15, 25, 159, 225, 212), (15, 25, 159, 225, 219), (15, 25, 159, 238, 212), (15, 25, 159, 238, 219), (15, 30, 142, 225, 212), (15, 30, 142, 225, 219), (15, 30, 142

In [182]:
K0s = set([])
K7s = set([])
K10s = set([])
K13s = set([])
for Z, Y0, Y1, Y2, Y3 in guesses:
    K0s.add(state[0] ^ aes_tables.sbox[Y0])
    K7s.add(state[7] ^ aes_tables.sbox[Y1])
    K10s.add(state[10] ^ aes_tables.sbox[Y2])
    K13s.add(state[13] ^ aes_tables.sbox[Y3])

In [183]:
K0s

{13,
 15,
 28,
 30,
 45,
 47,
 60,
 62,
 69,
 71,
 84,
 86,
 101,
 103,
 116,
 118,
 137,
 139,
 152,
 154,
 169,
 171,
 184,
 186,
 193,
 195,
 208,
 210,
 225,
 227,
 240,
 242}

In [198]:
pt2 = ktp.next()[1]
state2 = list(pt2)
state2 = state2+[16-len(state2)]*(16-len(state2))

cipher._add_round_key(state2, 0)
for i in range(1, 9):
    cipher._sub_bytes(state2)
    cipher._shift_rows(state2)
    cipher._mix_columns(state2, False)
    cipher._add_round_key(state2, i)
cipher._sub_bytes(state2)
cipher._shift_rows(state2)

x2 = list(state2)
cipher._mix_columns(x2, False)
cipher._add_round_key(x2, 9)
cipher._sub_bytes(x2)
cipher._shift_rows(x2)
cipher._add_round_key(x2, 10)

import random
random.seed()
fault = random.getrandbits(8)
state2[0] = fault

cipher._mix_columns(state2, False)
cipher._add_round_key(state2, 9)
cipher._sub_bytes(state2)
cipher._shift_rows(state2)
cipher._add_round_key(state2, 10)

print(bytearray(state2))
print(bytearray(x2))

CWbytearray(b'80 cd 7a b8 48 ed ab 26 73 f3 b3 20 54 11 28 fc')
CWbytearray(b'01 cd 7a b8 48 ed ab 5c 73 f3 89 20 54 3e 28 fc')


In [199]:
def check_Y2(Z, Yn, n):
    lookup = [0, 7, 10, 13]
    lhs = state2[lookup[n]] ^ x2[lookup[n]]
    coeff = [2, 3, 1, 1]
    rhs = aes_tables.sbox[Yn] ^ aes_tables.sbox[gmul(Z, coeff[n]) ^ Yn]
    return lhs == rhs

guesses2 = []
from tqdm.notebook import trange
for Z in trange(255):
    for Y0 in range(255):
        if check_Y2(Z, Y0, 0):
            for Y1 in range(255):
                if check_Y2(Z, Y1, 1):
                    for Y2 in range(255):
                        if check_Y2(Z, Y2, 2):
                            for Y3 in range(255):
                                if check_Y2(Z, Y3, 3):
                                    guesses2.append((Z, Y0, Y1, Y2, Y3))

HBox(children=(FloatProgress(value=0.0, max=255.0), HTML(value='')))




In [200]:
guesses2

[(27, 155, 150, 111, 193),
 (27, 155, 150, 111, 218),
 (27, 155, 150, 116, 193),
 (27, 155, 150, 116, 218),
 (27, 155, 187, 111, 193),
 (27, 155, 187, 111, 218),
 (27, 155, 187, 116, 193),
 (27, 155, 187, 116, 218),
 (27, 173, 150, 111, 193),
 (27, 173, 150, 111, 218),
 (27, 173, 150, 116, 193),
 (27, 173, 150, 116, 218),
 (27, 173, 187, 111, 193),
 (27, 173, 187, 111, 218),
 (27, 173, 187, 116, 193),
 (27, 173, 187, 116, 218),
 (28, 21, 88, 106, 78),
 (28, 21, 88, 106, 82),
 (28, 21, 88, 118, 78),
 (28, 21, 88, 118, 82),
 (28, 21, 124, 106, 78),
 (28, 21, 124, 106, 82),
 (28, 21, 124, 118, 78),
 (28, 21, 124, 118, 82),
 (28, 45, 88, 106, 78),
 (28, 45, 88, 106, 82),
 (28, 45, 88, 118, 78),
 (28, 45, 88, 118, 82),
 (28, 45, 124, 106, 78),
 (28, 45, 124, 106, 82),
 (28, 45, 124, 118, 78),
 (28, 45, 124, 118, 82),
 (50, 60, 147, 156, 209),
 (50, 60, 147, 156, 227),
 (50, 60, 147, 174, 209),
 (50, 60, 147, 174, 227),
 (50, 60, 197, 156, 209),
 (50, 60, 197, 156, 227),
 (50, 60, 197, 174, 

In [201]:
K0s2 = set([])
K7s2 = set([])
K10s2 = set([])
K13s2 = set([])
for Z, Y0, Y1, Y2, Y3 in guesses2:
    K0s2.add(state2[0] ^ aes_tables.sbox[Y0])
    K7s2.add(state2[7] ^ aes_tables.sbox[Y1])
    K10s2.add(state2[10] ^ aes_tables.sbox[Y2])
    K13s2.add(state2[13] ^ aes_tables.sbox[Y3])

In [202]:
K0 = []
for K in K0s:
    if K in K0s2:
        K0.append(K)

In [203]:
bytearray(K0)

CWbytearray(b'98 1c ab 2f d0 54 67')

In [204]:
bytearray(key)

CWbytearray(b'2b 7e 15 16 28 ae d2 a6 ab f7 15 88 09 cf 4f 3c a0 fa fe 17 88 54 2c b1 23 a3 39 39 2a 6c 76 05 f2 c2 95 f2 7a 96 b9 43 59 35 80 7a 73 59 f6 7f 3d 80 47 7d 47 16 fe 3e 1e 23 7e 44 6d 7a 88 3b ef 44 a5 41 a8 52 5b 7f b6 71 25 3b db 0b ad 00 d4 d1 c6 f8 7c 83 9d 87 ca f2 b8 bc 11 f9 15 bc 6d 88 a3 7a 11 0b 3e fd db f9 86 41 ca 00 93 fd 4e 54 f7 0e 5f 5f c9 f3 84 a6 4f b2 4e a6 dc 4f ea d2 73 21 b5 8d ba d2 31 2b f5 60 7f 8d 29 2f ac 77 66 f3 19 fa dc 21 28 d1 29 41 57 5c 00 6e d0 14 f9 a8 c9 ee 25 89 e1 3f 0c c8 b6 63 0c a6')

In [219]:
pt2 = ktp.next()[1]
state2 = list(pt2)
state2 = state2+[16-len(state2)]*(16-len(state2))

cipher._add_round_key(state2, 0)
for i in range(1, 9):
    cipher._sub_bytes(state2)
    cipher._shift_rows(state2)
    cipher._mix_columns(state2, False)
    cipher._add_round_key(state2, i)
cipher._sub_bytes(state2)
cipher._shift_rows(state2)

x2 = list(state2)
cipher._mix_columns(x2, False)
cipher._add_round_key(x2, 9)
cipher._sub_bytes(x2)
cipher._shift_rows(x2)
cipher._add_round_key(x2, 10)

import random
random.seed()
fault = random.getrandbits(8)
state2[0] = fault

cipher._mix_columns(state2, False)
cipher._add_round_key(state2, 9)
cipher._sub_bytes(state2)
cipher._shift_rows(state2)
cipher._add_round_key(state2, 10)

print(bytearray(state2))
print(bytearray(x2))

CWbytearray(b'c6 69 0f 4c eb 04 a8 ec 41 17 5c 26 02 94 70 bd')
CWbytearray(b'8c 69 0f 4c eb 04 a8 61 41 17 05 26 02 04 70 bd')


In [315]:
pt = ktp.next()[1]
state, x = generate_glitch(pt, cipher)
print(bytearray(state), bytearray(x))

CWbytearray(b'6c 11 86 14 7c cf c5 df c5 48 8a 22 1e 20 24 00') CWbytearray(b'23 11 86 14 7c cf c5 ec c5 48 92 22 1e e6 24 00')


In [316]:
key_guesses = get_key_guesses(state, x)
print(key_guesses)

HBox(children=(FloatProgress(value=0.0, max=255.0), HTML(value='')))


[{4, 134, 141, 148, 22, 29, 159, 37, 167, 172, 46, 181, 55, 60, 190, 194, 201, 75, 208, 82, 89, 219, 97, 227, 232, 106, 241, 115, 120, 250}, {136, 137, 142, 18, 19, 20, 21, 32, 33, 38, 39, 186, 187, 189, 194, 195, 196, 197, 88, 89, 94, 95, 106, 107, 108, 109, 240, 241, 246, 247}, {138, 11, 12, 141, 146, 19, 20, 149, 168, 41, 46, 175, 176, 49, 54, 183, 74, 203, 204, 77, 82, 211, 212, 85, 104, 233, 238, 111, 112, 241, 246, 119}, {5, 14, 142, 16, 144, 27, 155, 37, 165, 174, 46, 176, 48, 187, 59, 195, 200, 72, 214, 86, 221, 93, 99, 227, 104, 232, 118, 246, 253, 125}]


In [320]:
state, x = generate_glitch(pt, cipher)
print(bytearray(state), bytearray(x))
key_guesses = update_keys(key_guesses, get_key_guesses(state, x))
print(key_guesses)

CWbytearray(b'b7 11 86 14 7c cf c5 00 c5 48 d5 22 1e 2f 24 00') CWbytearray(b'23 11 86 14 7c cf c5 ec c5 48 92 22 1e e6 24 00')


HBox(children=(FloatProgress(value=0.0, max=255.0), HTML(value='')))


[[208], [137], [12], [99]]


In [None]:
K0 = []
for K in K0s:
    if K in K0s:
        K0.append(K)                          