In [1]:
# Import libraries
import inspect
import warnings
warnings.filterwarnings('ignore')
import ipywidgets as widgets
from ipywidgets import Layout
from IPython.display import display, clear_output, Code

from Crypto.Util.number import getPrime, inverse, GCD, long_to_bytes
from decimal import *
from modint import chinese_remainder
from sympy.core.numbers import igcdex

# Set a sufficiently large precision which is the number of decimal places
getcontext().prec = 2000

In [2]:
#!jupyter nbextension enable --py widgetsnbextension --sys-prefix
#!jupyter serverextension enable voila --sys-prefix

In [3]:
# Widgets for Low Public Exponent Attack

heading_lpe = widgets.HTML(value="<h1>Low Public Exponent Attack</h1>")
text_lpe = widgets.HTMLMath(value=r"<br>This attack works effectively when the public exponent $e$ is small and the size of the plaintext $M$ is also small, such that $M^e$ will still be smaller than the modulus $N$. In this way, there is no need to perform a modulo $N$ when calculating $C$:<br><br>$C = M^e $ mod $ N = M^e$<br><br>$M$ can then be easily determined by finding $C^{1/e}$.<br>")

def RSA_LPE(e):
    gcd = 0
    while gcd != 1:
        p = getPrime(2048)
        q = getPrime(2048)
        n = p * q
        phi = (p-1) * (q-1)
        d = inverse(e, phi)
        gcd = GCD(e, phi)
    return d, n

def RSAEncryption_LPE(e, n, data):
    plaintext = int(data.hex(), 16)
    ciphertext = pow(plaintext, e, n)
    return long_to_bytes(ciphertext)

def RSADecryption_LPE(d, n, data):
    ciphertext = int(data.hex(), 16)
    plaintext = pow(ciphertext, d, n)
    return long_to_bytes(plaintext)

def LowPublicExponentAttack(ciphertext, e):
    ciphertext_int = int(ciphertext.hex(), 16)
    print("M = C^(1/e)")
    plaintext = int((Decimal(ciphertext_int) ** (Decimal(1)/Decimal(e))).to_integral_exact())
    print("  =", plaintext)
    return long_to_bytes(plaintext)

code_lpe = widgets.Output()
with code_lpe:
    print()
    print("The implemented RSA protocol for this attack is as follows:")
    print()
    display(Code(inspect.getsource(RSA_LPE)))
    display(Code(inspect.getsource(RSAEncryption_LPE)))
    display(Code(inspect.getsource(RSADecryption_LPE)))

code_lpe_attack = widgets.Output()
with code_lpe_attack:
    print()
    print("The code for the Low Public Exponent Attack is as follows:")
    print()
    display(Code(inspect.getsource(LowPublicExponentAttack)))

vbox_lpe_desc = widgets.VBox([heading_lpe,
                              code_lpe,
                              text_lpe,
                              code_lpe_attack])

def LPEDemo(plaintext):
    e = 3
    print("Processing...")
    d, n = RSA_LPE(e)
    ciphertext = RSAEncryption_LPE(e, n, plaintext.encode()).hex()
    print("Provided message:", plaintext)
    print("Given e =", e)
    print("and C =", ciphertext)
    msg = LowPublicExponentAttack(bytes.fromhex(ciphertext), e).decode("utf-8")
    print()
    print("Recovered message using Low Public Exponent Attack:", msg)

text_demo_lpe = widgets.HTML(value="Please provide a plaintext input:")

input_lpe = widgets.Text(
    value='Did you put your name into the Goblet of Fire, Harry?',
    placeholder='Type something here',
    description='Plaintext:',
    disabled=False
)

vbox_lpe = widgets.VBox([vbox_lpe_desc, text_demo_lpe, input_lpe])

button_lpe = widgets.Button(
                description='Demo',
                style={'description_width': 'initial'}
            )

output_lpe = widgets.Output()

def on_button_clicked_lpe(event):
    output_lpe.clear_output()
    with output_lpe:
        plaintext = input_lpe.value
        # To check for invalid plaintext/key input length
        if len(plaintext) == 0:
            print("Please provide a plaintext input for the demo")
        else:
            LPEDemo(plaintext)

button_lpe.on_click(on_button_clicked_lpe)

vbox_lpe_demo = widgets.VBox([vbox_lpe, button_lpe, output_lpe])

In [4]:
# Widgets for Hastad Broadcasting Attack

heading_hb = widgets.HTML(value="<h1>Hastad Broadcasting Attack</h1>")
text_hb = widgets.HTMLMath(value=r"<br>This is a form of low exponent attack using the Chinese Remainder Theorem (CRT). For the attack to work, the sender must first encrypt the same plaintext $M$ with the same public exponent $e$ but with different moduli $N_1, N_2, ..., N_i$ to obtain ciphertexts $C_1, C_2, ..., C_i$. In this specific scenario, 3 different $N$s are used:<br><br>$C_1 = M^e$ mod $N_1$<br>$C_2 = M^e$ mod $N_2$<br>$C_3 = M^e$ mod $N_3$<br>")
text_hb2 = widgets.HTMLMath(value=r"<br>Using CRT, the following can be computed:<br><br>$N = N_1 \times N_2 \times ... \times N_i$<br>$M_i = N/N_i$ for $i = 1,2,...$<br>$y_i = m_i^{-1}$ mod $N_i$ for $i = 1,2,...$<br>$X = (C_1 M_1 y_1 + C_2 M_2 y_2 + ... + C_i M_i y_i)$ mod $N$<br><br>Once the CRT solution has been found, the plaintext can be recovered by computing $X^{1/e}$")

