In [34]:
import random
from hashlib import sha256

p = 6
g = 2

# Secret keys
x_official = random.randint(2, p - 2)
x_hidden = random.randint(2, p - 2)

# Public key generated once
private_r = random.randint(2, p - 2)
public_key = pow(g, private_r, p)

def derive_key(shared_secret):
    return sha256(str(shared_secret).encode()).digest()

def xor_encrypt(msg: bytes, key: bytes):
    return bytes([m ^ k for m, k in zip(msg, key)])

def encrypt(message_official: str, message_hidden: str):
    # Combine messages separated by '|'
    combined = message_official + "|" + message_hidden
    combined_bytes = combined.encode()

    # Ephemeral key
    r = random.randint(2, p - 2)
    c1 = pow(g, r, p)

    # Shared secret: g^(r * private_r) mod p
    shared_secret = pow(public_key, r, p)
    key = derive_key(shared_secret)

    # Ensure key is long enough
    key = (key * ((len(combined_bytes) // len(key)) + 1))[:len(combined_bytes)]

    ciphertext = xor_encrypt(combined_bytes, key)

    return {
        'c1': c1,
        'ciphertext': ciphertext
    }

def derive_shared_secret(c1, x):
    return pow(c1, x, p)

def decrypt(cipher, x):
    shared_secret = derive_shared_secret(cipher['c1'], x)
    key = derive_key(shared_secret)

    ciphertext = cipher['ciphertext']
    key = (key * ((len(ciphertext) // len(key)) + 1))[:len(ciphertext)]

    decrypted_bytes = xor_encrypt(ciphertext, key)
    decrypted_text = decrypted_bytes.decode(errors='ignore')  # ignore decode errors

    # Extract messages
    parts = decrypted_text.split('|')
    if len(parts) == 2:
        # Official key expects first part
        if x == x_official:
            return parts[0].strip()
        # Hidden key expects second part
        elif x == x_hidden:
            return parts[1].strip()
    return decrypted_text

# Demo
message_official = "hail hitler"
message_hidden = "kill hitler"

cipher = encrypt(message_official, message_hidden)

print("Ciphertext:", cipher['ciphertext'])
print("\nOfficial Receiver Decrypts:")
print(decrypt(cipher, x_official))
print("\nHidden Receiver Decrypts:")
print(decrypt(cipher, x_hidden))

Ciphertext: b'#C\x1e\x1b\xf4\xb5v\xb2p\n\xfa3#\rqn\x94\xb9H\xa7\x91W\xfe'

Official Receiver Decrypts:
hail hitler

Hidden Receiver Decrypts:
kill hitler
