# SPA against RSA on XMEGA (8-bit implementation)

Supported setups:

SCOPES:

* OPENADC

PLATFORMS:

* CWLITEXMEGA or CW303

Note this *only* works with an XMEGA target. This tutorial does not work with any other platforms. This is because the RSA implementation in use is `avr-crypto-lib`, which is has AVR assembly code to accelerate certain routines. A later tutorial will demonstrate a similar (but not *exactly* the same) attack on MBED-TLS RSA implementation.

In [None]:
SCOPETYPE = 'OPENADC'
PLATFORM = 'CWLITEXMEGA'
CRYPTO_TARGET = 'AVRCRYPTOLIB'

In [None]:
%%bash -s "$PLATFORM" "$CRYPTO_TARGET"
cd ../../../hardware/victims/firmware/simpleserial-rsa
make PLATFORM=$1 CRYPTO_TARGET=$2

In [None]:
%run "../../Setup_Scripts/Setup_Generic.ipynb"

In [None]:
fw_path = '../../../hardware/victims/firmware/simpleserial-rsa/simpleserial-rsa-CWLITEXMEGA.hex'

In [None]:
cw.program_target(scope, prog, fw_path)

## Communicating With Target and Testing

In [None]:
scope.clock.adc_src = "clkgen_x1"
scope.adc.samples = 10000

In [None]:
import matplotlib.pylab as plt
import matplotlib
import numpy as np


In [None]:
def capture_RSA_trace(scope, target, text):
    scope.arm()
    target.simpleserial_write('p', text)
    
    ret = scope.capture()
    if ret:
        return ('Timeout happened')

    #liefert #samples zw trigger
    if SCOPETYPE == "OPENADC":
        print('#samples: ', scope.adc.trig_count)


    target.simpleserial_wait_ack()
    return scope.get_last_trace()

### Breaking RSA 1

In [None]:
key = bytearray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x00, 0xab])
trace = capture_RSA_trace(scope, target, key)

%matplotlib ipympl
plt.figure()

plt.plot(trace, 'r')

In [None]:
%matplotlib ipympl
plt.figure()

#ref_trace = trace[3600:4100]
ref_trace = trace [3000:3600]
plt.plot(ref_trace, 'b')

In [None]:
import numpy as np
start = 3000
rsa_one = trace[start:(start+600)]      
diffs = []
for i in range(0, len(trace)-len(rsa_one)):
    diff = trace[i:(i+len(rsa_one))] - rsa_one    
    diffs.append(np.sum(abs(diff)))
    
plt.figure()
plt.plot(diffs)
plt.title('SAD Match for RSA')
plt.ylabel('SAD Difference')
plt.xlabel('Offset')

In [None]:
#alle zeitpunkte mit sad < 33
times = np.where(np.array(diffs) < 34.0)[0]
print(len(times), 'times: ', times)

# ignoriere ausschläge vor zeitpunkt 2000
times = times[np.where(times >= 2000)]
print(len(times), 'times: ', times)

In [None]:
deltalist = []
i = 0
while i < len(times)-1:
    y = i+1
    delta = times[y] - times[i]
    if delta < 2: # werte liegen zu nahe beinander -> selber ausschlag 
        i += 1
        continue
    #print(delta, times[i], times[y])
    deltalist.append(delta)
    i +=1

print(len(deltalist), 'deltas: ', deltalist)

#deltas = [x for x in deltalist if x > 2]
#print(len(deltalist), 'deltas: ', deltalist)


In [None]:
threshold = 220
add_to_last_delta = 600
add_to_first_delta = 150
deltas = []

# wenn 2 ausschläge zu einem bit (dem 1 Bit) gehören, wähle den ersten 
# addiere den kleinen wert zu nächsen großen 
# wähle immer 0 1 und nicht 1 0 
i = 0
while i < len(deltalist):
    if deltalist[i] < threshold:
        if i + 1 < len(deltalist):
            deltas.append(deltalist[i] + deltalist[i+1])
        else: # letzte wert ist 'klein' -> last bit ist 1 -> mache diesen zu einem 'großen' 
            deltas.append(deltalist[i] + add_to_last_delta)
        i +=2
    else:
        deltas.append(deltalist[i])
        i += 1

#erstes bit wird als 0 erkannt, ist aber immer 1
deltas[0] = deltas[0] + add_to_first_delta

print(len(deltas), 'merged deltas: ', deltas)

In [None]:
plt.figure()
plt.plot(deltas, range(0, len(deltas)), 'or')
plt.grid(True)
plt.title('A Learned Comparison of RSA Execution Time')
plt.ylabel('Processing Bit Number')
plt.xlabel('Time Delta (based on SAD Match)')

In [None]:
key = "" 
for i in range(0, len(deltas)): 
    if deltas[i] > 750: 
        key += "1"
    else: 
        key += "0"

