# Part 3, Topic 1, Lab B: AES256 Bootloader Attack (MAIN)

---
NOTE: This lab references some (commercial) training material on [ChipWhisperer.io](https://www.ChipWhisperer.io). You can freely execute and use the lab per the open-source license (including using it in your own courses if you distribute similarly), but you must maintain notice about this source location. Consider joining our training course to enjoy the full experience.

---

**SUMMARY:** *Through the previous labs, we've gained a lot of tools to attack unknown embedded devices: SPA, DPA, CPA, trace resynchronization, and more. In this lab, we'll be using some of those techniques to break a more realistic target: a bootloader. Note that there are two versions of this lab. In this one (Lab B), we'll start with no information that couldn't be revealed by watching code be sent to the bootloader. Everything else, we'll need to figure out for ourselves, such as what encryption algorithm the target is using, how it's using it, etc. In lab A, this information will be given and you'll just focus on the attack. It's up to you whether you want to run this lab or Lab A!*

**LEARNING OUTCOMES:**

* Observing power traces to figure out what encryption operation it's running
* Applying CPA and DPA to break different parts of the bootloader
* Understanding different operating modes for block ciphers

## Background

In the world of microcontrollers, a bootloader is a special piece of firmware that is made to let the user upload new programs into memory. This is especially useful for devices with complex code that may need to be patched or otherwise updated in the future - a bootloader makes it possible for the user to upload a patched version of the firmware onto the micro. The bootloader receives information from a communication line (a USB port, serial port, ethernet port, WiFi connection, etc...) and stores this data into program memory. Once the full firmware has been received, the micro can happily run its updated code.

There is one big security issue to worry about with bootloaders. A company may want to stop their customers from writing their own firmware and uploading it onto the micro. For example, this might be for protection reasons - hackers might be able to access parts of the device that weren't meant to be accessed. One way of stopping this is to add encryption. The company can add their own secret signature to the firmware code and encrypt it with a secret key. Then, the bootloader can decrypt the incoming firmware and confirm that the incoming firmware is correctly signed. Users will not know the secret key or the signature tied to the firmware, so they won't be able to "fake" their own.

In [None]:
SCOPETYPE = 'OPENADC'
PLATFORM = 'CWLITEARM'

In [None]:
%%bash -s "$PLATFORM" 
cd ../../../hardware/victims/firmware/bootloader-aes256
make PLATFORM=$1 CRYPTO_TARGET=NONE

In [None]:
%run "../../Setup_Scripts/Setup_Generic.ipynb"

In [None]:
fw_path = "../../../hardware/victims/firmware/bootloader-aes256/bootloader-aes256-{}.hex".format(PLATFORM)
cw.program_target(scope, prog, fw_path)

## The Situation

Simply put, we've got a target device running an encrypted bootloader (a program used to upload new code onto a device) and we want to see if we can get our own code running on the device. We've done a bit of sniffing on the serial lines when the device's firmware is being updated and we've learned the following:

* The device communicates over serial at 38400bps
* When writing memory, the first byte is always zero (probably a command byte)
* There's a 16 byte block of random looking memory (aka it doesn't look like firmware). This part is probably encrypted
* There's a 2 byte CRC at the end of each message
* There's no repetition in the ciphertext.

All together this looks like:

```
       |<-------- Encrypted block (16 bytes) ---------->|
       |                                                |
+------+------+------+------+------+------+ .... +------+------+------+
| 0x00 |              Random looking data               |   CRC-16    |
+------+------+------+------+------+------+ .... +------+------+------+
```

After sending data to the bootloader it responds with either `0xA4` or `0xA1`. The former only happened when we sent a bad CRC.

This time, we won't be triggering off of our trigger pins (you can remove them from the code if you'd like).

From our initial sniffing of the communication lines, we've got the first few messages that were sent:

In [None]:
import pickle
with open("./firmware.pickle", "rb") as f:
    encrypted_firmware = pickle.load(f)

## Doing Recon

Our first step will be to see if we can learning anything about the bootloader from looking at its power traces. Let's start with the boot sequence:

In [None]:
scope.trigger.triggers = "nrst"

In [None]:
scope.adc.samples = 24400

In [None]:
scope.arm()
reset_target(scope)
scope.capture()
wave = scope.get_last_trace()