def RSA_HB(e):
    gcd = 0
    while gcd != 1:
        p = getPrime(512)
        q = getPrime(512)
        n = p * q
        phi = (p-1) * (q-1)
        d = inverse(e, phi)
        gcd = GCD(e, phi)
    return d, n

def RSAEncryption_HB(e, n, data):
    plaintext = int(data.hex(), 16)
    ciphertext = pow(plaintext, e, n)
    return ciphertext

def RSADecryption_HB(d, n, data):
    ciphertext = int(data.hex(), 16)
    plaintext = pow(ciphertext, d, n)
    return long_to_bytes(plaintext)

def HastadBroadcastingAttack(e, n1, n2, n3, c1, c2, c3):
    N = [n1, n2, n3]
    C = [c1, c2, c3]
    result = chinese_remainder(N, C)
    print("Chinese Remainder Theorem Result:", result)
    plaintext = int((Decimal(result) ** (Decimal(1)/Decimal(e))).to_integral_exact())
    return long_to_bytes(plaintext)

code_hb = widgets.Output()
with code_hb:
    print()
    print("The implemented RSA protocol for this attack is as follows:")
    print()
    display(Code(inspect.getsource(RSA_HB)))
    display(Code(inspect.getsource(RSAEncryption_HB)))
    display(Code(inspect.getsource(RSADecryption_HB)))

code_hb_attack = widgets.Output()
with code_hb_attack:
    print()
    print("The code for the Hastad Broadcasting Attack is as follows:")
    print()
    display(Code(inspect.getsource(HastadBroadcastingAttack)))

vbox_hb_desc = widgets.VBox([heading_hb,
                             code_hb,
                             text_hb,
                             text_hb2,
                             code_hb_attack])

def HBDemo(plaintext):
    e = 3
    print("Processing...")
    d1, n1 = RSA_HB(e)
    d2, n2 = RSA_HB(e)
    d3, n3 = RSA_HB(e)
    c1 = RSAEncryption_HB(Decimal(e), Decimal(n1), plaintext.encode())
    c2 = RSAEncryption_HB(Decimal(e), Decimal(n2), plaintext.encode())
    c3 = RSAEncryption_HB(Decimal(e), Decimal(n3), plaintext.encode())
    print("Provided message:", plaintext)
    print("Given e =", e)
    print("     N1 =", n1)
    print("     N2 =", n2)
    print("     N3 =", n3)
    print("     C1 =", c1)
    print("     C2 =", c2)
    print("     C3 =", c3)
    msg = HastadBroadcastingAttack(e, n1, n2, n3, c1, c2, c3).decode("utf-8")
    print()
    print("Recovered message using Hastad Broadcasting Attack with CRT:", msg)

text_demo_hb = widgets.HTML(value="Please provide a plaintext input:")

input_hb = widgets.Text(
    value='Closing down for ever - all the best - goodbye.',
    placeholder='Type something here',
    description='Plaintext:',
    disabled=False
)

vbox_hb = widgets.VBox([vbox_hb_desc, text_demo_hb, input_hb])

button_hb = widgets.Button(
                description='Demo',
                style={'description_width': 'initial'}
            )

output_hb = widgets.Output()

def on_button_clicked_hb(event):
    output_hb.clear_output()
    with output_hb:
        plaintext = input_hb.value
        # To check for invalid plaintext/key input length
        if len(plaintext) == 0:
            print("Please provide a plaintext input for the demo")
        else:
            HBDemo(plaintext)

button_hb.on_click(on_button_clicked_hb)

vbox_hb_demo = widgets.VBox([vbox_hb, button_hb, output_hb])

In [5]:
# Widgets for Common Modulus Attack

heading_cm = widgets.HTML(value="<h1>Common Modulus Attack</h1>")
text_cm = widgets.HTMLMath(value=r"<br>This attack is effective for breaking RSA protocols that generate pairs of public and private keys from the same modulus $N$, usually for convenience. If 2 of the public exponents $e_1$ and $e_2$ are relatively prime to each other, i.e. gcd($e_1$,$e_2$) = 1, then according to Bezout's Identity:<br><br>$e_1x + e_2y = 1$<br><br>The values of $x$ and $y$ can then be determined through the use of the Extended Euclidean Algorithm. After that, to recover the original plaintext, the following computation can be performed:<br>")
text_cm2 = widgets.HTMLMath(value=r"<br>$C_1^x \times C_2^y$<br>$= M^{e_1x} \times M^{e_2y}$<br>$= M^{e_1x + e_2y}$<br>$= M$")

