In [38]:

## imports

import numpy as np
import random


%run implementations/shors.ipynb


In [41]:
""" RSA:

Keypairs are generated from n-bit primes 'p' and 'q', specifically:

N = pq
tot = lcm(p-1, q-1)
e = number coprime to 'tot'
d = e ** -1 (mod tot)

Now, the public key is:

(N, e)

And the private key is:

(N, d)

To encrypt a message:
encrypt:
  message -> (message ** e) % N
  
decrypt:
  message -> (message ** d) % N
  
This works because encrypt->decrypt effectively is:

  message -> ((message ** d) ** e) % N
  = (message ** (d * e) % N)

And since the power series for any number is periodic mod 'tot', and d*e ==1 mod tot, that is the same as:

  = (message ** 1 % N) == message
  
So, we recover the message


"""

## UTILITY FUNCS

# greatest common denominator
def gcd(a, b):
    if b == 0:
        return a
    else:
        return gcd(b, a % b)

assert gcd(3, 9) == 3
    
# least common multiple
def lcm(a, b):
    return a * b // gcd(a, b)

# extended GCD 
def egcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)

# modular inverse of 'a' mod 'm'
# NOTE: This is much more efficient than 
def modinv(a, m):
    g, x, y = egcd(a, m)
    if g != 1:
        raise Exception('modinv(', a, ',', m, ') doesn\'t exist!')
    else:
        return x % m

assert modinv(2, 5) == 3
    

# generate keypairs
# returns (public, private) tuples of numbers to evaluate
def gen_keypair(n_bits=12):
    
    # whether or not a given number is a prime
    def is_prime(n):
        # hard coded examples
        if n == 2 or n == 3: return True
        if n < 2 or n % 2 == 0: return False
        if n < 9: return True
        if n % 3 == 0: return False
        
        # floor(sqrt(n))
        r = int(n**0.5)
        
        # try (6n-1) and (6n+1), as a wheel
        # up to the sqrt(n)
        i = 5
        
        while i <= r:
            if n % i == 0: return False
            if n % (i + 2) == 0: return False
            i += 6
            
        # nothing was divisible
        return True  

    # generate 2 prime numbers
    p, q = 0, 0
    
    while not is_prime(p):
        p = random.getrandbits(n_bits)
    while not is_prime(q):
        q = random.getrandbits(n_bits)
    
    # modulo
    N = p * q
    
    # totient(p-1, q-1)
    tot = lcm(p-1, q-1)
    
    # choose an exponent
    e = random.randint(2, tot - 1)
    
    while gcd(e, tot) != 1:
        e = random.randint(2, tot - 1)
        
    # d * e == 1 (mod tot)
    d = modinv(e, tot)
 
    # assert it is a modular inverse
    assert (d * e) % tot == 1
    
    # export keys
    pubkey = N, e
    prikey = N, d
    
    return pubkey, prikey

# encrypt from RSA
def rsa_enc(pubkey, data):
    # handle lists
    if type(data) in (list, tuple):
        return [rsa_enc(pubkey, d) for d in data]
    
    # unpack public key
    N, e = pubkey
    
    if data < 0 or data >= N:
        raise Exception("Encrypting a larger data piece than public key can handle!")

    # (data) ^ e (mod n)
    return pow(data, e, N)


# decrypt from RSA
def rsa_dec(prikey, data):
    # handle lists
    if type(data) in (list, tuple):
        return [rsa_dec(prikey, d) for d in data]
    
    # unpack public key
    N, d = prikey
    
    if data < 0 or data >= N:
        raise Exception("Decrypting a larger data piece than public key can handle!")

    # (data) ^ d (mod n)
    return pow(data, d, N)

pub, pri = gen_keypair(10)

""" decrypting message example
print (pub, pri)

msg = 5
print (msg)

msg_enc = rsa_enc(pub, msg)
print (msg_enc)

msg_dec = rsa_dec(pri, msg_enc)
print (msg_dec)
"""


# attempt #1 to crack an RSA public key:
# just brute forces the private key
def crack_rsa_1(pubkey):
    
    # unpack keys
    N, e = pubkey
    
    # attempt at 'p'
    p = int(N ** 0.5)
    
    # round up to a multiple of 6, and then
    # 6n-1
    p = 6 * (p // 6 + 1) - 1
    
    # solve for private key given a factor of 'N'
    def solve(p):
        
        # generate other factor
        q = N // p
        
        # make sure it works
        assert p * q == N
        
        # calculate totient
        tot = lcm(p-1, q-1)
        
        # reconstruct private value
        d = modinv(e, tot)
        
        # solve for private key, and return it
        return N, d
        
    while p > 1:
        if N % p == 0: return solve(p)
        if N % (p + 2) == 0: return solve(p + 2)
        p -= 6
        
    return None

# attempt 
def crack_rsa_quantum(pubkey):
    # unpack
    N, e = pubkey
    facs = shor_factor(N)
    
    # solve for private key given a factor of 'N'
    def solve(p):
        
        # generate other factor
        q = N // p
        
        # make sure it works
        assert p * q == N
        
        # calculate totient
        tot = lcm(p-1, q-1)
        
        # reconstruct private value
        d = modinv(e, tot)
        
        # solve for private key, and return it
        return N, d
        
    if 1 in facs:
        # no factors found
        return None
    else:
        return solve(facs[0])
    

assert (crack_rsa_1(pub) == pri)

print ("RSA Successful!")


RSA Successful!