In [None]:
cw.plot(wave)

The device does appear to be doing something, but it's clearly nothing major - no encryptions or anything. Every microcontroller has boot code and operations that run when it's reset. The device may even have its own bootloader running in ROM! Let's move onto the messages themselves. We do know that there's a CRC for data integrety. We can use the following code to calculate the CRC for us:

In [None]:
# Class Crc
#############################################################
# These CRC routines are copy-pasted from pycrc, which are:
# Copyright (c) 2006-2013 Thomas Pircher <tehpeh@gmx.net>
#
class Crc(object):
    """
    A base class for CRC routines.
    """

    def __init__(self, width, poly):
        """The Crc constructor.

        The parameters are as follows:
            width
            poly
            reflect_in
            xor_in
            reflect_out
            xor_out
        """
        self.Width = width
        self.Poly = poly


        self.MSB_Mask = 0x1 << (self.Width - 1)
        self.Mask = ((self.MSB_Mask - 1) << 1) | 1

        self.XorIn = 0x0000
        self.XorOut = 0x0000

        self.DirectInit = self.XorIn
        self.NonDirectInit = self.__get_nondirect_init(self.XorIn)
        if self.Width < 8:
            self.CrcShift = 8 - self.Width
        else:
            self.CrcShift = 0

    def __get_nondirect_init(self, init):
        """
        return the non-direct init if the direct algorithm has been selected.
        """
        crc = init
        for i in range(self.Width):
            bit = crc & 0x01
            if bit:
                crc ^= self.Poly
            crc >>= 1
            if bit:
                crc |= self.MSB_Mask
        return crc & self.Mask


    def bit_by_bit(self, in_data):
        """
        Classic simple and slow CRC implementation.  This function iterates bit
        by bit over the augmented input message and returns the calculated CRC
        value at the end.
        """
        # If the input data is a string, convert to bytes.
        if isinstance(in_data, str):
            in_data = [ord(c) for c in in_data]

        register = self.NonDirectInit
        for octet in in_data:
            for i in range(8):
                topbit = register & self.MSB_Mask
                register = ((register << 1) & self.Mask) | ((octet >> (7 - i)) & 0x01)
                if topbit:
                    register ^= self.Poly

        for i in range(self.Width):
            topbit = register & self.MSB_Mask
            register = ((register << 1) & self.Mask)
            if topbit:
                register ^= self.Poly

        return register ^ self.XorOut
    
bl_crc = Crc(width = 16, poly=0x1021)

Let's definte a function to do the communication and capture a trace for us. We can try triggering off of our message. There's not much memory on the target, so it's probably decryption on the fly instead of reading in a whole bunch of memory, then doing the decryption.

In [None]:
scope.trigger.triggers = "tio2"
scope.adc.samples = 24400
scope.adc.decimate = ??? # try to get the full encryption in a single trace, then set back to 1
scope.adc.offset = 0
def cap_trace(enc_block):
    message = [0x00]
    target.read()

    key, text = ktp.next()
    message.extend(enc_block)

    crc = bl_crc.bit_by_bit(enc_block)
    message.append(crc >> 8)
    message.append(crc & 0xFF)

    

    target.write(message[:-1])
    time.sleep(0.01)
    scope.arm()
    target.write([crc&0xFF])
    ret = scope.capture()
    if ret:
        print('Timeout happened during acquisition')
    response = target.read()
    if response:
        if ord(response[0]) != 0xA4:
            # Bad response, just skip
            #print("Bad response: {:02X}".format(ord(response[0])))
            return None

    return scope.get_last_trace()
    
ktp = cw.ktp.Basic()
text, key = ktp.next()
wave = cap_trace(text)
wave2 = cap_trace(text)

In [None]:
cw.plot(wave) * cw.plot(wave2)

We've found the decryption. Some things to notice from this:

1. The target immediately goes from reading the ciphertext to running a decryption - there's no preprocessing done at all. This means the ciphertext is likely being immediately fed into the decryption algorithm.
1. We can see operations in the encryption being repeated. This looks a lot like AES - we can see a long distinct operation that's probably MixColumns in there and the overall structure looks a lot like what we saw when looking at TINYAES128C encryptions. Let's make a weak assumption that this is AES128 - we can adjust this as we learn more about the bootloader.
1. We can't see the full encryption
1. There's some jitter here. We'll probably have to resync the traces if we run a CPA attack

