# Timing Attack on RSA with Montgomery Multiplication
Attacking data-dependent final subtraction that leads to timing side channel

In [None]:
import gmpy2
import random
import numpy as np

In [None]:
random.seed(2023)
L = 64     # bit length of modulus
R = 1 << L # Montgomery domain scaling factor 
# generate RSA modulus n
n = 0
while n.bit_length() != L:
    p = int(gmpy2.next_prime(random.randint(1<<(L//2-1), 1<<(L//2)))) # L-1 bit random prime
    q = int(gmpy2.next_prime(random.randint(1<<(L//2-1), 1<<(L//2))))
    n = p*q
print(f"{R = }, {n = }, {p = }, {q = }, {n.bit_length() = }")
e = 31 # public exponent
d = pow(e, -1, (p-1)*(q-1)) # private exponent
ni = R - pow(n, -1, R) # N' = negative inverse of n mod R
R2 = R * R % n # R^2 mod n
print(f"{e = }, {d = }, {ni = }, {R2 = }")

# Montgomery reduction
def REDC(T):
    m = ((T%R)*ni)%R
    t = (T + n*m)//R
    if t >= n:
        t = t - n
    return t

# Instrumented Montgomery reduction
# returns also if final subtraction occurred
def REDC2(T):
    m = ((T%R)*ni)%R
    t = (T + n*m)//R
    if t >= n:
        return (t - n, 1)
    return (t, 0)

# Montgomery multiplication
def MonMult(a, b):
    return REDC(a*b)

# Instrumented Montgomery multiplication
# returns also if final subtraction occurred
def MonMult2(a, b):
    return REDC2(a*b)

# Square and multiply 
def sm(a, d):
    aa = MonMult(a, R2)
    x = aa
    k = d.bit_length()
    for i in reversed(range(k-1)):
        x = MonMult(x, x)
        if d & (1<<i):
            x = MonMult(x, aa)
    return MonMult(x, 1)

# Instrumented Square and multiply
# returns also the number of final subtractions
def sm2(a, d):
    cnt = 0
    aa = MonMult(a, R2)
    x = aa
    k = d.bit_length()
    for i in reversed(range(k-1)):
        x, t = MonMult2(x, x)
        cnt += t
        if d & (1<<i):
            x, t = MonMult2(x, aa)
            cnt += t
    return MonMult(x, 1), cnt
sm(7, 5) == pow(7, 5, n)


## Prepare and measure 2 servers

In [None]:
print(f"{d.bit_length() = }")
d0 = d & ~(1<<(d.bit_length()-2))
d1 = d |  (1<<(d.bit_length()-2))
print(f"Original d   {d:b}")
print(f"Server 0: d0 {d0:b}")
print(f"Server 1: d1 {d1:b}")

In [None]:
number_of_messages = 10000
msgs = np.array([random.randint(0, n-1) for i in range(number_of_messages)], dtype=object)
print("Sending messages to server 0")
times0 = np.vectorize(lambda c: sm2(c, d0)[1])(msgs)
print("Sending messages to server 1")
times1 = np.vectorize(lambda c: sm2(c, d1)[1])(msgs)
print("Responses received and times recorded")

## Attacking Multiplication

In [None]:
# Oracle about the presence of a final subtraction in c^2 * c (in the multiplication)
def orak(c):
    cc = MonMult(c, R2) # convert to MD
    tmp = MonMult(cc, cc) # Square c
    tmp, t = MonMult2(tmp, cc) # Multiply and get final subtraction
    return t

In [None]:
print("Computing oracle about multiplication final subtraction")
oo = np.vectorize(orak)(msgs)
oo.shape

In [None]:
# Attack server 0
F10 = times0[oo == 1]
F20 = times0[oo == 0]
print(F10.shape, F20.shape)
np.mean(F10) - np.mean(F20)
# Low difference of means means d_{k-2} = 0
# Can you tell?

In [None]:
F11 = times1[oo == 1]
F21 = times1[oo == 0]
print(F11.shape, F21.shape)
np.mean(F11) - np.mean(F21)
# High difference of means means d_{k-2} = 1
# Can you tell?

## Attacking Squaring
Complete the attack on squaring on both servers and show that you've discovered a secret bit of their private exponents

In [None]:
def ora1(c):
    cc = MonMult(c, R2)
    tmp = MonMult(cc, cc)
    # ...
    return t
def ora2(c):
    cc = MonMult(c, R2)
    # ...
    return t

In [None]:
oo1 = np.vectorize(ora1)(msgs)
# oo2 = ...
F10 = times0[oo1 == 1]
# ...

In [None]:
# ...