# Hands-on differential fault analysis tutorial

### What is differential fault analysis ?

Differential fault analysis or DFA is a cryptanalysis technique that works with most modern cyphers. The idea is that you can't bruteforce the cypher and get the key because there are just too much possibilities. However, if you can inject a fault (which basically means flipping one bit during the execution of the algorithm), it will propagate and you might be able to retrieve valuable information and then deduce the secret key.

One could inject such faults with various techniques such as targeting a chip's memory with a laser or inducing variations in the electric current it recieves.

### What are we going to do exactly ?

We're going to learn DFA with DES. DES stands for Data Encryption Standard. It was the standard encryption algorithm before AES was invented. It uses a 56 bits key size, which means that nowadays one could bruteforce it pretty easily.
However 3DES is still in use, which consists of applying DES three times. It has a complexity of 2^112 and is still used.

So here is the plan :
* First we'll learn how DES works
* Then we'll implement it (you need a DES implementation in order to perform a DFA anyways)
* Then we'll attack the key throught DFA !

All the implementations will be done in Python, so I recommend you have a minimal understanding of the language.

### Let's go

Let's start now ! If you're already excited by the tutorial, don't forget to **star this repo**, thanks ! <3

## How DES Works

We'll start with a bit of theory : we first need to understand DES's functionning.

I said earlier that DES uses a 56 bits keysize. That is not exactly right : DES uses a 64 bits keysize, but eight of those bits are just parity bits : they are not used during the computation and are here to verify the key isn't corrupted. 
This brings down the key size to 56 bits, meaning a bruteforce attaque would have a complexity of 2^56. This is a lot but largely realistic nowadays.

Let's start with a scheme of what DES does :

![title](des.png) 

This might seem overwhelming if you've never done any crypto before but you will see it's a pretty simple design. 
DES belongs to the family of the *Feistel cyphers*. It means you can draw it as a twisted ladder, as you can see above. One of their main interests aside from their simplicity is that there is only one algorithm for cyphering and decyphering: the key doesn't change, nor does the algorithm.

N.B: all the permutations described below are hardcoded in the algorithm and are part of the standard. We won't detail the way they are built, but they are meant to optimize the security. If you want to see all the constants somewhere else, I'd recommend wikipedia : https://fr.wikipedia.org/wiki/Constantes_du_DES

What DES basically does is :

