In [122]:
from math import floor
from typing import Tuple
from Crypto.Util.number import getPrime


def mod_inverse(a, m):
    gcd, x, _ = egcd(a, m)
    assert gcd == 1, f"the modular multiplicative inverse does not exist for a: {a}, m: {m}"
    return x


def egcd(a, b):
    """
    Algorithm outlined at https://en.m.wikipedia.org/wiki/Extended_Euclidean_algorithm

    NOTE: must satisfy |a| > |b|
    """
    r_prev = a
    r = b
    s_prev = 1
    s = 0
    t_prev = 0
    t = 1
    while True:
        q = floor(r_prev / r)
        r_next = r_prev - q * r
        s_next = s_prev - q * s
        t_next = t_prev - q * t
        if r_next == 0:
            return r, s, t
        r_prev = r
        r = r_next
        s_prev = s
        s = s_next
        t_prev = t
        t = t_next


def generate_keys(prime_size: int = 512): # prime_size > 512 results in float overflow in egcd()
    e = 65537
    while True:
        p = getPrime(prime_size)
        q = getPrime(prime_size)
        totient = (p - 1) * (q - 1)
        if egcd(e, totient)[0] == 1:  # ensure e is coprime with totient
            break
    
    n = p * q
    d = mod_inverse(e, totient)  # modular multiplicative inverse of e mod totient
    return (n, e), (n, d)


def encrypt(message: str, public_key: Tuple[int, int]) -> int:
    encoded_message = int.from_bytes(message.encode())
    n, e = public_key
    return pow(encoded_message, e, n)


def decrypt(ciphertext: int, private_key: Tuple[int, int]) -> str:
    n, d = private_key
    message = pow(ciphertext, d, n)
    return message.to_bytes((message.bit_length() + 7) // 8).decode()


message = "Hello, world!"
public_key, private_key = generate_keys()
ciphertext = encrypt(message, public_key)
recovered_message = decrypt(ciphertext, private_key)
recovered_message

'Hello, world!'

In [123]:
from random import randint


# calculate a random number to the power of the public key modulo n
k = randint(1, public_key[0] - 1)
k_e = pow(k, public_key[1], public_key[0])
# calculate the altered ciphertext
c_prime = (ciphertext * k_e) % public_key[0]
m_prime = pow(c_prime, private_key[1], public_key[0])
# calculate the unaltered message
unaltered_message = (m_prime * mod_inverse(k, public_key[0])) % public_key[0]
unaltered_message.to_bytes((unaltered_message.bit_length() + 7) // 8).decode()

'Hello, world!'

In [125]:
def sign(message: int, private_key: Tuple[int, int]) -> int:
    n, d = private_key
    return pow(message, d, n)


def verify(signature: int, message: int, public_key: Tuple[int, int]) -> bool:
    n, e = public_key
    return message == pow(signature, e, n)

# message1 and message2 are encrypted and signed
message1 = encrypt("First message", public_key)
message2 = encrypt("Second message", public_key)
signed_message1 = sign(message1, private_key)
signed_message2 = sign(message2, private_key)

# message3 is the product of the two messages
message3 = (message1 * message2) % public_key[0]
# s3 = s1 * s2 mod n
# signed_message3 = sign(message1 * message2, private_key)
signed_message3 = signed_message1 * signed_message2 % public_key[0]

# verify the signed messages
verified_message1 = verify(signed_message1, message1, public_key)
verified_message2 = verify(signed_message2, message2, public_key)
verified_message3 = verify(signed_message3, message3, public_key)
print(verified_message1, verified_message2, verified_message3)

True True True
