In [92]:
import sys
from sympy import prime, sec
from sympy import primitive_root as pr
import secrets
from sympy import discrete_log



def primitive_root(p: int) -> int:
    g = pr(p)
    if g == None:
        sys.exit(f"no primitive_root found for prim {p}")
    return g

In [93]:
def crypt_param(p):
    g = primitive_root(p)
    x = secrets.randbelow(p - 1) + 1 # Private Key
    y = secrets.randbelow(p - 1) + 1
    return (g, x, y)


In [94]:
def encrypt(m: int, g: int, p: int, y: int, x: int) -> tuple[int, int, int, int, int]:
    h = pow(g, x, p)
    c1 = pow(g, y, p)
    c2 = (pow(h, y, p) * pow(g, m, p)) % p # the message must be in the exponent
    return (c1, c2, p, g, x)

In [95]:
def decrypt(c1, c2, p, g, x) -> int:
    # s_inverse = pow(c1, p - 1 - x, g)
    # return (c2 * s_inverse) % p
        # Compute s and its modular inverse
    s_inverse = pow(c1, -x, p)  # Modular inverse of s
    decrypted_value = (c2 * s_inverse) % p

    # Extract the exponent using discrete logarithm
    m = discrete_log(p, decrypted_value, g)
    return m

In [96]:
p: int = 2760727302517
g1, x1, y1 = crypt_param(p)
# g2, p2, y2, x2 = crypt_param(p)

# m = 0x68616c6c6f2077656c74
m1 = 345
m2 = 10

cipher1 = encrypt(m1, g1, p, y1, x1)
cipher2 = encrypt(m2, g1, p, y1, x1) # encrypt(m2, g2, p2, y2, x2)

In [97]:
c1_combined = (cipher1[0] * cipher2[0]) % p
c2_combined = (cipher1[1] * cipher2[1]) % p

decrypted_sum = decrypt(c1_combined, c2_combined, p, g1, x1)

print(f"Original messages: m1 = {m1}, m2 = {m2}")
print(f"Ciphertext 1: {cipher1}")
print(f"Ciphertext 2: {cipher2}")
print(f"Combined ciphertext: (c1 = {c1_combined}, c2 = {c2_combined})")
print(f"Decrypted sum: {decrypted_sum}")

assert decrypted_sum == m1 + m2, "error"

Original messages: m1 = 345, m2 = 10
Ciphertext 1: (55, 55)
Ciphertext 2: (55, 400)
Combined ciphertext: (c1 = 320, c2 = 360)
Decrypted sum: 355
