<div style="text-align: right">Paul Novaes<br>August 2018</div> 

# RSA-3

The goal of this notebook is to present RSA in the special (and common) case where the exponent of the public key is $e=3$.

This simplifies quite a bit the implementation and some of the math. In addition, we will see that it stresses the fact that RSA is based on the difficulty of computing the $k$-th root of a number modulo $n$. This difficulty is a bit surprising because the problem is easy over regular integers.

Moreover, in this case, the public key is just $n$ and the secret key is just $(p, q)$ where $n=pq$. That is, it shows that the real secret is not any magic exponent, but the fact that one of the parties knows the factorization of $n$.

## Public-Key Cryptography

A cryptosystem allows 2 people to communicate in a way that their conversation remains private even in the presence of an eavesdropper.

Anybody, for example Alice, who wants to receive private messages produces 2 keys:
* a public key $pk$, that she publishes for anybody to see (and use)
* a secret key $sk$, that she keeps to herself

Anybody, for example Bob, who wants to send a message $m$ to Alice uses an encryption function $E$, and the public-key $pk$ to encrypt the message:

$$c=E(m, pk)$$

$E$ is such than from $c$, it is very difficult to get $m$. $E$ is a __one-way function__, easy to compute but difficult to invert.

Alice uses a decrypting function $D$ and her private key $sk$, to decrypt c:

$$m = D(c, sk)$$


## RSA-3 Cryptosystem

In RSA-3, Alice chooses a number $n=pq$, product of 2 distinct primes of the form $6k + 5$.
* the public key is $pk=n$
* the secret key is $sk=(p, q)$

The encrypting and decrypting functions are defined by:
* $E_n(m) = m^3 \bmod n$
* $D_n(c) = \sqrt[3] c \bmod n$ 

In fact we will show that $\sqrt[3] c \bmod n = c^{(2\phi(n) + 1)/{3}} \bmod n$, where $\phi(n) = (p-1)(q-1)$

RSA-3 relies on the fact that it is easy for Alice, who knows $(p,q)$, to compute the cubic root of the cypher, but not for Bob.

## Proof

Let's show that $$D_n(E_n(m)) \equiv_n m$$

This is obviously true if $m \equiv_n 0$. So let's assume that $m \neq 0$, mod $n$.

We have $D_n(E_n(m)) \equiv_n m^{2\phi(n) + 1}$. Using the Chinese Remaining Theorem (CRT) and then Fermat's Little Theorem ($m^{\phi(n)} \equiv_p 1$), we have:

$$D_n(E_n(m)) \equiv_p m^{2\phi(n) + 1} \equiv_p m.m^{\phi(n)}.m^{\phi(n)} \equiv_p m$$

Therefore for any $m$, $D_n(E_n(m)) \equiv_p m$ and similarly $D_n(E_n(m)) \equiv_q m$.

Using CRT again, $$D_n(E_n(m)) \equiv_n m$$

## Implementation

To test whether a number is (probably) prime we will use Miller-Rabin primality test:

In [1]:
from random import *

def is_miller_rabin_witness(a, n):
    if pow(a, n - 1, n) != 1:
        return True
    k = n - 1
    while k % 2 == 0:
        k //= 2
        res = pow(a, k, n)
        if res == -1 + n:
            return False
        if res != 1:
            return True
    return False

def is_probable_prime(n):
    for i in range(50):
        a = randint(1, n - 1)
        if is_miller_rabin_witness(a, n):
            return False
    return True

Generating a random modulo $n$:

In [2]:
def random_6n_5_prime(lo, hi):
    while True:
        n = randint(lo, hi)
        n = 6 * (n // 6) + 5
        if (n < lo):
            continue
        if is_probable_prime(n):
            return n

def generate_random_mod(lo, hi):
    while True:
        p = random_6n_5_prime(lo, hi)
        q = random_6n_5_prime(lo, hi)
        if p != q:
            break
    return p, q

Encrypting and decrypting functions:

In [3]:
def encrypt(m, n):
    return pow(m, 3, n)

def decrypt(m, p, q):
    phi = (p - 1)*(q - 1)
    d = (2 * phi + 1) // 3
    return pow(m, d, p*q)

## $p, q = 5, 11$

In [4]:
p, q = 5, 11
n = p * q
print('        message          cypher         decoded')
print()
for m in range(n):
    c = encrypt(m, n)
    d = decrypt(c, p, q)
    assert(m == d)
    if c == d:
        print("%15d %15d %15d =" % (m, c, d))   
    else:
        print("%15d %15d %15d " % (m, c, d))   

        message          cypher         decoded

              0               0               0 =
              1               1               1 =
              2               8               2 
              3              27               3 
              4               9               4 
              5              15               5 
              6              51               6 
              7              13               7 
              8              17               8 
              9              14               9 
             10              10              10 =
             11              11              11 =
             12              23              12 
             13              52              13 
             14              49              14 
             15              20              15 
             16              26              16 
             17              18              17 
             18               2              18 
             19 

## Fixed Points

In [5]:
def print_fixed_point(p, q):
    n = p * q
    print('        message          cypher         decoded')
    print()
    for m in range(n):
        c = encrypt(m, n)
        d = decrypt(c, p, q)
        assert(m == d)
        if c == d:
            print("%15d %15d %15d =" % (m, c, d))       

p, q = 101, 107
print_fixed_point(p, q)

        message          cypher         decoded

              0               0               0 =
              1               1               1 =
           1818            1818            1818 =
           1819            1819            1819 =
           3637            3637            3637 =
           7170            7170            7170 =
           8988            8988            8988 =
           8989            8989            8989 =
          10806           10806           10806 =


In [6]:
p, q = 1013, 1019
print_fixed_point(p, q)

        message          cypher         decoded

              0               0               0 =
              1               1               1 =
         172210          172210          172210 =
         172211          172211          172211 =
         344421          344421          344421 =
         687826          687826          687826 =
         860036          860036          860036 =
         860037          860037          860037 =
        1032246         1032246         1032246 =


In [7]:
p, q = generate_random_mod(10**4, 10**5)
n = p * q
print('        message          cypher         decoded')
print()
for m in range(10**9, 10**9 + 10):
    c = encrypt(m, n)
    d = decrypt(c, p, q)
    print("%15d %15d %15d " % (m, c, d)) 

        message          cypher         decoded

     1000000000      6497629787      1000000000 
     1000000001      6241139314      1000000001 
     1000000002      4652135864      1000000002 
     1000000003      1730619443      1000000003 
     1000000004      4809103040      1000000004 
     1000000005      6555073678      1000000005 
     1000000006      6968531363      1000000006 
     1000000007      6049476101      1000000007 
     1000000008      3797907898      1000000008 
     1000000009       213826760      1000000009 


In [8]:
p, q = generate_random_mod(10, 3000)
n = p * q
for m in range(n):
    c = encrypt(m, n)
    mm = decrypt(c, p, q)
    assert(m == mm)
    if m == c:
        print("%10d %10d %10d " % (m, c, mm))

         0          0          0 
         1          1          1 
      6594       6594       6594 
    225549     225549     225549 
    225550     225550     225550 
    232143     232143     232143 
    232144     232144     232144 
    451099     451099     451099 
    457692     457692     457692 