print(key)
print("%04X"%int(key, 2))

key = ""
times = np.where(np.array(diffs) < 23)[0]
times = times[np.where(times >= 2000)]

for i in range(0, len(times)-1):
    delta = times[i+1] - times[i]
    print(delta)
    if delta < 2: 
        continue
    if delta > 800:
        key += "1"
    else:
        key += "0"
key += "0"    
print(key)
print("%04X"%int(key, 2))

## Breaking RSA
Now that we have such a target we can get power traces from, how to break RSA? The easiest way is actually with a "single-trace" attack. Let's capture a single RSA trace here:

In [None]:
key = bytearray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x00, 0xab])
trace = capture_RSA_trace(scope, target, key)

In [None]:
%matplotlib ipympl
plt.figure()

plt.plot(trace, 'r')

In [None]:
%matplotlib ipympl
plt.figure()

#ref_trace = trace[3600:4100]
ref_trace = trace [3000:3600]
plt.plot(ref_trace, 'b')

In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
from IPython.display import display
import numpy as np
#%matplotlib inline

@widgets.interact(offset=(0, len(trace)))
def plotsad(offset=3000):
    plt.figure()
    plt.plot(trace, 'r')
    plt.plot(range(offset, offset+len(ref_trace)), ref_trace, 'b', alpha=0.6)
    plt.show()
    
    plt.figure()
    diff = ref_trace-trace[offset:(offset+len(ref_trace))]
    plt.plot(range(offset, offset+len(ref_trace)), diff, 'g', alpha=0.6)
    plt.show()
    print(np.sum(abs(diff)))
    

In [None]:
import numpy as np
start = 3000
rsa_one = trace[start:(start+600)]      
diffs = []
for i in range(0, len(trace)-len(rsa_one)):
    diff = trace[i:(i+len(rsa_one))] - rsa_one    
    diffs.append(np.sum(abs(diff)))
    
plt.figure()
plt.plot(diffs)
plt.title('SAD Match for RSA')
plt.ylabel('SAD Difference')
plt.xlabel('Offset')

In [None]:
#og version

times = np.where(np.array(diffs) < 20.0)[0]
print(len(times), 'times: ', times)
deltalist = []
for i in range(0, len(times)-1):
    delta = times[i+1] - times[i]
    deltalist.append(delta)
print(len(deltalist), 'deltas: ', deltalist)

In [None]:
#optimierte? version

times = np.where(np.array(diffs) < 21)[0]
print(len(times), 'times: ', times)
deltalist = []
i = 0
while i < len(times)-1:
    y = i+1
    delta = times[y] - times[i]
    print(delta, times[i], times[y])
    #add 
    if delta < 300:
        y = y + 1
        if y > len(times)-1:
            break
        delta = times[y] - times[i]
        print('N: ', delta, times[i], times[y])
        i+=1
    i+=1
        
    deltalist.append(delta)

print(len(deltalist), 'deltas: ', deltalist)

And we can then plot the time deltas:

In [None]:
plt.figure()
plt.plot(deltalist, range(0, len(deltalist)), 'or')
plt.grid(True)
plt.title('A Learned Comparison of RSA Execution Time')
plt.ylabel('Processing Bit Number')
plt.xlabel('Time Delta (based on SAD Match)')

In [None]:
import numpy as np
key = ""
times = np.where(np.array(diffs) < 23.0)[0]

i = 0
while i < len(times)-1:
    y = i+1
    delta = times[y] - times[i]
    if delta < 300:
        y = y + 1
        if y > len(times)-1:
            i+=1
            continue
        delta = times[y] - times[i]
        i+=1
    i+=1
    if delta > 800:
        key += "1"
    else:
        key += "0"
    
key += "0"
print(key)
print("%04X"%int(key, 2))

In [None]:
import numpy as np
key = ""
times = np.where(np.array(diffs) < 21)[0]
for i in range(0, len(times)-1):
    delta = times[i+1] - times[i]
    if delta > 800:
        key += "1"
    else:
        key += "0"
key += "0"    #warum? letztes bit nicht erkennbar -> just add 0
print(key)
print("%04X"%int(key, 2))

Hopefully that recovered the encryption key you set earlier! The last caveat is the *last bit* isn't recovered. Can you figure out a way to recover it? Why isn't it recovered?

## Conclusion

This tutorial has demonstrated the use of the power side-channel for performing RSA 8-bit attacks. We attacked it both using a SAD match to find the interesting points, and by performing a bandwidth-specific filter to make it more obvious when sections are 1 vs 0.

In [None]:
scope.dis()
target.dis()

## Tests

In [None]:
assert int(key_filt, 2) == 0x8AB0, "Failed to break key with filter, adjust maximum"

In [None]:
assert int(key, 2) == 0x8AB0, "Failed to break key with SAD Match"