## Introduction

This notebook is a tutorial of Differential Fault Attack on AES with faults injected before the MixColumns in the 9-th round. For the original paper, we refer to [Piret and Quisquater, CHES 2003](https://link.springer.com/chapter/10.1007/978-3-540-45238-6_7).

## Faulted AES

In [165]:
from random import randint


SBOX = (
    0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
    0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
    0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
    0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
    0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
    0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
    0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
    0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
    0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
    0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
    0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
    0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
    0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
    0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
    0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
    0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16,
)

RCON = (
    0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40,
    0x80, 0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A,
    0x2F, 0x5E, 0xBC, 0x63, 0xC6, 0x97, 0x35, 0x6A,
    0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91, 0x39,
)

xtime = lambda a: (((a << 1) ^ 0x1B) & 0xFF) if (a & 0x80) else (a << 1)

class F_AES128:
    def __init__(self, key: bytes, fault_position=None):
        self.fault_location = fault_position
        self.Nr = 10
        self.Nk = 4
        self.round_keys = self.key_schedule(list(key))
        
    def key_schedule(self, key) -> list:
        # Round 0
        round_keys = key[:]
        
        # Round ri: 1-10
        for ri in range(1, self.Nr+1):
            fi = (ri-1) * 16    # fi: index of  1-st element of old round key
            li = fi + 12        # li: index of 12-th element of old round key (3-rd column)

            # RotWord
            c0 = round_keys[li + 1]
            c1 = round_keys[li + 2]
            c2 = round_keys[li + 3]
            c3 = round_keys[li    ]

            # SubWord
            c0 = SBOX[c0]
            c1 = SBOX[c1]
            c2 = SBOX[c2]
            c3 = SBOX[c3]

            # XOR
            c0 ^= round_keys[fi + 0]
            c1 ^= round_keys[fi + 1]
            c2 ^= round_keys[fi + 2]
            c3 ^= round_keys[fi + 3]

            # Rcon
            c0 ^= RCON[ri]

            # New round key
            ni = li + 4         # ni: index of  1-st element of new round key
            round_keys.append(c0)
            round_keys.append(c1)
            round_keys.append(c2)
            round_keys.append(c3)

            for _ in range(3):
                ni = ni + 4
                round_keys.append(round_keys[ni-16] ^ round_keys[ni-4])
                round_keys.append(round_keys[ni-15] ^ round_keys[ni-3])
                round_keys.append(round_keys[ni-14] ^ round_keys[ni-2])
                round_keys.append(round_keys[ni-13] ^ round_keys[ni-1])
        
        return round_keys

    def encrypt(self, plaintext: bytes) -> bytes:
        assert len(plaintext) == 16
        state = plaintext

        # Round 0
        state = self._add_round_key(state, self.round_keys[:16])

        # Round 1-8
        for rn in range(1,9):
            state = self._sub_bytes(state)
            state = self._shift_rows(state)
            state = self._mix_columns(state)
            state = self._add_round_key(state, self.round_keys[rn*16 : (rn+1)*16])
        
        # Round 9
        rn = 9
        state = self._sub_bytes(state)
        state = self._shift_rows(state)
        # /!\ FAULT INJECTION /!\
        if self.fault_location != None:
            state[self.fault_location] ^= randint(1, 15) 
        state = self._mix_columns(state)
        state = self._add_round_key(state, self.round_keys[rn*16 : (rn+1)*16])

        # Round 10
        rn = 10
        state = self._sub_bytes(state)
        state = self._shift_rows(state)
        state = self._add_round_key(state, self.round_keys[rn*16:])
        
        return state

    def _add_round_key(self, state: list, round_key: list) -> list:
        new_state = [v ^ k for (v, k) in zip(state, round_key)]
        return new_state

    def _sub_bytes(self, state: list) -> list:
        new_state = [SBOX[v] for v in state]
        return new_state
    
    def _shift_rows(self, state: list) -> list:
        new_state = state[:]
        # Row 2
        new_state[ 1] = state[ 5]
        new_state[ 5] = state[ 9]
        new_state[ 9] = state[13]
        new_state[13] = state[ 1]
        # Row 3
        new_state[ 2] = state[10]
        new_state[ 6] = state[14]
        new_state[10] = state[ 2]
        new_state[14] = state[ 6]
        # Row 3
        new_state[ 3] = state[15]
        new_state[ 7] = state[ 3]
        new_state[11] = state[ 7]
        new_state[15] = state[11]
        
        return new_state
    
    def _mix_columns(self, state: list) -> list:
        new_state = self._mix_single_column(state[  : 4])\
                  + self._mix_single_column(state[ 4: 8])\
                  + self._mix_single_column(state[ 8:12])\
                  + self._mix_single_column(state[12:  ])
        return new_state

    def _mix_single_column(self, column: list) -> list:
        t = column[0] ^ column[1] ^ column[2] ^ column[3]
        u = column[0]
        column[0] ^= t ^ xtime(column[0] ^ column[1])
        column[1] ^= t ^ xtime(column[1] ^ column[2])
        column[2] ^= t ^ xtime(column[2] ^ column[3])
        column[3] ^= t ^ xtime(column[3] ^ u)
        return column

def test_f_aes128():
    '''
    Test vectors are taken from https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.197-upd1.pdf
    '''

    key = bytes.fromhex("2b7e151628aed2a6abf7158809cf4f3c")
    m = bytes.fromhex("3243f6a8885a308d313198a2e0370734")

    ciphC = F_AES128(key)
    ciphF = F_AES128(key, True)

    # Key Schedule
    bytes(ciphC.round_keys[-16:]) == bytes.fromhex("d014f9a8c9ee2589e13f0cc8b6630ca6")
    bytes(ciphF.round_keys[-16:]) == bytes.fromhex("d014f9a8c9ee2589e13f0cc8b6630ca6")

    # Encryption
    cc = ciphC.encrypt(m)
    fc = ciphF.encrypt(m)
    assert bytes(cc) == bytes.fromhex("3925841d02dc09fbdc118597196a0b32")
    assert bytes(fc) != bytes.fromhex("3925841d02dc09fbdc118597196a0b32")

test_f_aes128()

## Function for displaying differences

In [166]:
def print_diff(c0: bytes, c1: bytes):
    assert len(c0) == len(c1)
    n = len(c0)

    for i in range(n):
        print(f"{i:2d} ", end="")
    print()
    for _ in range(n):
        print(f"---", end="")
    print()

    for v in list(c0):
        print(f"{v:2X} ", end="")
    print()
    for v in list(c1):
        print(f"{v:2X} ", end="")
    print()

    for (v0, v1) in zip(list(c0), list(c1)):
        d = v0 ^ v1
        print(f"{d:2X} ", end="")
    print()

def test_print_diff():
    key = bytes.fromhex("2b7e151628aed2a6abf7158809cf4f3c")
    m = bytes.fromhex("3243f6a8885a308d313198a2e0370734")

    ciphC = F_AES128(key)
    ciphF = F_AES128(key, True)

    # Key Schedule
    bytes(ciphC.round_keys[-16:]) == bytes.fromhex("d014f9a8c9ee2589e13f0cc8b6630ca6")
    bytes(ciphF.round_keys[-16:]) == bytes.fromhex("d014f9a8c9ee2589e13f0cc8b6630ca6")

    # Encryption
    cc = ciphC.encrypt(m)
    fc = ciphF.encrypt(m)

    print_diff(cc, fc)

test_print_diff()

 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 
------------------------------------------------
39 25 84 1D  2 DC  9 FB DC 11 85 97 19 6A  B 32 
44 25 84 1D  2 DC  9 2B DC 11 4E 97 19 31  B 32 
7D  0  0  0  0  0  0 D0  0  0 CB  0  0 5B  0  0 


## Function for computing backwards

Given 4 bytes of a key guess and 4 bytes of a ciphertext, this function computes the corresponding 4 bytes (which forms a column) at the beginning of the last round.

In [167]:
INV_SBOX = [0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
            0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
            0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
            0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
            0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92, 
            0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
            0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
            0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
            0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
            0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
            0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
            0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
            0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
            0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
            0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
            0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d]


def compute_backwards(c_qu: list, k_qu: list) -> list:
    '''
    :param c_qu: quadruplet of ciphertext
    :param k_qu: quadruplet of last round key
    :return qu : quadruplet (column) of the state before the last SubBytes
    '''
    assert len(c_qu) == 4
    assert len(k_qu) == 4

    qu = [c^k for (c, k) in zip(c_qu, k_qu)]
    qu = [INV_SBOX[v] for v in qu]
   
    return qu

def test_compute_backwards():
    k_qu = [0xd0, 0x63, 0x0c, 0x89]
    c_qu = [0x39, 0x6a, 0x85, 0xfb]
    f_qu = [0xce, 0xe0, 0x01, 0x2d]
    cqu_bw_ref = [0xeb, 0x40, 0xf2, 0x1e]

    cqu_bw = compute_backwards(c_qu, k_qu)
    assert cqu_bw == cqu_bw_ref
    fqu_bw = compute_backwards(f_qu, k_qu)
    print_diff(cqu_bw, fqu_bw)

test_compute_backwards()

 0  1  2  3 
------------
EB 40 F2 1E 
E9 41 F3 1D 
 2  1  1  3 


## Key recovery

In [168]:
def is_diff_related(diff: list) -> bool:
    '''
    Check if the differences correspond to the relation of MixColumns. 
    There are 4 possible positions of the fault. If it matches one of them, then return True.  
    '''
    if (diff[1] == diff[2] and diff[0] == xtime(diff[1]) and diff[3] == xtime(diff[1]) ^ diff[1])\
    or (diff[2] == diff[3] and diff[1] == xtime(diff[2]) and diff[0] == xtime(diff[2]) ^ diff[2])\
    or (diff[3] == diff[0] and diff[2] == xtime(diff[3]) and diff[1] == xtime(diff[3]) ^ diff[3])\
    or (diff[0] == diff[1] and diff[3] == xtime(diff[0]) and diff[2] == xtime(diff[0]) ^ diff[0]):
        return True
    else:
        return False
    
def get_quadruplet(cc, fc):
    '''
    :param cc: correct ciphertext
    :param fc: faulty ciphertext
    :return qu_no: quadruplet number (0, 1, 2, 3)
    :return c_qu : correct quadruplet extracted from correct ciphertext
    :return f_qu : faulty quadruplet extracted from faulty ciphertext
    
    Calculate the differences from the pair of correct ciphertext and faulty ciphertext.
    Determine the quadruplet number (Q0, Q1, Q2, Q3) from the differences.
    Then get the quadruplets from the pair of ciphertexts
    '''

    c_diff = []
    qu_idx = []
    for i in range(16):
        d = cc[i] ^ fc[i]
        if d != 0:
            c_diff.append(d)
            qu_idx.append(i)
    
    assert len(c_diff) == 4
    assert len(qu_idx) == 4

    if qu_idx == [0, 7, 10, 13]:
        qu_no = 0
        cqu = [cc[0], cc[13], cc[10], cc[7]]
        fqu = [fc[0], fc[13], fc[10], fc[7]]
    elif qu_idx == [1, 4, 11, 14]:
        qu_no = 1
        cqu = [cc[4], cc[1], cc[14], cc[11]]
        fqu = [fc[4], fc[1], fc[14], fc[11]]
    elif qu_idx == [2, 5, 8, 15]:
        qu_no = 2
        cqu = [cc[8], cc[5], cc[2], cc[15]]
        fqu = [fc[8], fc[5], fc[2], fc[15]]
    elif qu_idx == [3, 6, 9, 12]:
        qu_no = 3
        cqu = [cc[12], cc[9], cc[6], cc[3]]
        fqu = [fc[12], fc[9], fc[6], fc[3]]
    else:
        raise ValueError
    
    return qu_no, cqu, fqu
    

In [169]:
def recover_key(qu_no: int, cqus: list, fqus: list):
    '''
    :param qu_no: quadruplet number (0, 1, 2, 3) corresponds to the column number
                  of the state at the beginning of the last round
    :param cqus : correct quadruplets
    :param fqus : faulty quadruplets
    :return qukey_candidates: quadruplet of recovered key
    '''

    # For simplicity, we assume that the last 2 key bytes of each column are known
    # Indices:      10     7      14    11       2    15       6     3
    known_key = [[0x0c, 0x89], [0x0c, 0xc8], [0xf9, 0xa6], [0x25, 0xa8]]

    qukey_candidates = []

    # g0, g1: guess for 2 remaining bytes in the quadruplet
    for g0 in range(256):
        for g1 in range(256):
            qukey_guess = [g0, g1] + known_key[qu_no]        
            is_good_candidate = True
            
            for cqu, fqu in zip(cqus, fqus):                
                # Corresponding quadruplets at the beginning of the last round
                cqu_bw = compute_backwards(cqu, qukey_guess)
                fqu_bw = compute_backwards(fqu, qukey_guess)

                diff = [rv ^ fv for rv, fv in zip(cqu_bw, fqu_bw)]
                if not is_diff_related(diff):
                    is_good_candidate = False
                    break
            
            if is_good_candidate:
                qukey_candidates.append(qukey_guess)
                print(f"Good : {qukey_guess}")
                print(f"Diff : ")
                print_diff(cqu_bw, fqu_bw)
                print()

    return qukey_candidates


In [172]:
def randomize_encryption(key: bytes, N: int, fpos: int):
    '''
    :param key:
    :param   N: number of encryptions
    :param fpos: fault position in [0, 15]
    :return ccpts: list of correct ciphertexts
    :return fcpts: list of faulty ciphertexts
    '''

    ciphC = F_AES128(key)
    ciphF = F_AES128(key, fault_position=fpos)

    plts = [bytes([randint(0,15) for _ in range(16)]) for _ in range(N)]
    ccpts = [ciphC.encrypt(m) for m in plts]
    fcpts = [ciphF.encrypt(m) for m in plts]

    return ccpts, fcpts

In [171]:
def test_key_recovery():
    '''
    This test fix the fault position to be the first byte of the state.
    However, the attack does not require the fault position to be fixed.
    It just requires that there are 2 pairs of correct and faulty ciphertexts
    such that the faults are injected in the same quadruplet.
    '''

    key = bytes.fromhex("2b7e151628aed2a6abf7158809cf4f3c")
    qukey_ref = [0xd0, 0x63, 0x0c, 0x89]

    fault_position = 0
    qu_no = fault_position // 4

    ccpts, fcpts = randomize_encryption(key, N=2, fpos=fault_position)
    cqus, fqus = [], []

    for cc, fc in zip(ccpts, fcpts):
        _qu_no, cqu, fqu = get_quadruplet(cc, fc)
        assert _qu_no == qu_no
        cqus.append(cqu)
        fqus.append(fqu)

    recovered_candidates = recover_key(qu_no, cqus, fqus)
    assert len(recovered_candidates) == 1
    assert recovered_candidates[0] == qukey_ref

test_key_recovery()

Good : [208, 99, 12, 137]
Diff : 
 0  1  2  3 
------------
CE AD F1 90 
C0 AA F6 99 
 E  7  7  9 
