In [2]:
import math
import random
 

# Generate prime numbers
def generate_primes(nbits):
    while True:
        p = random.getrandbits(nbits)
        if is_prime(p):
            return p 
# Check if prime for prime number generation
def is_prime(n):
    if n == 2:
        return True
    if n % 2 == 0 or n == 1:
        return False
    s = 0
    d = n - 1
    while d % 2 == 0:
        s += 1
        d //= 2
    for _ in range(10):
        a = random.randint(2, n - 2)
        x = pow(a, d, n)
        if x == 1 or x == n - 1:
            continue
        for _ in range(s - 1):
            x = pow(x, 2, n)
            if x == n - 1:
                break
        else:
            return False
    return True

# Compute greatest common divisor

def gcd(a, b):
    while b != 0:
        a, b = b, a % b
    return a

 

# Compute modular inverse for finding private key (found this can also be done with: pow(pubkey, -1, phi))
def mod_inv(a, m):

    if gcd(a, m) != 1:
        return None
    u1, u2, u3 = 1, 0, a
    v1, v2, v3 = 0, 1, m
    while v3 != 0:
        q = u3 // v3
        v1, v2, v3, u1, u2, u3 = (u1 - q * v1), (u2 - q * v2), (u3 - q * v3), v1, v2, v3

    return u1 % m

# Other option for defining public key
def pubKeyGen(p, q):
    #define phi with (p,q)
    phi = (p-1)*(q-1)
    # Calculate/return a value for e (public key)
    e = 2
    while (e < phi):
        if(gcd(e, phi) == 1):
            break
        else:
            e=e+1
    return e

In [3]:
# Generate RSA key pair

def generate_key_pair(nbits):
    p = generate_primes(nbits // 2)
    q = generate_primes(nbits // 2)
    n = p * q
    phi = (p - 1) * (q - 1)
    #found another function for finding public key -> pubKeyGen()
    e = pubKeyGen(p,q)
    d = mod_inv(e, phi)
    return ((e, n), (d, n))

In [4]:
# Encrypt data using public key

def encrypt(message, public_key):
    e, n = public_key
    m = int.from_bytes(message, 'big')
    c = pow(m, e, n)
    return c.to_bytes((c.bit_length() + 7) // 8, 'big')

 

# Decrypt data using private key

def decrypt(ciphertext, private_key):
    d, n = private_key
    c = int.from_bytes(ciphertext, 'big')
    m = pow(c, d, n)
    return m.to_bytes((m.bit_length() + 7) // 8, 'big')

## Example Test Case

In [7]:
# Example usage
import time
message = b"Hello World"
start = time.time()
public_key, private_key = generate_key_pair(1024)
ciphertext = encrypt(message, public_key)
plaintext = decrypt(ciphertext, private_key)
end = time.time()
print("Ciphertext:", ciphertext)
print("Plaintext:", plaintext.decode('utf-8'))
print("Duration: ", end - start)

Ciphertext: b'v\x89\xcb\x14bV\xa3\xdc\xf4\xb9]\xf8P\xaa\xe6]\xd9\x824?Fa\xc9o\x1d-\x9d\x11\xf1%\x83\x9fw\x00\x8b\xc6\x81\xd7P4\xe4K\xc7|\xf8\x021\xb0\xceP\x1c\xa7\xe4\x00'
Plaintext: Hello World
Duration:  7.364403009414673
