## RSA

### imports

In [96]:
from math import gcd
from random import randint, sample

### utility functions

In [97]:
def generate_primes(pn:int = 1000):
	prime_flags = [True] * pn
	pi = 2 # prime iterator

	prime_flags[0] = False
	prime_flags[1] = False
	# Sieve of Eratosthenes
	while (pi*pi <= pn-1):
		if prime_flags[pi]:
			for i in range(pi * pi, pn, pi):
				prime_flags[i] = False
		pi += 1
	
	return set(i for i, flag in enumerate(prime_flags) if flag)

### RSA Class

In [98]:
class RSA:
	def __init__(self, a: int, b: int):
		self.a = a
		self.b = b
		self.n = a * b
		self.phi_n = (a - 1) * (b - 1)
		self.e = self.generate_public_key()
		self.d = self.generate_private_key()

	def generate_public_key(self):
		e = 2
		while e < self.phi_n and gcd(e, self.phi_n) != 1:
			e += 1
		"""or
		while gcd(e, self.phi_n) != 1:
			e = randint(2, self.phi_n)
		"""
		return e
	
	def generate_private_key(self):
		i = 1
		while (i * self.phi_n + 1) % self.e != 0:
			i += 1
		return (i * self.phi_n + 1) // self.e

	def encrypt(self, plaintext: int):
		if plaintext >= self.n:
			raise ValueError("Plaintext cannot be greater than the product of the chosen primes {} and {}".format(self.a, self.b))
		return pow(plaintext, self.e, self.n)
	
	def decrypt(self, ciphertext: int):
		return pow(ciphertext, self.d, self.n)

### I/O

In [99]:
PREDEFINED = False
primes = generate_primes()
a,b = sample(list(primes), 2)

if not PREDEFINED:
	a = int(input("Enter prime number 'a' (random: {}): ".format(a))) or a
	
	if a not in primes:
		raise ValueError('{} is not a prime number'.format(a))
	
	b = int(input("Enter prime number 'b' (random: {}): ".format(b))) or b
	if b == a:
		raise ValueError('Prime number b: {} cannot be the same as a: {}'.format(b, a))
	
	if b not in primes:
		raise ValueError('{} is not a prime number'.format(b))


In [100]:
rsa = RSA(a, b)
print("Primes:\t\t", rsa.a, ",", rsa.b)
print("n:\t\t", rsa.n)
print("phi(n):\t\t", rsa.phi_n)
print("e:\t\t", rsa.e)
print("d:\t\t", rsa.d)
print("Public Key:\t", "[e, n] = ", [rsa.e, rsa.n])
print("Private Key:\t", "[d, n] = ", [rsa.d, rsa.n])

Primes:		 709 , 97
n:		 68773
phi(n):		 67968
e:		 5
d:		 40781
Public Key:	 [e, n] =  [5, 68773]
Private Key:	 [d, n] =  [40781, 68773]


In [101]:
plaintext = int(input("Enter plaintext: "))
ciphertext = rsa.encrypt(plaintext)
print("Plaintext:", plaintext)
print("Ciphertext:", ciphertext)

Plaintext: 89
Ciphertext: 35714


In [102]:
result = rsa.decrypt(ciphertext)
print("Decrypted ciphertext: ", result)
assert result == plaintext

Decrypted ciphertext:  89
