In [1]:
import pwn

from typing import Tuple
from Crypto.Util.number import *
from gmpy2 import powmod
import os

## From the challenge https://chal.ctf.tsj.tw/challenges

In [2]:
def pow(a: int, b: int, c: int) -> int:
    # gmpy2.powmod is much faster than pow
    return int(powmod(a, b, c))


def getPrimeOrderGroup(bits) -> Tuple[int, int, int]:
    """
    Generate a prime p with large prime factor q and a generator g
    """
    while True:
        q = getPrime(bits)
        for i in range(2, 257, 2):
            p = q * i + 1
            if isPrime(p):
                g = pow(getRandomRange(2, p), i, p)
                if g != 1:
                    assert pow(g, q, p) == 1
                    return p, q, g


class RSA:
    def __init__(self, bits):
        self.p = getPrime(bits // 2)
        self.q = getPrime(bits // 2)
        self.n = self.p * self.q
        self.e = 65537
        self.d = pow(self.e, -1, (self.p - 1) * (self.q - 1))

    def encrypt(self, m: int) -> int:
        return pow(m, self.e, self.n)

    def decrypt(self, c: int) -> int:
        return pow(c, self.d, self.n)

    def __str__(self) -> str:
        e = self.e
        n = self.n
        return f"RSA({n}, {e})"


class ElGamal:
    def __init__(self, bits):
        self.p, self.q, self.g = getPrimeOrderGroup(bits)
        self.x = getRandomRange(2, self.q)
        self.y = pow(self.g, self.x, self.p)

    def encrypt(self, m: int) -> Tuple[int, int]:
        r = getRandomRange(2, self.q)
        s = pow(self.y, r, self.p)
        c1 = pow(self.g, r, self.p)
        c2 = (s * m) % self.p
        return c1, c2

    def decrypt(self, c1: int, c2: int) -> int:
        s = pow(c1, self.x, self.p)
        m = (pow(s, -1, self.p) * c2) % self.p
        return m

    def __str__(self) -> str:
        p = self.p
        g = self.g
        y = self.y
        return f"ElGamal({p}, {g}, {y})"


elg = ElGamal(1024)
rsa = RSA(1024)

In [3]:
def egcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = egcd(b % a, a)
        return (g, x - (b // a) * y, y)

def modinv(a, m):
    g, x, y = egcd(a, m)
    if g != 1:
        raise ValueError("No modinv.")
    else:
        return x % m

## Need to get some info about the message

Observations:
* We can encrypt whatever we want as we have the public keys for both schemes.
* We can send updated values of the encrypted message to be decrypted.
* The modulus of the two algorithm is different.

For the ElGamal encryption, we note that the message is retrieved by multiplying `c2` by some secret get get the plain text out. Therefor we can reason that we can generate `message * k`.

Next we need some information leak. The different modulus between RSA and ElGamal means we can detect when we wrap around `p` the in the ElGamal decrypt.

In [4]:
sample_message = bytes_to_long(os.urandom(96))
rsa_base_message = rsa.encrypt(sample_message)
print(hex(rsa_base_message))

# k * sample_message < elg.p
k = 3214123545326
rsa_k_times_message = rsa.encrypt((sample_message * k) % elg.p)
rsa_k_times_message_divide_by_k = (rsa_k_times_message * pow(modinv(k, rsa.n),rsa.e,rsa.n)) % rsa.n
print(hex(rsa_k_times_message_divide_by_k))
assert sample_message * k < elg.p
assert rsa_base_message == rsa_k_times_message_divide_by_k

# k * sample_message > elg.p
k = 3214123545326 * (1<<270)
rsa_k_times_message = rsa.encrypt((sample_message * k) % elg.p)
rsa_k_times_message_divide_by_k = (rsa_k_times_message * pow(modinv(k, rsa.n),rsa.e,rsa.n)) % rsa.n
print(hex(rsa_k_times_message_divide_by_k))
assert sample_message * k > elg.p
assert rsa_base_message != rsa_k_times_message_divide_by_k

0x3eddec238e123fa86a464b893436e1feea63e84ddf8f4fe211e4a9f1d02e71b608cec86dd06ba3264d00eb3c4058ac956c63dd3854b1435ebcb8147a568924bd2b15eb638460430fd2d030168eaed90ca3f906abd3344cd0b8e111d76a9364b490529dda7fb60c537a4ba1855f332355e2a7a463274ae37379d8a9f915d57369
0x3eddec238e123fa86a464b893436e1feea63e84ddf8f4fe211e4a9f1d02e71b608cec86dd06ba3264d00eb3c4058ac956c63dd3854b1435ebcb8147a568924bd2b15eb638460430fd2d030168eaed90ca3f906abd3344cd0b8e111d76a9364b490529dda7fb60c537a4ba1855f332355e2a7a463274ae37379d8a9f915d57369
0x3bf04e321d5c6989ee0b6db9982d9250ebb31f558b0ba12a3fd92f4d9aa40a830fb306c9797c42a29fd2eca241f5ee76b87f38fcd4c3f3873794a28be40f65af8546e95fc6abdac23969dc551c681e57ae17423837744eb3933d43e6ab8940b81b99b1ea2b6154c3dbf1423e5218e491debd71f3771d05dee050976bd222918e


## Oracle leak

Now we have essentially a 1 bit leak from the oracle. We can now binray search for the message. Note this only works because the flag is only at the start of the message. This will not recover the last `ceil(log2(elg.p))//2` bits.

In [132]:
sample_message = bytes_to_long(os.urandom(96))
unknown_bits = len(bin(elg.p)[2:])//2
base_encrypted = rsa.encrypt(sample_message)

# PoC solver.
upper = elg.p // ((1 << 760) - 1)
lower = 0
for i in range(1000):
    k = (upper + lower) // 2
    inv_k = modinv(k,rsa.n)
    c = (rsa.encrypt((sample_message * k) % elg.p) * pow(inv_k,rsa.e,rsa.n)) % rsa.n
    if base_encrypted != c:
        upper = k
        print("+", end="")
    else:
        lower = k
        print("-", end="")
    if upper - 1 == lower:
        break
print()
guess = (elg.p // upper)
print(hex(sample_message)[:-unknown_bits//4])
print(hex(guess)[:-unknown_bits//4])

++++++-+--++-+---++----+-+-++++++-++-+--+---++++++++-+-+-+-++-+--+-++--+---++++--++----++++----+-----+++++-+-+--+-+-+-----++-+--++---+++-+-+++++--+-++--+--++-+++--++-+----++-+-+--+++--++++------+------++++--+-----+-+----+++---+--+-+----++++-+++-------+-+--+-+-+++++
0x5b9452087acb6ebb5f80950dbd86ed3c2f4169ee8673bc4be4aacf45e14858a
0x5b9452087acb6ebb5f80950dbd86ed3c2f4169ee8673bc4be4aacf45e14858a


## Final solve script

In [None]:
s = pwn.connect("34.81.158.137", 8763)
s.timeout=3
data = []
while True:
    d = s.readline()
    if d:
        data.append(d)
    else:
        break

# Parse out the numbers.
n, e = [int(i) for i in data[3][4:-2].split(b",")]
p, g, y = [int(i) for i in data[4][8:-2].split(b",")]
f = int(data[8].strip().split(b' = ')[-1])

# Get ElGamal encrypted base value.
s.timeout=2
s.send("\n".join(["1", f"{f}",""]))
while True:
    d = s.readline()
    if d:
        data.append(d)
    else:
        break
c1,c2 = [int(i) for i in data[13][7:-2].split(b",")]
print("f", hex(f))
print("c1", hex(c1))
print("c2", hex(c2))

def get(s, c1, c2):
    di = []
    s.send("\n".join(["2", f"{c1}", f"{c2}", ""]))
    while True:
        d = s.readline()
        if d:
            di.append(d)
        else:
            break
    assert len(di) > 3
    return int(di[-4].strip().split(b" ")[-1])

def check(s,c1,c2,f,k):
    inv_k = modinv(k,n)
    r = get(s, c1, c2 * k)
    c = r * pow(inv_k,e,n) % n
    return c == f

# Double check the check function.
assert (check(s,c1,c2,f,2))

# Build up the flag bit by bit.
upper = p // ((1 << 760) - 1)
lower = 0
for i in range(1000):
    k = (upper + lower) // 2
    if not check(s,c1,c2,f,k):
        upper = k
        print("+", end="")
    else:
        lower = k
        print("-", end="")
    print(hex(p//k))
    if upper - 1 == lower:
        break

print(long_to_bytes(upper))
    
s.close()