# Part 1, Topic 2: 1.5 Round AES Attack

---
NOTE: This lab references some (commercial) training material on [ChipWhisperer.io](https://www.ChipWhisperer.io). You can freely execute and use the lab per the open-source license (including using it in your own courses if you distribute similarly), but you must maintain notice about this source location. Consider joining our training course to enjoy the full experience.

---

**SUMMARY:** *In the last lab, we made a small modification that allowed us to mount our simple AES round skip fault. In the unmodified code, however, the lack of a check at the beginning of the loop prevented us from being able to skip the first round in AES. In this lab, we'll break the original code using an attack generated by https://github.com/cbouilla/AES-attacks-finder.*

**LEARNING OUTCOMES:**
* Identifying the end of the first AES round

In [None]:
SCOPETYPE = 'OPENADC'
PLATFORM = 'CWLITEARM'
CRYPTO_TARGET='TINYAES128C'

In [None]:
%%bash -s "$PLATFORM" "$CRYPTO_TARGET"
cd ../../hardware/victims/firmware/simpleserial-aes
make PLATFORM=$1 CRYPTO_TARGET=$2

In [None]:
%run "../Setup_Scripts/Setup_Generic.ipynb"

In [None]:
fw_path = "../../hardware/victims/firmware/simpleserial-aes/simpleserial-aes-{}.hex".format(PLATFORM)
cw.program_target(scope, prog, fw_path)

In [None]:
if PLATFORM == "CWLITEXMEGA":
    def reboot_flush():            
        scope.io.pdic = False
        time.sleep(0.1)
        scope.io.pdic = "high_z"
        time.sleep(0.1)
        #Flush garbage too
        target.flush()
else:
    def reboot_flush():            
        scope.io.nrst = False
        time.sleep(0.05)
        scope.io.nrst = "high_z"
        time.sleep(0.05)
        #Flush garbage too
        target.flush()

Again, we'll start by collecting a reference ciphertext and output:

In [None]:
scope.clock.adc_src = "clkgen_x1"
reboot_flush()
scope.arm()
target.simpleserial_write('p', bytearray([0]*16))
ret = scope.capture()
if ret:
    print("No trigger!")

wave = scope.get_last_trace()

output = target.simpleserial_read_witherrors('r', 16)
gold_ct = output['payload']

print(gold_ct)

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
plt.figure()
plt.plot(wave[:2000])
plt.show()

Identify a good glitch range. We're still attacking after the AddRoundKey operation, so look at the second of those.

In [None]:
glitch_loc = range(870, 920)

In [None]:
if scope._is_husky:
    scope.glitch.enabled = True
scope.glitch.clk_src = "clkgen"
scope.glitch.output = "clock_xor"
scope.glitch.trigger_src = "ext_single"
scope.glitch.repeat = 1
scope.io.hs2 = "glitch"
# These width/offset numbers are for CW-Lite/Pro; for CW-Husky, convert as per Fault 1_1:
scope.glitch.width = 3
scope.glitch.offset = -12.8
print(scope.glitch)

Again, we'll be using power analysis to determine when we've got a good glitch. This time you should see the first round being completed:

In [None]:
from tqdm.notebook import tqdm, trange
wave = None
import logging
ktp = cw.ktp.Basic()
logging.getLogger().setLevel(logging.ERROR)
reboot_flush()
for i in trange(min(glitch_loc), max(glitch_loc) + 1):
    scope.adc.timeout = 0.2
    scope.glitch.ext_offset = i
    ack = None
    while ack is None:
        target.simpleserial_write('k', ktp.next()[0])
        ack = target.simpleserial_wait_ack()
        if ack is None:
            reboot_flush()
            time.sleep(0.1)
    
    scope.arm()
    
    pt = bytearray([0]*16)
    target.simpleserial_write('p', pt)
    ret = scope.capture()
    if ret:
        reboot_flush() #bad if we accidentally didn't have this work
        time.sleep(0.1)
        print("timed out!")
        continue
    output = target.simpleserial_read_witherrors('r', 16, glitch_timeout = 1)
    if output['valid']:
        if output['payload'] != gold_ct:
            print("Glitched at {}".format(i))
            wave = scope.get_last_trace()
            break
    else:
        reboot_flush()
        
%matplotlib notebook
import matplotlib.pyplot as plt
plt.figure()
plt.plot(wave)
plt.show()

In [None]:
glitched_ct0 = bytearray(output['payload'])

In [None]:
glitched_ct0

As it turns out, there's actually two possible `glitched_ct0`s you could've gotten here. One is adding the second round key at the end of the encryption and the other is adding the 0th round key! This actually makes the attack a bit simpler (and much faster), as we can recover the key with one plaintext and both of the possible faulty ciphertexts.

Let's rerun our loop until we get another glitch that breaks out early, but is different than our first glitch:

In [None]:
from tqdm.notebook import tqdm, trange
wave = None
import logging
ktp = cw.ktp.Basic()
logging.getLogger().setLevel(logging.ERROR)
reboot_flush()
while True:
    scope.adc.timeout = 0.2
    scope.glitch.ext_offset = i
    ack = None
    while ack is None:
        target.simpleserial_write('k', ktp.next()[0])
        ack = target.simpleserial_wait_ack()
        if ack is None:
            reboot_flush()
            time.sleep(0.1)
    
    scope.arm()
    
    pt = bytearray([0]*16)
    target.simpleserial_write('p', pt)
    ret = scope.capture()
    if ret:
        reboot_flush() #bad if we accidentally didn't have this work
        time.sleep(0.1)
        print("timed out!")
        continue
    output = target.simpleserial_read_witherrors('r', 16, glitch_timeout = 1)
    if output['valid']:
        if output['payload'] != gold_ct:
            if output['payload'] != glitched_ct0:
                print("Glitched at {}".format(i))
                wave = scope.get_last_trace()
                break
    else:
        reboot_flush()
        


In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
plt.figure()
plt.plot(wave)
plt.show()

In [None]:
print(output['payload'])
glitched_ct1 = bytearray(output['payload'])

We can now feed it into our autogenerated solver. `Known` is just a big array of known information (includes both ciphertexts and the plaintext, which is just all 0's in our case):

In [None]:
Known = [0] * 49
Known[0] = 1

for i in range(4):
    for j in range(4):
        Known[48-(i*4+j)] = glitched_ct0[i+j*4]
        Known[16-(i*4+j)] = glitched_ct1[i+j*4]
        
from out2 import MakeTableMul2_8, Attack

MakeTableMul2_8()
kguess = Attack(Known)

If you don't get a solution, you'll probably just need to swap `glitched_ct0` and `glitched_ct1` (the order does matter, but we don't know which is which). We can print our solution in a nicer format with the following:

In [None]:
for i in range(4):
    print([hex(kguess[i][j]) for j in range(4)])

## Variations on this Attack

This attack is pretty specific to this implementation. What if, for example, we were able to break out at the same spot, but only the 2nd round key was used for the final AddRoundKey? In that case, we could run an attack with two plaintext faulted ciphertexts pairs, though it is much slower than the attack we did. Different situations will have different requirements, though attacks quickly become infeasible the further you get into AES (as you'd expect). The tool we used to generate the attack for this situation is available at https://github.com/cbouilla/AES-attacks-finder. If you're curious, try thinking up different glitch scenarios and use the tools to see if the attacks are possible! The tool outputs C code. In the case of this attack, the generated C code was ported to python with the help of `ctopy`.