* First, it takes a 64 bits message and applies an initial permutation (IP) to it, which just modifies the order of the bits
* Then the main algorithm starts: cut the message in two halves, L0 and R0.
* Derive a subkey K1, of 48 bits from the original key of 56 bits
* Apply a function f (we'll see exactly what it does when implementing it) for *cyphering R0 with K1*
* Xor all this to the left part of the message (L0). This will become the right part of the message for the next step (R1)
* R0 becomes the left part of our new message (R1)
* Apply this transformation until reaching a message L16, R16, then the message will be cyphered
* Finally you need to revert the permutation IP you did in the beginning, so apply IP^-1 to your message L16|R16

**In short, what the algorithm does each step is to compute : Li+1 = Ri and Ri+1 = Li xor f(Ki, Ri), with Ki being derivated from the original key K**

# Implementation

**Before anything else, remember that our DES function will take 64 bits integers as input message.**

Let's start by implementing small helper functions. 
We'll begin with the function that will cut our input in two halves that we'll be able to then treat separately.

In [1]:
def cutInHalves(entier):
    return (entier >> 32) & 0xFFFFFFFF, entier & 0xFFFFFFFF

Then we need a function to perform permutations for us. A permutation table is basically a list of positions.
The ith number in the table tells you which bit of the input has to be put in it's place.

For example if I have the table : [3, 4, 2, 1, 4, 3] and the input 0b<font color='red'>1</font><font color='blue'>0</font><font color='green'>1</font><font color='orange'>0</font>

Then my output will be : 0b<font color='green'>1</font><font color='orange'>0</font><font color='blue'>0</font><font color='red'>1</font><font color='orange'>0</font><font color='green'>1</font>

As you can see, the output is of the table's length and some bits might be used more than once

In [2]:
# Permuts bits of to permute, according to the table
def permutation(toPermute, table , inputSize, verbose=False):
    res = 0
    for i in range(len(table)):
        mask = 1 << (inputSize - table[i])  # The only bit that should be one is the one that will get permuted at this round
        if verbose is True:
            print(bin(mask))
        bitPermuted = bool(toPermute & mask) << (len(table) - i - 1) # then we shift that bit ((bool)toPermute & mask)
                                                                     # at it's post permutation position
        res |= bitPermuted # remember that 0 | 1 = 1, 001 | 100 = 101, etc
    return res

print("Number:", bin(0b1010), hex(0b1010), "table:", [3, 4, 2, 1, 4, 3])
print("permutation:", bin(permutation(0b1010, table=[3, 4, 2, 1, 4, 3], inputSize=4, verbose=True)))

Number: 0b1010 0xa table: [3, 4, 2, 1, 4, 3]
0b10
0b1
0b100
0b1000
0b1
0b10
permutation: 0b100101


We'll now define the expansion function. I mentionned earlier that the subkeys are 48 bits long. Thus we need a way to make the 32 bits halves 48 bits long. This is what the expansion function is for.

In [3]:
E = [
    32, 1,  2,  3,  4,  5,
    4,  5,  6,  7,  8,  9,
    8,  9,  10, 11, 12, 13,
    12, 13, 14, 15, 16, 17,
    16, 17, 18, 19, 20, 21,
    20, 21, 22, 23, 24, 25,
    24, 25, 26, 27, 28, 29,
    28, 29, 30, 31, 32, 1
]

def expansion(inputMessage):
    return permutation(inputMessage, E, 32)

### Key Schedule

We will now implement the Key schedule. The key schedule is the algorithm that tells us how to deduce 16 48 bits keys Ki from the original 64 bits key K.

The first thing we will need is the LeftShift function. It basically shifts the two halves of a 64 bits integers separately. For example if you shifted: 0x0000000100000002 you should get 0x0000000200000004 because the algorithm would have shifted 0x00000001 and 0x00000002 separately.

Note that those rotations are circular meaning 0x1000000000000000 would become : 0x0000000100000000

In [4]:
# left shift a 32 bits integer circularly
def rotate(X):
    poidsFort = 1 if X > pow(2, 27) else 0
    X = X << 1
    X = (X & 0x0FFFFFFF) | poidsFort
    return X

# left shift the two halves of a 32 bits integer circularly
def leftShift(T, verbose=False):
    C = (T & 0xFFFFFFF0000000) >> 28
    D = T & 0x0000000FFFFFFF
    if verbose is True:
        print("before rotation:", hex(C), hex(D))
    C = rotate(C)
    D = rotate(D)
    if verbose is True:
        print("after rotation:", hex(C), hex(D))
    
    C = C << 28
    return (C | D)

a = 0x0000000100000002
print(hex(leftShift(a, True)))
print(hex(rotate(0x10000000)))

before rotation: 0x10 0x2
after rotation: 0x20 0x4
0x200000004
0x1


The keyschedule uses two permutations, PC1 and PC2

PC1 is used once before the loop, in the initialization step. It shrinks the key down from 64 to 56 bit, effectively removing the parity bits. We'll call T that 56 bits version of K

PC2 takes a T as an input and outputs the 48 bits of a Ki.
At each step, T is shifted either once or twice (which step gets 2 shifts is hardcoded in the standard too). These shifts ensure every Ki will be different from one another.

After this function, we obtain a list of the 48 subkeys that we'll be able to use for encryption and decryption

In [5]:
PC1 = [
    57, 49, 41, 33, 25, 17, 9,
    1,  58, 50, 42, 34, 26, 18,
    10, 2,  59, 51, 43, 35, 27,
    19, 11, 3,  60, 52, 44, 36,
    63, 55, 47, 39, 31, 23, 15,
    7,  62, 54, 46, 38, 30, 22,
    14, 6,  61, 53, 45, 37, 29,
    21, 13, 5,  28, 20, 12, 4
]

PC2 = [
    14, 17, 11, 24, 1,  5,
    3,  28, 15, 6,  21, 10,
    23, 19, 12, 4,  26, 8,
    16, 7,  27, 20, 13, 2,
    41, 52, 31, 37, 47, 55,
    30, 40, 51, 45, 33, 48,
    44, 49, 39, 56, 34, 53,
    46, 42, 50, 36, 29, 32
]

In [6]:
def keySchedule(K):
    subKeys = []
    
    nbPermutations = [1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1]
    T = permutation(K, PC1, 64)

    for perm in nbPermutations:
        T = leftShift(T) if perm == 1 else leftShift(leftShift(T))
        subKeys.append(permutation(T, PC2, 56))
    return subKeys

## F function

We'll now dive into the F function. First, let's take look at an overview of this function :

![title](F.png)

The first step is pretty simple : take your 32 bits and expand them to 48 bits, then xor them to Ki and pass them to the SBoxes, but what is an SBox ?

### The SBoxes

We will now introduce the concept of Sbox. The Sboxes are the main reason why DES is a good algorithm. The S in SBox stands for substitution. SBoxes are a kind of permutation that are meant to introduce non-linearity into the algorithm.
What they do is they take a 6 bits input (ranging from 0 to 63) and output a 4 bits block (ranging from 0 to 15).

Thus, after the expansion, the message is divided into 8 chunks of 6 bits. Each sbox is fed with one block and outputs 4 bits.

Here is the way one calculates the output of an sbox : take the least and most significant bits of the block, those tell you on which line your 4 bits integer is, then take the 4 bits in the middle, those tell you which column to look for.

#### It will be clearer with an example :

Take the Sbox0 :  
                [14,    4,  13, 1,  2,  15, 11, 8,  3,  10, 6,  12, 5,  9,  0,  7],  
                [0, 15, 7,  4,  14, 2,  13, 1,  10, 6,  12, 11, 9,  5,  3,  8],  
                [4, 1,  14, 8,  13, 6,  2,  11, 15, 12, 9,  7,  3,  10, 5,  0],  
                [15,    12, 8,  2,  4,  9,  1,  7,  5,  11, 3,  14, 10, 0,  6,  13]  

Assume your chunk is 44 or 0b<font color='red'>1</font><font color='blue'>0110</font><font color='red'>0</font>.

Then the line number is : 0b<font color='red'>1</font><font color='red'>0</font>, that is 2  
And the column number is: 0b<font color='blue'>0110</font>, that is 6

Then you can find the output of that sbox for 44:  
[14, 4, 13,  1,   2,  15,<font color='blue'>11</font>,  8,   3, 10,  6, 12, 5,  9,  0,  7],  
[0,   15,   7,  4,  14,   2, <font color='blue'>13</font>,  1,  10,  6, 12, 11, 9,  5,  3,  8],  
<font color='red'>[4, 1,  14, 8,  13, 6,  </font><font color='green'>2</font>, <font color='red'>11, 15, 12, 9,  7,  3,  10, 5,  0]</font>,  
[15,    12, 8,  2,  4,  9,  <font color='blue'>1</font>,  7,  5,  11, 3,  14, 10, 0,  6,  13]  

Which means that our block 0b101100 (44) becomes 0b0010 (2) after going throught this sbox.

### Back to F

Let's recap what happens in F : first, the 32 bits block Ri is expanded to 48 bits, then, these 48 bits are xored to the curent subkey Ki.
These bits are then cut in chunks of 6 bits. Each chunk passes throught one sbox, which gives us 4 * 8 = 32 bits.
There is one last permutation P to be performed on these 32 bits before we finally can output F(Ki, Ri)

##### Time to implement this now !

In [7]:
# There are 8 SBoxes

Sbox = [
    [
        [14,    4,  13, 1,  2,  15, 11, 8,  3,  10, 6,  12, 5,  9,  0,  7],
        [0, 15, 7,  4,  14, 2,  13, 1,  10, 6,  12, 11, 9,  5,  3,  8],
        [4, 1,  14, 8,  13, 6,  2,  11, 15, 12, 9,  7,  3,  10, 5,  0],
        [15,    12, 8,  2,  4,  9,  1,  7,  5,  11, 3,  14, 10, 0,  6,  13]
    ],
    [
        [15,    1,  8,  14, 6,  11, 3,  4,  9,  7,  2,  13, 12, 0,  5,  10],
        [3, 13, 4,  7,  15, 2,  8,  14, 12, 0,  1,  10, 6,  9,  11, 5],
        [0, 14, 7,  11, 10, 4,  13, 1,  5,  8,  12, 6,  9,  3,  2,  15],
        [13,    8,  10, 1,  3,  15, 4,  2,  11, 6,  7,  12, 0,  5,  14, 9]
    ],
    [
        [10,    0,  9,  14, 6,  3,  15, 5,  1,  13, 12, 7,  11, 4,  2,  8],
        [13,    7,  0,  9,  3,  4,  6,  10, 2,  8,  5,  14, 12, 11, 15, 1],
        [13,    6,  4,  9,  8,  15, 3,  0,  11, 1,  2,  12, 5,  10, 14, 7],
        [1, 10, 13, 0,  6,  9,  8,  7,  4,  15, 14, 3,  11, 5,  2,  12]
    ],
    [
        [7, 13, 14, 3,  0,  6,  9,  10, 1,  2,  8,  5,  11, 12, 4,  15],
        [13,    8,  11, 5,  6,  15, 0,  3,  4,  7,  2,  12, 1,  10, 14, 9],
        [10,    6,  9,  0,  12, 11, 7,  13, 15, 1,  3,  14, 5,  2,  8,  4],
        [3, 15, 0,  6,  10, 1,  13, 8,  9,  4,  5,  11, 12, 7,  2,  14]
    ],
    [
        [2, 12, 4,  1,  7,  10, 11, 6,  8,  5,  3,  15, 13, 0,  14, 9],
        [14,    11, 2,  12, 4,  7,  13, 1,  5,  0,  15, 10, 3,  9,  8,  6],
        [4, 2,  1,  11, 10, 13, 7,  8,  15, 9,  12, 5,  6,  3,  0,  14],
        [11,    8,  12, 7,  1,  14, 2,  13, 6,  15, 0,  9,  10, 4,  5,  3]
    ],
    [
        [12,    1,  10, 15, 9,  2,  6,  8,  0,  13, 3,  4,  14, 7,  5,  11],
        [10,    15, 4,  2,  7,  12, 9,  5,  6,  1,  13, 14, 0,  11, 3,  8],
        [9, 14, 15, 5,  2,  8,  12, 3,  7,  0,  4,  10, 1,  13, 11, 6],
        [4, 3,  2,  12, 9,  5,  15, 10, 11, 14, 1,  7,  6,  0,  8,  13]
    ],
    [
        [4, 11, 2,  14, 15, 0,  8,  13, 3,  12, 9,  7,  5,  10, 6,  1],
        [13,    0,  11, 7,  4,  9,  1,  10, 14, 3,  5,  12, 2,  15, 8,  6],
        [1, 4,  11, 13, 12, 3,  7,  14, 10, 15, 6,  8,  0,  5,  9,  2],
        [6, 11, 13, 8,  1,  4,  10, 7,  9,  5,  0,  15, 14, 2,  3,  12]
    ],
    [
        [13,    2,  8,  4,  6,  15, 11, 1,  10, 9,  3,  14, 5,  0,  12, 7],
        [1, 15, 13, 8,  10, 3,  7,  4,  12, 5,  6,  11, 0,  14, 9,  2],
        [7, 11, 4,  1,  9,  12, 14, 2,  0,  6,  10, 13, 15, 3,  5,  8],
        [2, 1,  14, 7,  4,  10, 8,  13, 15, 12, 9,  0,  3,  5,  6,  11]
    ]
]

In [8]:
P = [
    16, 7,  20, 21,
    29, 12, 28, 17,
    1,  15, 23, 26,
    5,  18, 31, 10,
    2,  8,  24, 14,
    32, 27, 3,  9,
    19, 13, 30, 6,
    22, 11, 4,  25
]

def F(R, sousCle):
    # after expansion, R is 48 bits long
    T = expansion(R) ^ sousCle
    
    res = 0

    for i in range(1, 9):
        # 48 bits / 8 = 8 block of 6 bits to be handle separately 
        block = T >> (48 - i*6)

        msbDoubled = (block & 0b100000) >> 4
        lsb = block & 1

        line = msbDoubled + lsb
        column = (block & 0b011110) >> 1

        # Then every block goes throught one SBox => 4 bits outputed
        # for each Sbox, append it's output to the result
        res = res | (Sbox[i - 1][line][column] << (32 - i*4))
    return permutation(res, P, 32)

### Putting DES together

Now that we have all the building blocks, we can head towards implementing the DES algorithm as defined in the beginning of the tutorial. Here is the algorithm, as a reminder:

![title](des.png)

In [9]:
# First we need to define IP and IPinv

IP = [
    58, 50, 42, 34, 26, 18, 10, 2,
    60, 52, 44, 36, 28, 20, 12, 4,
    62, 54, 46, 38, 30, 22, 14, 6,
    64, 56, 48, 40, 32, 24, 16, 8,
    57, 49, 41, 33, 25, 17, 9,  1,
    59, 51, 43, 35, 27, 19, 11, 3,
    61, 53, 45, 37, 29, 21, 13, 5,
    63, 55, 47, 39, 31, 23, 15, 7
]
    
IPinv = [
    40, 8,  48, 16, 56, 24, 64, 32,
    39, 7,  47, 15, 55, 23, 63, 31,
    38, 6,  46, 14, 54, 22, 62, 30,
    37, 5,  45, 13, 53, 21, 61, 29,
    36, 4,  44, 12, 52, 20, 60, 28,
    35, 3,  43, 11, 51, 19, 59, 27,
    34, 2,  42, 10, 50, 18, 58, 26,
    33, 1,  41, 9,  49, 17, 57, 25
]

In [10]:
def DES(clear, K):
    # Derivating K into 16 subkeys
    subKeys = keySchedule(K)
    
    # Inital Permutation, IP
    clearIP = permutation(clear, IP, 64)
    L, R = cutInHalves(clearIP)

    # 1 Feistel round/subkey
    for i in subKeys:
        LiPlus1 = R
        RiPlus1 = L ^ F(R, i)

        L = LiPlus1
        R = RiPlus1
    
    # swap R16 and L16
    swapped = (R << 32) | L
    # apply IP^-1
    return permutation(swapped, IPinv, 64)

In [11]:
# Let's try it. There's a DES calculator here : http://www.emvlab.org/descalc/ if you want to verify
clear = 0x40f40e794a6b4c4c
key   = 0x16a8dda1065cc7ca

print("Cleartext = " + hex(clear) + " | Key = " + hex(key) + " | Cypher : " + hex(DES(clear, key)))

Cleartext = 0x40f40e794a6b4c4c | Key = 0x16a8dda1065cc7ca | Cypher : 0x94a18565dd7d071d


### Time to attack

It's time to attack our DES cypher. Remember that we are aiming at finding the key.  
Let's first load the clear text/cyphertext couple :

In [12]:
with open("clairEtChiffreJuste.txt", "r") as f:
    lines = f.readlines()
    clair = int(lines[0].strip(), 16)
    chiffreJuste = int(lines[1].strip(), 16)

print("Clear Text:", hex(clair))
print("Cypher:    ", hex(chiffreJuste))

Clear Text: 0x40f40e794a6b4c4c
Cypher:     0x16a8dda1065cc7ca


Now let's load the faulted cyphertexts. They are the result of 32 executions of DES where a bit has been changed on the R15 block, that then has been fed to the last call to the F function, which used the K16 subkey.  
It might seem like you would never get 32 faulted cyphertexts IRL. However, one could for example attack a bank card by flipping one bit per execution using a laser, thus creating faulted cyphertexts.

In [13]:
chiffreFaux = []
with open("chiffreFaux.txt", "r") as f:
    for chiffre in f.readlines():
        chiffreFaux.append(int(chiffre.strip(), 16))
        
print("Faulted cyphertexts:", [hex(i) for i in chiffreFaux])

Faulted cyphertexts: ['0x14a8dde5065cc7de', '0x16aadde5065dc7ca', '0x16b8dfe5065cc7ca', '0x17f8dda7065cc7ca', '0x17f8dde5045dc7ca', '0x17e8dda1165ec7ca', '0x16e8d9a1065cc5ca', '0x17a8dda05658c7c8', '0x1ee8d9a0465cc7ca', '0x16a0dda04648c7ca', '0x16a8d5a0064cc7ca', '0x56a8cda80618c7cb', '0x16a8cda14e58c7cb', '0x16a8cda10614c7ca', '0x56a8cda1065ccfcb', '0x56a8dda1021cc783', '0x36a8cda1065cd6cb', '0x1688dda1025cd6ca', '0x16a8fda1065cd68a', '0x12a89d81025cd68a', '0x12a89da1265cd68a', '0x2a89da1067cc7ca', '0x6a8dca1065ce7ca', '0x6a89db1075c87ea', '0x82a89db1075c83ca', '0x1628ddb1065cc3ca', '0x16a85da1075c83ca', '0x16acdd21075c87da', '0x16addda1875c83ca', '0x16acdda106dcc7ce', '0x16a8dda1065c47de', '0x16addde1065dc75a']


### A bit of theory again

##### So now we have 32 faulted cyphers. How are we going to use them ?

Let's denote the faulted cyphers with a \* (like R16\*, L16\*, R15\*, etc).  
First, let's calculate L16 ^ L16\*, which is the same as R15 ^ R15\*. This will give us the difference between R15 and R15\* ie the position where the fault was injected.  
Suppose we have R15 ^ R15\* = 0b10 = 2, then we know the second bit was flipped.  
We know that bit the went throught the expansion step :  
  
E = [  
    32, 1,  2,  3,  4,  5,  
    4,  5,  6,  7,  8,  9,  
    8,  9,  10, 11, 12, 13,  
    12, 13, 14, 15, 16, 17,  
    16, 17, 18, 19, 20, 21,  
    20, 21, 22, 23, 24, 25,  
    24, 25, 26, 27, 28, 29,  
    28, 29, 30, 31, 32, 1  
]  

Looking at it, we find out that the error was in position 3 of the expanded message, thus we know that message went into the first SBox (remember that each SBox recieves a block of 6 bits).

Let's implement the function that tells us which fault went into which SBox:

In [14]:
def selectFaults():
    aAtq = {'sbox{}'.format(i): [] for i in range(8)}

    for index, i in enumerate(chiffreFaux):
        xor = chiffreJuste ^ i
        xor = permutation(xor, IP, 64)
        
        L16, R16 = cutInHalves(xor)
        exp = bin(expansion(R16))[2:].zfill(48)
        expListe = [exp[i:i+6] for i in range(0, 48, 6)]
        
        for idx, bloc in enumerate(expListe):
            if bloc != '000000':
                aAtq['sbox{}'.format(idx)].append(index)
    return aAtq

print(selectFaults())

{'sbox3': [15, 16, 17, 18, 19, 20], 'sbox6': [3, 4, 5, 6, 7, 8], 'sbox0': [0, 27, 28, 29, 30, 31], 'sbox2': [19, 20, 21, 22, 23, 24], 'sbox5': [7, 8, 9, 10, 11, 12], 'sbox7': [0, 1, 2, 3, 4, 31], 'sbox4': [11, 12, 13, 14, 15, 16], 'sbox1': [23, 24, 25, 26, 27, 28]}


##### And now ?

Okay so we have our faulted cyphers and we know where they went.

Do you remember what happens after the Expansion ? We calculated E(R15\*) ^ K16.  
Let's call B one of the blocks of 6 bits from E(R15\*) ^ K16 that went throught an SBox S.  
The result (that we have since it's part of L16) of this operation is S(B).  

Our goal here is to basically compute S-1(S(B)), ie to go backwards into the SBox.  
We want this value because S-1(S(B)) = B is part of E(R15\*) ^ K16  
Better now, we have E(R15\*) : if we get B, we can calculate E(R15\*) ^ K16 ^ E(R15\*) ie discover a part of K16 !

If we do this with every sbox we'll manage to find K16 back !

##### Let's find a set of possible solutions

Given the nature of an Sbox (see below), we won't try to find every solutions straightaway but we will start with find a set of possible solutions.  

**First, focus on finding potential solutions**

Imagine we're attacking one particular SBox S with one particular faulted cypher.
Let's calculate L16 ^ L16* for that SBox, that is 4 bits.  
With P-1 being the inverse of the permutation P we have :  

L16 = P-1(S(E(R15) ^ K16))) 
L16\* = P-1(S(E(R15\*) ^ K16)))  

