### RSA Cryptography Lab

In [14]:
# Helper Functions
def PowMod(a, n, modulo):
    result = 1
    base = a % modulo
    while n > 0:
        if n % 2 == 1:
            result = (result * base) % modulo
        base = (base * base) % modulo
        n = n // 2
    return result

def ConvertToInt(message):
    return int.from_bytes(message.encode(), 'big')

def ConvertToStr(m):
    byte_length = (m.bit_length() + 7) // 8
    return m.to_bytes(byte_length, 'big').decode('utf-8', errors='ignore')

def InvertModulo(a, n):
    def extended_gcd(a, b):
        if a < b:
            y, x = extended_gcd(b, a)
            return (x, y)
        if b == 0:
            return (1, 0)
        (x1, y1) = extended_gcd(b, a % b)
        x = y1
        y = x1 - (a // b) * y1
        return (x, y)
    (x, y) = extended_gcd(a, n)
    return x % n

def IntSqrt(n):
    if n == 0:
        return 0
    x = n
    y = (x + 1) // 2
    while y < x:
        x = y
        y = (x + n // x) // 2
    return x

def GCD(a, b):
    if b == 0:
        return a
    return GCD(b, a % b)

def ChineseRemainderTheorem(n1, r1, n2, r2):
    def extended_euclid(a, b):
        if a < b:
            y, x = extended_euclid(b, a)
            return (x, y)
        if b == 0:
            return (1, 0)
        (x1, y1) = extended_euclid(b, a % b)
        x = y1
        y = x1 - (a // b) * y1
        return (x, y)
    (x, y) = extended_euclid(n1, n2)
    return (r1 * n2 * y + r2 * n1 * x) % (n1 * n2)

### Q1: RSA Encryption


In [15]:
def Encrypt(message, modulo, exponent):
    return PowMod(ConvertToInt(message), exponent, modulo)

# Test
print(Encrypt("hi", 101, 3))

6


### Q2: Decrypt - RSA Decryption

In [16]:
def Decrypt(ciphertext, p, q, exponent):
    n = p * q
    phi = (p - 1) * (q - 1)
    d = InvertModulo(exponent, phi)
    return ConvertToStr(PowMod(ciphertext, d, n))

p, q = 101, 103
ciphertext = Encrypt("attack", p * q, 3)
print(Decrypt(ciphertext, p, q, 3))

%t


### Q3: DecipherSimple - Dictionary Attack

In [17]:
def DecipherSimple(ciphertext, modulo, exponent, potential_messages):
    for message in potential_messages:
        if ciphertext == Encrypt(message, modulo, exponent):
            return message
    return "don't know"

modulo = 101
exponent = 12
ciphertext = Encrypt("attack", modulo, exponent)
print(DecipherSimple(ciphertext, modulo, exponent, ["attack", "don't attack", "wait"]))

attack


### Q4: DecipherSmallPrime - Factor Small Prime

In [18]:
def DecipherSmallPrime(ciphertext, modulo, exponent):
    for potential_prime in range(2, 1000000):
        if modulo % potential_prime == 0:
            small_prime = potential_prime
            big_prime = modulo // potential_prime
            return Decrypt(ciphertext, small_prime, big_prime, exponent)
    return "don't know"

modulo = 101 * 1000000007
ciphertext = Encrypt("attack", modulo, 239)
print(DecipherSmallPrime(ciphertext, modulo, 239))




### Q5: DecipherSmallDiff - Fermat's Factorization

In [19]:
def DecipherSmallDiff(ciphertext, modulo, exponent):
    a = IntSqrt(modulo)
    if a * a < modulo:
        a += 1
    for _ in range(5000):
        b_squared = a * a - modulo
        if b_squared >= 0:
            b = IntSqrt(b_squared)
            if b * b == b_squared:
                p = a + b
                q = a - b
                return Decrypt(ciphertext, p, q, exponent)
        a += 1
    return "don't know"

p, q = 1000000007, 1000000009
ciphertext = Encrypt("attack", p * q, 239)
print(DecipherSmallDiff(ciphertext, p * q, 239))

attack


### Q6: DecipherCommonDivisor - Shared Prime Factor

In [20]:
def DecipherCommonDivisor(first_ciphertext, first_modulo, first_exponent, second_ciphertext, second_modulo, second_exponent):
    common_prime = GCD(first_modulo, second_modulo)
    if common_prime == 1:
        return ("unknown message 1", "unknown message 2")
    q1 = first_modulo // common_prime
    q2 = second_modulo // common_prime
    return (Decrypt(first_ciphertext, common_prime, q1, first_exponent), 
            Decrypt(second_ciphertext, common_prime, q2, second_exponent))

p, q1, q2 = 101, 1000000007, 1000000009
c1 = Encrypt("attack", p * q1, 239)
c2 = Encrypt("wait", p * q2, 17)
print(DecipherCommonDivisor(c1, p * q1, 239, c2, p * q2, 17))

('\x15\x10', 'wait')


### Q7: DecipherHastad - Broadcast Attack

In [22]:
def DecipherHastad(first_ciphertext, first_modulo, second_ciphertext, second_modulo):
    r = ChineseRemainderTheorem(first_modulo, first_ciphertext, second_modulo, second_ciphertext)
    m = IntSqrt(r)
    return ConvertToStr(m)

p1, q1 = 790383132652258876190399065097, 662503581792812531719955475509
p2, q2 = 656917682542437675078478868539, 1263581691331332127259083713503
c1 = Encrypt("attack", p1 * q1, 2)
c2 = Encrypt("attack", p2 * q2, 2)
print(DecipherHastad(c1, p1 * q1, c2, p2 * q2))

attack