def RSA_CM(e1, e2):
    gcd1 = 0
    gcd2 = 0
    while gcd1 != 1 or gcd2 != 1:
        p = getPrime(1024)
        q = getPrime(1024)
        n = p * q
        phi = (p-1) * (q-1)
        d1 = inverse(e1, phi)
        d2 = inverse(e2, phi)
        gcd1 = GCD(e1, phi)
        gcd2 = GCD(e2, phi)
    return d1, d2, n

def RSAEncryption_CM(e, n, data):
    plaintext = int(data.hex(), 16)
    ciphertext = pow(plaintext, e, n)
    return ciphertext

def RSADecryption_CM(d, n, data):
    ciphertext = int(data.hex(), 16)
    plaintext = pow(ciphertext, d, n)
    return long_to_bytes(plaintext)

def CommonModulusAttack(e1, e2, c1, c2, n):
    x, y, _ = igcdex(e1, e2)
    print("Using Extended Euclidean Algorithm:")
    print("x =", x)
    print("y =", y)
    plaintext = int(((Decimal(pow(c1, x)) / Decimal(pow(c2, -y))) % Decimal(n)).to_integral_exact())
    print("M = C1^x * C2^y")
    print("  =", plaintext)
    return long_to_bytes(plaintext)

code_cm = widgets.Output()
with code_cm:
    print()
    print("The implemented RSA protocol for this attack is as follows:")
    print()
    display(Code(inspect.getsource(RSA_CM)))
    display(Code(inspect.getsource(RSAEncryption_CM)))
    display(Code(inspect.getsource(RSADecryption_CM)))

code_cm_attack = widgets.Output()
with code_cm_attack:
    print()
    print("The code for the Common Modulus Attack is as follows:")
    print()
    display(Code(inspect.getsource(CommonModulusAttack)))

vbox_cm_desc = widgets.VBox([heading_cm,
                             code_cm,
                             text_cm,
                             text_cm2,
                             code_cm_attack])

def CMDemo(plaintext):
    e1 = 17
    e2 = 11
    print("Processing...")
    d1, d2, n = RSA_CM(e1, e2)
    c1 = RSAEncryption_CM(Decimal(e1), Decimal(n), plaintext.encode())
    c2 = RSAEncryption_CM(Decimal(e2), Decimal(n), plaintext.encode())
    print("Provided message:", plaintext)
    print("Given e1 =", e1, ", e2 =", e2)
    print("N =", n)
    print("C1 =", c1)
    print("C2 =", c2)
    msg = CommonModulusAttack(e1, e2, c1, c2, n).decode("utf-8")
    print("Recovered message using Common Modulus Attack:", msg)

text_demo_cm = widgets.HTML(value="Please provide a plaintext input:")

input_cm = widgets.Text(
    value='Hello World!',
    placeholder='Type something here',
    description='Plaintext:',
    disabled=False
)

vbox_cm = widgets.VBox([vbox_cm_desc, text_demo_cm, input_cm])

button_cm = widgets.Button(
                description='Demo',
                style={'description_width': 'initial'}
            )

output_cm = widgets.Output()

def on_button_clicked_cm(event):
    output_cm.clear_output()
    with output_cm:
        plaintext = input_cm.value
        # To check for invalid plaintext/key input length
        if len(plaintext) == 0:
            print("Please provide a plaintext input for the demo")
        else:
            CMDemo(plaintext)

button_cm.on_click(on_button_clicked_cm)

vbox_cm_demo = widgets.VBox([vbox_cm, button_cm, output_cm])

In [6]:
# Menu Buttons

button1 = widgets.Button(
                description='Low Public Exponent Attack',
                layout=Layout(width='200px'),
                style={'description_width': 'initial'})

button2 = widgets.Button(
                description='Hastad Broadcasting Attack',
                layout=Layout(width='200px'),
                style={'description_width': 'initial'})

button3 = widgets.Button(
                description='Common Modulus Attack',
                layout=Layout(width='200px'),
                style={'description_width': 'initial'})

main_output = widgets.Output()

def on_button1_clicked(event):
    main_output.clear_output()
    with main_output:
        display(vbox_lpe_demo)
button1.on_click(on_button1_clicked)

def on_button2_clicked(event):
    main_output.clear_output()
    with main_output:
        display(vbox_hb_demo)
button2.on_click(on_button2_clicked)

def on_button3_clicked(event):
    main_output.clear_output()
    with main_output:
        display(vbox_cm_demo)
button3.on_click(on_button3_clicked)

hbox_buttons = widgets.HBox([button1, button2, button3])
vbox_main = widgets.VBox([hbox_buttons, main_output])

# Demonstration of Attacks on Weak RSA
---
This demonstrator showcases 3 different types of attacks on a simple RSA protocol that is weakly implemented to recover the original plaintext:<br>

<li>Low Public Exponent Attack</li>
<li>Hastad Broadcasting Attack</li>
<li>Common Modulus Attack

In [7]:
vbox_main

VBox(children=(HBox(children=(Button(description='Low Public Exponent Attack', layout=Layout(width='200px'), s…

<br><br><br>

In [8]:
#!pip freeze > requirements.txt
#!pip list --format=freeze > requirements.txt