In [675]:
import math

def gcd_extended(a, b):
    if b == 0:
        return a, 1, 0
    else:
        gcd, x1, y1 = gcd_extended(b, a % b)
        x = y1
        y = x1 - (a // b) * y1
        return gcd, x, y


def chinese_remainder(primes, a):
    s, mod = 0, math.prod(primes)
    for p_i, a_i in zip(primes, a):
        n = mod // p_i
        s += a_i * gcd_extended(n, p_i)[1] * n
    return s % mod

def fast_pow(base: int, exp: int, mod: int) -> int:
    result = 1
    base = base % mod
    while exp > 0:
        if exp % 2 == 1:
            result = (result * base) % mod
        exp //= 2
        base = (base * base) % mod
    return result

In [676]:
from Crypto.Util import number
from dataclasses import dataclass


@dataclass
class KeyPair:
    p: int
    q: int
    public_key: tuple[int]
    private_key: tuple[int]


class KeyGenerator:
    def __init__(self, prime_length: int):
        self._prime_length = prime_length

    def generate(self) -> KeyPair:
        p, q = self._generate_primes()
        n = p * q
        d, e = self._calc_d_e(p, q)
        return KeyPair(p=p, q=q, public_key=(e, n), private_key=(d, n))

    def _generate_primes(self) -> tuple[int, int]:
        p, q = self._get_prime(self._prime_length), self._get_prime(self._prime_length)
        while p == q or self._are_too_close(p, q):
            p, q = self._get_prime(self._prime_length), self._get_prime(self._prime_length)
        return p, q

    def _calc_d_e(self, p: int, q: int) -> tuple[int, int]:
        fi = (p - 1) * (q - 1)
        e = self._get_prime(fi.bit_length() - 1)
        _, x, _ = gcd_extended(e, fi)
        d = x % fi
        return d, e

    @staticmethod
    def _get_prime(n_length: int) -> int:
        return number.getPrime(n_length)
    
    @staticmethod
    def _are_too_close(p: int, q: int) -> bool:
        return abs(p - q) < 2 ** (p.bit_length() // 2)

In [677]:
class RSA:
    def __init__(self, prime_length: int = 512):
        self._keys = KeyGenerator(prime_length).generate()

    def encrypt(self, plaintext: str) -> list[int]:
        ascii_plaintext = list(map(ord, plaintext))
        return [self._encrypt_char(c) for c in ascii_plaintext]

    def _encrypt_char(self, ascii_char: int) -> int:
        e, n = self._keys.public_key
        return fast_pow(ascii_char, e, n)

    def decrypt(self, ciphertext: list[int]) -> str:
        decrypted_chars = [self._decrypt_char(c) for c in ciphertext]
        return ''.join(map(chr, decrypted_chars))

    def _decrypt_char(self, encrypted_char: int) -> int:
        d, _ = self._keys.private_key
        p, q = self._keys.p, self._keys.q
        m1 = pow(encrypted_char, d, p)
        m2 = pow(encrypted_char, d, q)
        m = chinese_remainder([p, q], [m1, m2])
        return m

In [678]:
rsa = RSA()
plain_text = "FBIT 336959"
ciphertext = rsa.encrypt(plain_text)
print(ciphertext)

[34325392568281786678658393419578493361995203779330378730710935020457221028853525848128744054965207982454700743943817739346159409370975240948726133120047282215117755233239551324053203303594153318454122748182671982673611860759339045793799694199757958958651933780495567296151344400977949281359670221411218404284, 44055421317617174608456190724354894421309129367988771162745472376533199153658473543398602255729970940899346504271414098835832967042031051067239533335634705754735776188075108152991147893168193108032696819189518201626867409925262402901232493417594995556693752293494459144731709205512665636306203713395906876687, 28021480902551310383558218328396726909342842701108250406381391702377971000442650096631363667882470476079909238118755119268733358878506643773357298615557883485822301406037286885416780857502253478518500840101639114127866436038273012675019779845078251267856176024074936920107538263967416654796595537417061455014, 159109068580836793535845245285668687305019312979053612091103059294588

In [679]:
plain_text = rsa.decrypt(ciphertext)
plain_text

'FBIT 336959'