P-1(L16 ^ L16\*) = S(E(R15) ^ K16) ^ S(E(R15\*) ^ K16)  

The beauty of that equation is that we already know R15 and R15\* (they are L16 and L16\*, just go back to the diagram of DES if you don't trust me). The only unknown here is K16.

**How do we inverse an Sbox ?**

True, that's an issue : it outputs 4 bits from a 6 bits input, meaning that some of the information is lost in the process. As an example, imagine that this SBox outputed <font color='green'>4</font> :  


**Let's bruteforce K16 !!**  

Yays ! we only have one unknown and it's 6 bits long, that means we have to bruteforce 64 solutions : easy !
We simply need to compute P-1(L16 ^ L16\*) and then compare it to S(E(R15) ^ x) ^ S(E(R15\*) ^ x) for every x from 0 to 63. If the values are equal, then x is a solution.  

This will give us the list of all the potential solutions for the 6 bits of K16 that went throught S. But there is an issue : there are more than one possibilities (mathematicians would say **SBoxes are surjective but not injective**).
For example here, if the SBox outputs 4 we have this : 

[14,    <font color='red'>4</font>,  13, 1,  2,  15, 11, 8,  3,  10, 6,  12, 5,  9,  0,  7],  
[0, 15, 7,  <font color='red'>4</font>,  14, 2,  13, 1,  10, 6,  12, 11, 9,  5,  3,  8],  
[<font color='red'>4</font>, 1,  14, 8,  13, 6,  2,  11, 15, 12, 9,  7,  3,  10, 5,  0],  
[15,    12, 8,  2,  <font color='red'>4</font>,  9,  1,  7,  5,  11, 3,  14, 10, 0,  6,  13]  

There are 4 different possible 6 bits values, being : 0b000010 = 2, 0b000111 = 7, 0b100000 = 32 and 0b101001 = 41

**So how do we know which one is the right solution ?**
 
Happilly, K16 doesn't vary between executions of DES, thus all we have to do is to try with different faults : the potential solutions will vary, however the right value of K16 is fixed. Thus, given a list of 6 faulted blocks, K16 will be the only number that is common to every set of solutions !

#### Now we're ready to implement !

Of course, this process has to be done for every one of the eight SBoxes, but the principles remain the same.

In [15]:
# these are helper functions that we'll use during the attack
# Find the common point of a list of lists
def common(liste):
    result = set(liste[0])
    for l in liste[1:]:
        result.intersection_update(l)
    return result.pop()

# Isolates the input of the sbox by generating a 6 bits mask
def mask48(sbox):
    mask = "111111"
    zerosGauche = sbox * 6 * '0'
    mask = zerosGauche + mask
    while(len(mask) < 48):
        mask += '0'
    return int(mask, 2)

# Isolates the output of the sbox by generating a 4 bits mask
def mask32(sbox):
    mask = "1111"
    return int(mask + "0" * (28 - sbox * 4), 2)

In [16]:
def calcLigneColonneSbox(sbox, expanded, valueToTest):
    # only keep the 6 bits that go throught the sbox
    tmp = expanded & mask48(sbox)
    # removing the zeros on the right
    tmp >>= (7 - sbox) * 6
    tmp ^= valueToTest  # valueToTest is the potential value of the current block of K16 that we're testing (0 to 63)
    
    # same way to calculate the line/column of the input as in DES
    msbDouble = (tmp & 0b100000) >> 4
    lsb = tmp & 1
    ligne = msbDouble + lsb 
    colonne = (tmp & 0b011110) >> 1
    return ligne, colonne

In [17]:
def solutionIsValid(sbox, P_1_L16_xor_L16f, line, column, linef, columnf):
    ver = P_1_L16_xor_L16f & mask32(sbox) # keep the (4 bits) output of the current S box
    ver >>= (7 - sbox) * 4            # remove trailing zeros
    xor = Sbox[sbox][line][column] ^ Sbox[sbox][linef][columnf] 
    return ver == xor                 # if S(E(R15) ^ x) ^ S(E(R15*) ^ x) = P-1(L16 ^ L16*) then we found a potential solution

In [18]:
Pinv = [
    9,  17, 23, 31,
    13, 28, 2,  18,
    24, 16, 30, 6,
    26, 20, 10, 1,
    8,  14, 25, 3,
    4,  29, 11, 19,
    32, 12, 22, 7,
    5,  27, 15, 21
]

In [19]:
import pprint

def findK16(chiffreJuste, chiffreFaux, verbose=False):
    pp = pprint.PrettyPrinter(indent=4)
    K16 = 0x000000000000
    sol = {"sbox{}".format(i): [] for i in range(8)}

    toAttack = selectFaults()
    
    if verbose is True:
        print("Indexes of the SBoxes to be attacked:")
        pp.pprint(toAttack)
        print()
    
    # Replacing the indexes with the values
    for sbox, liste in toAttack.items():
        lChiffreFaux = []
        for i in liste:
            lChiffreFaux.append(chiffreFaux[i])
        toAttack[sbox] = lChiffreFaux
    
    L16, R15 = cutInHalves(permutation(chiffreJuste, IP, 64))
    
    # Attacking the 8 sboxes with the equation we found
    for s in range(8):
        # Each SBox is attacked by 6 faulted cyphers
        for f in range(6): 
            L16f, R15f = cutInHalves(permutation(toAttack['sbox{}'.format(s)][f], IP, 64))
    
            # calculating the terms of the equation
            P_1_L16_xor_L16f = permutation(L16 ^ L16f, Pinv, 32)
            E_R15 = expansion(R15) 
            E_R15f = expansion(R15f)
            
            # currSols stocks a list of possible solutions for the current sbox and faulted cypher
            currSols = []
            # trying every values for the current portion of K16 (6 bits = 64 values)
            for x in range(pow(2, 6)):
                
                ligne, colonne = calcLigneColonneSbox(s, E_R15, x)
                ligneFaux, colonneFaux = calcLigneColonneSbox(s, E_R15f, x)
                
                if (solutionIsValid(s, P_1_L16_xor_L16f, ligne, colonne, ligneFaux, colonneFaux)):
                    currSols.append(x)
            sol["sbox{}".format(s)].append(currSols)

        # the solution is the only block that is common to every set of solutions for the current sbox
        solution = common(sol["sbox{}".format(s)])
        K16 = K16 << 6
        K16 = K16 | solution

        if verbose is True:
            print("Sbox", s + 1)
            print("Potential solutions :")
            pp.pprint(sol['sbox{}'.format(s)])
            print("Solution", s+1 , "=", hex(solution))
            print("current K16 =", hex(K16))
            
    return K16

In [20]:
K16 = findK16(chiffreJuste, chiffreFaux, verbose=True)

Indexes of the SBoxes to be attacked:
{   'sbox0': [0, 27, 28, 29, 30, 31],
    'sbox1': [23, 24, 25, 26, 27, 28],
    'sbox2': [19, 20, 21, 22, 23, 24],
    'sbox3': [15, 16, 17, 18, 19, 20],
    'sbox4': [11, 12, 13, 14, 15, 16],
    'sbox5': [7, 8, 9, 10, 11, 12],
    'sbox6': [3, 4, 5, 6, 7, 8],
    'sbox7': [0, 1, 2, 3, 4, 31]}

Sbox 1
Potential solutions :
[   [9, 20, 41, 52],
    [2, 3, 8, 9, 12, 13, 30, 31, 44, 45, 60, 61],
    [9, 11, 24, 26, 53, 55, 60, 62],
    [1, 5, 8, 9, 12, 13, 40, 43, 44, 47],
    [1, 5, 9, 13, 21, 29, 49, 57],
    [7, 9, 12, 23, 25, 28, 37, 44, 45, 53, 60, 61]]
Solution 1 = 0x9
current K16 = 0x9
Sbox 2
Potential solutions :
[   [0, 1, 44, 45, 46, 47],
    [0, 2, 5, 7, 8, 10],
    [0, 4, 40, 44, 59, 63],
    [0, 2, 6, 8, 10, 14, 55, 63],
    [0, 12, 14, 16, 28, 30, 45, 46, 61, 62],
    [0, 13, 32, 45]]
Solution 2 = 0x0
current K16 = 0x240
Sbox 3
Potential solutions :
[   [4, 5, 12, 13, 20, 21, 22, 23, 30, 31, 36, 37, 38, 39],
    [4, 6, 61, 63],
    [0,

In [21]:
def setKthBit(n,k): 
    return ((1 << k) | n) 

### Retrieving K from K16

**First inverse PC2**  
We did the biggest part of the job by finding K16. Now we need to use it to find the full key. Remember that out of 64 bits, 8 bits of K are parity bits. Thus we only have 8 bits to find from K16.  

Remember that we generated K16 at the key schedule step. It applied PC2 to generate the 48 bits of K16 from a set of 56 bits. This means we first need to  compute PC2-1(K16) with PC2-1 being the inverse permutation of PC2. The output will be 56 bits, 8 of which are lost (kind of the same problem as with the SBox, we are trying to find 56 bits from 48 bits : some of the information has been lost).  
  
However here, we know the position of the bits that were lost, due to the construction of the permutation (there is no bit that is projected twice by PC2).  

**Then PC1**

Don't forget that before the loop, the key schedule applies PC1. Again here, it's simple to invert since this permutation basically removes the parity bits and mixes the 56 lasting bits. That means that doing PC1-1(x) doesn't yield to losing any information about x.

##### So how can we find K ?

Well we know we have 48 bits that are at the right position, and 16 bits that are wrong. Among those 16 bits, 8 are parity bits that are useless for the DES computation. This leaves us with 8 bits to bruteforce (256 combinations).  

Let's just try to encrypt with every possible combination and compare the output with the cyphertext we have. If the results are the same, we found the 8 matching bits ie we found the key.

In [22]:
from itertools import product

PC1inv = [
    8,  16, 24, 56, 52, 44,
    36, 0,  7,  15, 23, 55,
    51, 43, 35, 0,  6,  14,
    22, 54, 50, 42, 34, 0,
    5,  13, 21, 53, 49, 41,
    33, 0,  4,  12, 20, 28,
    48, 40, 32, 0,  3,  11,
    19, 27, 47, 39, 31, 0,
    2,  10, 18, 26, 46, 38,
    30, 0,  1,  9,  17, 25,
    45, 37, 29, 0
]

# Bits 9, 18, 22, 25, 35, 38, 43 et 54 are lost and thus put to 0 here
PC2inv = [
    5,  24, 7,  16, 6,  10,
    20, 18, 0,  12, 3,  15,
    23, 1,  9,  19, 2,  0,
    14, 22, 11, 0,  13, 4,
    0,  17, 21, 8,  47, 31,
    27, 48, 35, 41, 0,  46,
    28, 0,  39, 32, 25, 44,
    0,  37, 34, 43, 29, 36,
    38, 45, 33, 26, 42, 0,
    30,  40
]

def findK56(clear, chiffre, K16):
    # We put the 48 bits of K16 at their position in K (going backwards in the key schedule)
    # Inverse of PC2 : 48 -> 56 bits, 8 wrong bits which's positions are known
    # Inverse of PC1 : 56 -> 64 bits, 8 wrong bits (parity, no consequence on the calculation). However our 8 wrong 
    # bits from PC2 are still wrong, but in different positions that are still trackable
    K48 = permutation(permutation(K16, PC2inv, 48), PC1inv, 56);
        
    # Bruteforcing the 8 missing bits
    masques = list(product([0, 1], repeat=8))
    positionsToRecover = [50, 49, 45, 44, 13, 10, 6, 4]
    
    for mask in masques:
        hypothesis = 0
        for i in range(8):
            if mask[i] == 1:
                hypothesis = setKthBit(hypothesis, k=positionsToRecover[i])
        hypothesis |= K48
        if chiffre == DES(clear, hypothesis):
            return hypothesis
    
    print("Impossible de retrouver K56, erreur")
    raise

In [23]:
K56 = findK56(clair, chiffreJuste, K16)

Finally, we have to calculate the parity bit, not much to say : for every byte, sum the 7 most significant bits. If the sum is odd, then put a 1 as this byte's least significant bit, else put a 0

In [24]:
def parity(K56b):
    strCle = bin(K56b)[2:]
    while len(strCle) < 64:
        strCle = "0" + strCle
    cle = ''
    i = 0
    while i < 64:
        currByte = strCle[i:i+8]
        if currByte.count('1') % 2 == 0:
            currByte = currByte[:-1] + '1'
        cle += currByte
        i+=8
    return int(cle, 2)

In [25]:
print("Found Key:", hex(parity(K56)))
print("Verifying the correctness:")
print()
print("Computed cypher:", hex(DES(clair, parity(K56))))
print("Right cypher:", hex(chiffreJuste))

Found Key: 0x1c8529cea240ae4f
Verifying the correctness:

Computed cypher: 0x16a8dda1065cc7ca
Right cypher: 0x16a8dda1065cc7ca


In [26]:
a = 0xFFFF
print(hex(a), bin(a), a)
a >>= 1
print(hex(a), bin(a), a)

0xffff 0b1111111111111111 65535
0x7fff 0b111111111111111 32767


## Congratulations !

You managed to attack one of the most used cyphers (in it's time) with a DFA. This attack is non-trivial but incredibly powerful. The worst is that the same principles work against most cyphers, even some versions of AES and RSA.  

I hope you loved it. If that's the case, make sure to start me on github or to upvote on hackernews !  
Don't hesitate to email me at maxime.elkael@gmail.com if you manage to extend or do anything cool with this tutorial !  
Thanks for reading,  
  
I hope you had as much fun reading as I had writing this !