Since we can't see the full encryption, decimate the ADC (we don't care too much about the fine details here) and take another look...


## The full encryption

You should see that instead of 9 repititions of MixColumns (or what we assume is MixColumns), there's actually 13. This rules out AES128, but AES256 actually has 14 rounds! We can still attack AES256 without much issue: we basically just have to run two CPA attacks, one for the first half of the key and another for the second half. Again, we're not 100% about this, but it's a good starting point. 

As we mentioned in the debriefing about the bootloader, the ciphertext never seems to repeat. It's pretty unlikely that this device is using straight AES256 (if it even is using AES256), since firmware usually has blocks that repeat. More likely is that AES is being used as a stream cipher: https://en.wikipedia.org/wiki/Block_cipher_mode_of_operation. This could pose an immediate problem for our attack efforts: we need to know either what's going into or coming out of the AES block to perform a CPA attack. However, some of the modes listed on that page (if it's even using one on that page) feed an IV or a counter into that block instead of the plaintext or the ciphertext. We didn't see any encryption operations happening on startup (which the device could've done if it was using one of these IV/counter modes), so we'll probably be okay with a normal CPA attack. Let's try it and see if we can get anything out of it:

In [None]:
from tqdm.notebook import trange
project = cw.create_project("projects/Tutorial_A5", overwrite=True)
scope.adc.offset = 0
scope.adc.decimate = 1
for i in trange(100):
    ktp = cw.ktp.Basic()
    key, text = ktp.next()
    wave = cap_trace(text)
    trace = cw.Trace(wave, text, bytearray([0]*16), bytearray([0]*16))
    project.traces.append(trace)

In [None]:
cw.plot(project.waves[0])

We can eliminate a lot of this jitter by using the resync SAD module:

In [None]:
import chipwhisperer as cw
import chipwhisperer.analyzer as cwa

leak_model = cwa.leakage_models.inverse_sbox_output
resync = cwa.preprocessing.ResyncSAD(project)
resync.enabled=True
resync.ref_trace = 0
resync.target_window = (???, ???)
resync.max_shift = 7000
new_proj = resync.preprocess()



Check to make sure your traces are resynced:

In [None]:
plt = cw.plot([])
for i in range(10):
    plt *= cw.plot(new_proj.waves[i])

plt

All that's left is to actually run the attack. We don't know the correct key, so it won't be highlighted in red.

In [None]:
attack = cwa.cpa(new_proj, leak_model)
#attack.pont_range = [???, ???]

#key = [0xea, 0x79, 0x79, 0x20, 0xc8, 0x71, 0x44, 0x7d, 0x46, 0x62, 0x5f, 0x51, 0x85, 0xc1, 0x3b, 0xcb]

cb = cwa.get_jupyter_callback(attack)
attack_results = attack.run(cb)

The difference in correlation between the best key guess and the next best one makes this look very promising! We now know we're correct about two things:

1. The bootloader is actually decrypting the ciphertext
1. The device is using AES

In [None]:
calc_round_key = attack_results.key_guess()

With that done, we now need to get the second half of the key. Pop over to [Extending AES-128 Attacks to AES-256](Extending%20AES-128%20Attacks%20to%20AES-256.ipynb), since that page explains how to do that...

Back? Let's go through and see if that theory actually holds up.

To make a new model, we start off by inheriting the `AESLeakageHelper` class. We need to make a `leakage()` method that calculates the Hamming weight we use in the CPA attack. To get you started:

In [None]:
import chipwhisperer as cw
class AES256_Round13_Model(cwa.AESLeakageHelper):
    def leakage(self, pt, ct, guess, bnum):
        #You must put YOUR recovered 14th round key here - this example may not be accurate!
        calc_round_key = [0xea, 0x79, 0x79, 0x20, 0xc8, 0x71, 0x44, 0x7d, 0x46, 0x62, 0x5f, 0x51, 0x85, 0xc1, 0x3b, 0xcb]
        state = reverse_round_14(self, pt, calc_round_key)
        state = reverse_round_13(self, state) #reverse state just before inv_subbytes
        return self.inv_sbox(state[bnum] ^ guess[bnum])

Now we just need to make the `reverse_round_14()` and `reverse_round_13()` functions. By passing the class in, we get access to `self.inv_shiftrows()`, `self.inv_subbytes()`, and `self.inv_mixcolumns()`.

In [None]:
def reverse_round_14(self, pt, key):
    state = [pt[i] ^ key[i] for i in range(16)] #AddRoundKey
    state = ???
    state = ???
    return state # we're now at the end of decryption round 1

def reverse_round_13(self, state):
    state = ???
    state = ???
    return state

From there, just update the leakage model and rerun the attack:

In [None]:
leak_model = cwa.leakage_models.new_model(AES256_Round13_Model)
attack.leak_model = leak_model
cb = cwa.get_jupyter_callback(attack)
attack_results = attack.run(cb)

If you built your model correctly, you should again see a pretty likely guess for a key. All that's left now is to combine the 14th and 13th round keys, then use that to figure out the 0th and 1st round keys.

We'll start by getting the transformed 13th round key out of the attack results:

In [None]:
#calc_round_key = [0xea, 0x79, 0x79, 0x20, 0xc8, 0x71, 0x44, 0x7d, 0x46, 0x62, 0x5f, 0x51, 0x85, 0xc1, 0x3b, 0xcb]
rec_key = calc_round_key

In [None]:
rec_key2 = []
for bnum in attack_results.find_maximums():
    print("Best Guess = 0x{:02X}, Corr = {}".format(bnum[0][0], bnum[0][2]))
    rec_key2.append(bnum[0][0])

Now we need to transform that key into the actual 13th round key by running it through ShiftRows and MixColumns:

In [None]:
real_key2 = cwa.aes_funcs.shiftrows(rec_key2)
real_key2 = cwa.aes_funcs.mixcolumns(real_key2)

print("Recovered:", end="")
for subkey in real_key2:
    print(" {:02X}".format(subkey), end="")
print("")

then we can combine the keys:

In [None]:
rec_key_comb = real_key2.copy()
rec_key_comb.extend(rec_key)

print("Key:", end="")
for subkey in rec_key_comb:
    print(" {:02X}".format(subkey), end="")
print("")

and use ChipWhisperer's built in key scheduler to reverse them to the 0th and 1st round keys:

In [None]:
btldr_key = leak_model.key_schedule_rounds(rec_key_comb, 13, 0)
btldr_key.extend(leak_model.key_schedule_rounds(rec_key_comb, 13, 1))
print("Key:", end="")
for subkey in btldr_key:
    print(" {:02X}".format(subkey), end="")
print("")

So we were clearly right about the bootloader running AES256! However, if we try decrypting some of our firmware with the encryption key:

In [None]:
from Crypto.Cipher import AES
cipher = AES.new(bytes(btldr_key), AES.MODE_ECB)
print(bytearray(cipher.decrypt(encrypted_firmware[:16])))
print(bytearray(cipher.decrypt(encrypted_firmware[16:32])))

As we expected, we still get gibberish out of this - the target is definitely using AES as a stream cipher. The question now is, which block cipher mode is it using? Well, we know the ciphertext is being decrypted. We can narrow it down a bit by looking at the end of the encryption. Increase the offset until you reach the end of the encryption.

In [None]:
#scope.adc.offset = 50000
#wave = cap_trace(text)
#key, text = ktp.next()
wave = cap_trace(text)

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
plt.figure()
plt.plot(wave)
plt.show()

It might be hard to pick out, but you should be able to find 16 XOR short operations just after the end of the last round of AES. It might be using CBC mode, which means it'll be using the ciphertext as part of the encryption and decryption of subsequent blocks. Let's do a quick check:

In [None]:
cipher = AES.new(bytes(btldr_key), AES.MODE_ECB)
dec_fw = cipher.decrypt(encrypted_firmware[16:32])
fw = [dec_fw[i] ^ encrypted_firmware[i] for i in range(16)]

In [None]:
print(fw)

It looks like we're finally getting some valid output. Let's try the next block:

In [None]:
cipher = AES.new(bytes(btldr_key), AES.MODE_ECB)
dec_fw = cipher.decrypt(encrypted_firmware[32:48])
fw = [dec_fw[i] ^ encrypted_firmware[i+16] for i in range(16)]
print(fw)

This is very promising! This firmware is repeated twice. The all `FFs` is probably just empty flash memory. The beginning is more curious though:

* It's probably not a flash constant - otherwise it wouldn't be in two blocks in a row
* That's only enough room for a few instructions at most. Again, it's a little strange that it would be repeated like this

It might be some sort of signature. After all, the device doesn't want to write anything to memory unless the ciphertext has properly been decrypted.

There's still the issue of the first block of memory though. There's another secret value called an initialization vector that we need to decrypt that block. To be able to recover that, we'll need to revisit the DPA attack:

### Attack Theory

The bootloader applies the IV to the AES decryption result (`DR`) by calculating


$\text{PT} = \text{DR} \oplus \text{IV}$

where DR is the decrypted ciphertext, IV is the secret vector, and PT is the plaintext that the bootloader will use later. We only have access to one of these: since we know the AES-256 key, we can calculate DR. This exclusive or will be visible in the power traces.

This is enough information for us to attack a single bit of the IV. Suppose we only wanted to get the first bit (bit 0) of the first byte (byte 0) of the IV. We could do the following:

* Split all of the traces into two groups: those with `(DR[0] & 0x01) = 0`, and those with `(DR[0] & 0x01) = 1`. 
* Calculate the average trace for both groups.
* Find the difference between the two averages. Provided we've got sufficient data, we should see a spike where the xor is occuring.
* Look at the direction of the spike to decide if the IV bit is 0 `(PT[0] = DR[0])` or if the IV bit is 1 `(PT[0] = ~DR[0])`.

This is effectively a DPA attack on a single bit of the IV. We can repeat this attack across the whole IV by instead separating by `(DR[byte] & (1 << bit) = 0` and `(DR[byte] & (1 << bit) = bit`.

We'll need to reset the device every encryption since it only uses the IV in the first encryption. This leads to a slightly modified capture loop. You'll need to adjust your offset since we're now triggering near the beginning of the UART transmit instead of nera the end. Run the loop, then interrupt it to get a wave. Then plot and adjust your offset. Repeat until you're near the end of the encryption again:

In [None]:
from Crypto.Cipher import AES
import numpy as np

from tqdm import tnrange
import numpy as np
import time
traces = []
keys = []
plaintexts = []

from tqdm.notebook import trange
project = cw.create_project("projects/Tutorial_A5_IV", overwrite=True)
scope.adc.decimate = 1
scope.adc.offset = 51000
scope.adc.timeout = 1
scope.trigger.triggers = "tio2"
for i in trange(1000):
    scope.io.nrst = 0
    time.sleep(0.02)
    scope.io.nrst = "high_z"
    time.sleep(0.01)
    okay = 0
    while not okay:
        target.write('\0xxxxxxxxxxxxxxxxxx')
        time.sleep(0.005)
        response = target.read()
        i += 1
        if response:
            if ord(response[0]) == 0xA1:
                okay = 1
    message = [0x00]
    
    target.flush()
    
    key, text = ktp.new_pair()  # manual creation of a key, text pair can be substituted here
    
    wave = cap_trace(text)
    if wave is None:
        continue
    
    
    #wave = scope.get_last_trace()
    trace = cw.Trace(wave, text, bytearray([0]*16), bytearray([0]*16))
    project.traces.append(trace)

As you can see, we'll again need to resync to get rid of the jitter:

In [None]:
plt = cw.plot([])
for i in range(2):
    plt *= cw.plot(project.waves[i])
    
plt

Adjust the target window here as necessary:

In [None]:
import chipwhisperer as cw
import chipwhisperer.analyzer as cwa

leak_model = cwa.leakage_models.inverse_sbox_output
resync = cwa.preprocessing.ResyncSAD(project)
resync.enabled=True
resync.ref_trace = 0
resync.target_window = (???, ???)
resync.max_shift = 6000
new_proj = resync.preprocess()

In [None]:
plt = cw.plot([])
for i in range(10):
    plt *= cw.plot(new_proj.waves[i])
    
plt

Some numpy functions will be useful here, so we'll convert our ChipWhisperer project to numpy arrays:

In [None]:
trace_array = np.array([new_proj.waves[i] for i in range(len(new_proj.traces))])
textin_array = np.array([new_proj.textins[i] for i in range(len(new_proj.traces))])

We need to decrypt what we sent to the device to get one half of the input to the XOR:

In [None]:
knownkey = [0x94, 0x28, 0x5D, 0x4D, 0x6D, 0xCF, 0xEC, 0x08, 0xD8, 0xAC, 0xDD, 0xF6, 0xBE, 0x25, 0xA4, 0x99,
            0xC4, 0xD9, 0xD0, 0x1E, 0xC3, 0x40, 0x7E, 0xD7, 0xD5, 0x28, 0xD4, 0x09, 0xE9, 0xF0, 0x88, 0xA1]

knownkey = bytes(knownkey)
dr = []
aes = AES.new(knownkey, AES.MODE_ECB)
for i in range(len(new_proj.traces)):
    ct = bytes(textin_array[i])
    pt = aes.decrypt(ct)
    d = [bytearray(pt)[i] for i in range(16)]
    dr.append(d)

The basic idea is the same as the DPA attack: guess a bit and group traces based on that. We'll do the first byte here as an example. For each of the bits, you should see roughly half the traces fall into each group:

In [1]:
grouped_byte_traces = []
byte = ??? # start with zero, then come back later and change

for bit in range(8):
    grouped_bit_traces = [], []
    for i in range(len(new_proj.traces)):
        if (dr[i][byte] & (1 << bit)):
            grouped_bit_traces[0].append(trace_array[i])
        else:
            grouped_bit_traces[1].append(trace_array[i])
    grouped_byte_traces.append(grouped_bit_traces)
    print(len(grouped_bit_traces[0]))

SyntaxError: invalid syntax (<ipython-input-1-b8a799d34f4a>, line 2)

Now we need to figure out where the XOR operation is happening. You'll need to use the shape of the plot and the following plot of the difference of means for each bit. Each colour represents a different bit. For the correct part on the plot, you should see a distinct separation, with some bits peaking above zero, and others peaking below zero. If you see some colours in between peaks, it is probably not the right location. Repeat this for a few bytes - the location should change. Note down this change, as you'll have to use it to adjust the analysis location later:

In [None]:
# Find averages and differences
diffs = []
for i in range(8):
    means = np.average(grouped_byte_traces[i][0], axis=0), np.average(grouped_byte_traces[i][1], axis=0)
    diffs.append(means[1] - means[0])

In [None]:
# Split traces into 2 groups
from bokeh.plotting import figure, show
from bokeh.io import output_notebook

output_notebook()
p = figure()

xrange = range(len(diffs[0]))
xrange2 = range(len(trace_array[0]))
colours = ["red", "blue", "green", "black"]
plt = cw.plot([])
for i in range(8):
    plt *= cw.plot(diffs[i]).opts(color=colours[i%4])
plt

Fill in the position of the XOR, as well as how much it changes for each byte. All we're doing here is going byte by byte and bit by bit, and seeing if the difference in means is greater than or less than zero:

In [None]:
btldr_IV = [0] * 16 #181
for byte in range(16):
    location = ??? + byte * ???
    iv = 0
    for bit in range(8):
        pt_bits = [((dr[i][byte] >> (7-bit)) & 0x01) for i in range(len(new_proj.traces))]

        # Split traces into 2 groups
        groupedPoints = [[] for _ in range(2)]
        for i in range(len(new_proj.traces)):
            groupedPoints[pt_bits[i]].append(trace_array[i][location])
            
        means = []
        for i in range(2):
            means.append(np.average(groupedPoints[i]))
        diff = means[1] - means[0]
        
        iv_bit = 1 if diff > 0 else 0
        iv = (iv << 1) | iv_bit
        
        print(iv_bit, end = " ")
        
    print("{:02X}".format(iv))
    btldr_IV[byte] = iv
    
print(btldr_IV)

Finally, we can do the full decryption!

In [None]:
cipher = AES.new(bytes(btldr_key), AES.MODE_ECB)
first_pt = cipher.decrypt(encrypted_firmware[:16])
first_pt = [first_pt[i] ^ btldr_IV[i] for i in range(16)]
print(bytearray(first_pt))

As you can see, the first line of firmware is `deadbeefaabbccddeeff0011`