# RSA cryptography method
## Introduction:

RSA (Rivest-Shamir-Adleman) is a widely-used method of public-key cryptography. It is named after its inventors, Ron Rivest, Adi Shamir, and Leonard Adleman, who first published it in 1977. RSA is used to encrypt and decrypt messages, as well as to create digital signatures.


## Use python library 


The rsa library (https://stuvel.eu/python-rsa-doc/usage.html) is a Python library for implementing the RSA algorithm. It provides a simple and easy-to-use implementation of RSA, including key generation, encryption, decryption, and signing/verifying. 

Here are a few examples of how you can use the rsa library:

### Key Generation:

The **"rsa.newkeys()"** function is a part of the "rsa" library in Python and generates a new RSA key pair. The function takes one optional argument which is the size of the key in bits. If the argument is not specified, the default key size is 2048 bits. The function returns a tuple containing the public and private keys. The public key can be shared with anyone for encryption, while the private key must be kept secure for decryption and signing.

In [609]:
import rsa
(pubkey, privkey) = rsa.newkeys(512,accurate = True)

In RSA cryptography, the public key is composed of two components:

- n: The modulus, which is the product of two large prime numbers and forms the foundation of the encryption and decryption process. 

- e: The public exponent, which is a small odd integer used to encrypt messages. It is typically chosen to be 65537.
These two components of public key, can be used by others to encrypt messages that can only be decrypted using the corresponding private key. Here the **public key** is printed:

In [610]:
print(pubkey)

# A typical size for n is 1024 bits, or 309 decimal digits. That is why we got len(str(pubkey.n) = 309 here.
print(len(str(pubkey.n)))

# if we transfer it to binary, it should be 1024 bits, but why it is 1026 bits?
# 0b indicates that it number is binary representation. we can ignore the first 2 bits, 1026-2 = 1024. correctly!
print(len(str(bin(pubkey.n))))

PublicKey(11323450162675592716169913898430970415027044043116977610960567435002867533609284838650531089266641014458686913254631339567747321009371781309619154842255789, 65537)
155
514


The private key is kept secret and should not be shared with others. In RSA cryptography, the private key is composed of five components:

- n: The modulus, which is the product of two large prime numbers and forms the foundation of the encryption and decryption process. 

- e: The public exponent, which is a small odd integer used to encrypt messages. It is typically chosen to be 65537. 

- d: The private exponent, which is used to decrypt messages and create digital signatures. It is kept secret and derived from the public exponent and other information, including the modulus and prime factors used to generate the key. 
- p: One of the prime factors of n. 

- q: The other prime factor of n. 

The **private key** is printed here:

In [611]:
print(privkey)

# check whether p*q==n, since it returned True, correctly! 
print(privkey.p*privkey.q ==privkey.n)

PrivateKey(11323450162675592716169913898430970415027044043116977610960567435002867533609284838650531089266641014458686913254631339567747321009371781309619154842255789, 65537, 10744984393039541563176385483460085434794342768167626236294409695521511949350225187552418123139503184799968744700705444425279172379563559146986723483305473, 6802844660803364060099289597288979956119688717265906897764555003264242993623312001, 1664516937733277542011149147531152036968429030229340456255090597905087789)
True


### Encryption:

The **rsa.encrypt()** function is used to encrypt the plaintext message using the public key. The plaintext message must be passed as bytes, and the public key is passed as the second argument. The function returns the ciphertext, which is the encrypted version of the plaintext message.

In [612]:
# Define the message to be encrypted
message = "Hello World!"

# Encode the message as a byte string
message = message.encode()
print(message)

# Encrypt the message using the public key
ciphertext = rsa.encrypt(message, pubkey)

print(ciphertext)
# print(len(ciphertext))

b'Hello World!'
b'>\xbb\xc2\xf6\xb7\x86\xc5\x15\xac\x99\x9b\xa3\x1d\x94\xf6\xe5\xed\x85\xa9\xb0\xa0c\x92\xdc-\xba\x84 8\x0e\xb4\xc5\xbc\xa7:F\x03\xd8o\xcd\xa0\xaa<\xb5\x1c1}v\xfa\xcdB\xa96\x8d\xe3\x15\x88\xbf\xfa\x1c\xb5\x0b\x8fx'


The output is a binary string, which represents the encrypted message. It is generated by applying the RSA encryption algorithm to the original message, using the public key of the recipient. The output is unreadable, as it has been transformed into a sequence of seemingly random bytes through the encryption process. To decrypt the message, the recipient would use their private key.

###  Decryption:

The **rsa.decrypt()** function is used to decrypt the ciphertext using the private key. The ciphertext and the private key are passed as arguments. The function returns the decrypted plaintext message as bytes, which can then be decoded to a string using the .decode() method.

In [613]:
plaintext = rsa.decrypt(ciphertext, privkey).decode()
print(plaintext)

Hello World!


###  Signing:

The rsa.sign() function is used to sign the message using the private key. The message must be passed as bytes, and the private key and a hash function such as 'SHA-1' are passed as arguments. The function returns the signature, which is a digital signature of the message.  The SHA-1 argument specifies the hashing algorithm to use when generating the signature.

In this case, the SHA-1 algorithm is used to hash the message, and the resulting hash is then signed using the private key. The signature acts as a digital signature, verifying that the message came from the owner of the private key and has not been tampered with.

In [614]:
message = "Hello, World!"
signature = rsa.sign(message.encode(), privkey, 'SHA-1')

print(signature)

b'\xb9pb\xf1\x96\xe2\xaa#\xf2\x91\xcauS\x99\xc5Q\xbe\x8a\xaa\xebV\x94ZT\x1a_\x97y\x180\x91\xee\x87\xe4\xe43\xfc\x80ee\x12\xb5\xeb)0\xe3n\x08!@u;\xe8+Q\xdb\xaa\x93Q\xca\x88\x89\xf0\xd5'


### Verifying:

The rsa.verify() function is used to verify the signature of a message. The message, signature, and public key are passed as arguments. If the signature is valid, the function returns True, otherwise it raises a rsa.pkcs1.VerificationError exception.

In [615]:
rsa.verify(message.encode(), signature, pubkey)

'SHA-1'

In [616]:
import rsa

# Generate a new RSA key pair
(pubkey, privkey) = rsa.newkeys(1024)

# Define the message to be signed
message = "Hello, World!"
message = message.encode()

# Sign the message using the private key
signature = rsa.sign(message, privkey, 'SHA-256')

# Verify the signature using the public key
try:
    rsa.verify(message, signature, pubkey)
    print("Signature is valid.")
except rsa.pkcs1.VerificationError:
    print("Signature is invalid.")


Signature is valid.


## Implementing RSA without using rsa library

In this part, we plan to forget rsa library and try to implement RSA algorithm step by step. We want to cover the following topics:

- Generating RSA keys: We will learn how to generate a public and private key using Python.

- Encryption and Decryption: We will learn how to encrypt and decrypt a message using the RSA method.


Let's get started!

## Step1: Generating RSA keys:

The process of generating RSA keys involves the following steps:

- Select two large prime numbers, p and q.
- Compute n = p * q. n is used as the modulus for both the public and private keys.
- Compute the totient of n, denoted as φ(n) = (p-1)*(q-1).
- Choose a public exponent, e, such that 1 < e < φ(n) and e is coprime to φ(n).
- Determine the private exponent, d, such that d * e ≡ 1 (mod φ(n)). d can be calculated using the extended Euclidean algorithm.

To generate RSA keys in Python, we will need to import the following libraries:

In [617]:
import random
import math
import sympy

###  generate_large_prime() function:

This function has been used to generate two large prime numbers, p and q.

## Question:
Complete the following function to generate p and q:

In [618]:
def isPrime(n):
    """
    This function is to verify whether a number is prime or not
    using Miller_Rabin primality test (because it is more fast to check whether a number is prime from probability perspective, it's not perfect but useful)
    see: https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test

    Minimum number of rounds of M-R testing when generating primes using an error probability of 2^(-100)
    see: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.186-4.pdf, page 71, Appendix C, Table C.3

    Args:
        n: the decimal number

    Returns:
        True/False
    """
    if n <= 3:
        return n == 2 or n == 3

    numberOfBits = len(str(bin(n)))-2
    if numberOfBits >= 1536:
        k = 3
    elif numberOfBits >= 1024:
        k = 4
    elif numberOfBits >= 512:
        k = 7
    else:
        k = 14
    
    # Miller_Rabin primality test
    # initiate d and s, make the equation n − 1 = 2^(s) * d hold
    d = n-1
    s = 0

    # when d is a even number, we need to process it to an odd number
    while d % 2 == 0:
        d >>= 1
        s += 1

    for i in range(k):
        a = random.randint(2, n-2)
        # print("d now is {}".format(d))
        x = pow(a, int(d), n)

        for j in range(s):
            y = pow(x, 2, n)
            if y == 1 and x != 1 and x != n - 1:
                return False
            x = y
        if y != 1:
            return False
    return True

In [619]:
def generate_large_prime(nbits):
    """
    This function is to generate two random prime number (p,q) with the given number of bits.
    To make factoring harder, p and q should be chosen at random, be similar in magnitude, but differ in length.

    Args:
        nbits: The number of bits (which we want for modulus - n).

    Returns:
        (p,q)
    """

    bitDifference = nbits // 32

    while True:
        # genetate a random bits which have nbits//2 bits
        p = random.getrandbits(nbits//2 + bitDifference)
        # make sure n is an odd number, because all prime number is odd, except 2, which we don't consider it.
        p = p|1
        if isPrime(p):
            break
    
    while True:
        q = random.getrandbits(nbits//2 - bitDifference)
        q = q|1
        if isPrime(q) & (q!=p):
            break
    
    return (p,q)

In [620]:
(p,q) = generate_large_prime(1024)
print((p,q))
print(len(str(bin(p))))
print(len(str(bin(q))))
print(len(str(bin(p*q))))

# check if (p,q) are prime numbers
print(rsa.prime.is_prime(p))
print(rsa.prime.is_prime(q))

(32958237402471602930989377554484250548010289663611167200339133399237414135540043810111098228940508812182692029494596171430745957543931701350151837273478393125132721, 2780279616397420091156777483084361771004958714701533541508560642646224397230351009329238779276513706202398683166169702267608529338530573232180081)
546
482
1026
True
True


### Step 2: Selecting the public exponent e using generate_public_exponent function 

generate_public_exponent function will generate a random number between 2 and phi-1, then check if that number is relatively prime to phi (i.e. their greatest common divisor is 1). If it is not relatively prime, it will generate a new random number and repeat this process until it finds a relatively prime number.

## Question:
Based on the above description,complete the following function to generate public key exponent e:

In [621]:
def is_relatively_prime(x, y):
    # Determining if two numbers are relatively prime
    gcd = math.gcd(x,y)
    return gcd == 1

def generate_public_exponent(phi):
    # modulus - n
    # publicExponent - e
    # privateExponent - d
    while True:
        e = random.randint(2, phi-1)
        # if math.gcd(e,phi) == 1:
        if is_relatively_prime(e,phi):
            return e

### Step 3 : Selecting the private exponent d using  calculate_private_exponent(phi, e) function

This function will calculate the modular inverse of e mod phi. The modular inverse of a number is a number such that when it is multiplied by the original number, the result is 1 mod some other number. In this case, the result should be 1 mod phi.

The function mod_inverse(a,m) is used to find the modular inverse of a number. It returns the number x such that a*x = 1 mod m.

## Question:
Based on the above description,complete the following function by writing the code for the mod_inverse(e, phi) function in order to generate public key exponent e:

In [622]:
def extended_euclidean_algorithm(a,b):
    # Use Extended Euclidean Algorithm
    # see: https://en.wikipedia.org/wiki/Modular_multiplicative_inverse
    # see: https://www.educative.io/answers/what-is-extended-euclidean-algorithm
    # a*x + b*y = gcd(a,b) = r, given a and b, the Extended Euclidean Algorithm can get x, y and r
    # in our case, ed + y*phi = 1 = gcd(e,phi), we need to calculate d. 

    if a == 0: 
        return b, 0, 1        
    gcd, x, y = extended_euclidean_algorithm(b%a, a)
    x,y = y - (b//a) * x, x 
    return gcd, x, y

def mod_inverse(a,b):
    gcd,x,y = extended_euclidean_algorithm(a,b)
    return x

In [623]:
def calculate_private_exponent(phi, e):
    # Calculating the modular inverse of e mod phi
    d = mod_inverse(e, phi)
    if(d < 0):
        d += phi
    return d

In [624]:
A = calculate_private_exponent(49,3)
print('A:{}'.format(A))

A:33


## Writing encryption and decryption functions



### Encryption

The encrypt function takes the plaintext and the public key as input, and returns the ciphertext. The encryption is done by raising the plaintext to the power of the public exponent (e) and then taking the remainder when divided by n.

## Question:
Write a function to do the encryption:

In [625]:
def encrypt(message, e, n):
    # byte value can be interchanged to an int value by using the int.from_bytes() method.
    length_n = len(str(bin(n)))
    message = int.from_bytes(message, "big", signed=False)
    ciphertext = pow(message,e,n)
    # transform ciphertext from a number to bytes
    ciphertext = ciphertext.to_bytes(length_n,"big",signed=False)
    return ciphertext

### Decryption
The decrypt function takes the ciphertext and the private key as input, and returns the original plaintext. The decryption is done by raising the ciphertext to the power of the private exponent (d) and then taking the remainder when divided by n.

## Question:
Write a function to do the decryption:

In [626]:
def decrypt(ciphertext,d,n):
    length_n = len(str(bin(n)))
    ciphertext = int.from_bytes(ciphertext, "big", signed=False)
    message = pow(ciphertext,d,n)

    message = message.to_bytes(length_n,"big",signed=False)
    message = message.decode()
    return message

## Question:

based on the functions you implemented above, calculate p, q, n, phi,e, and d for a rather small amount of n. Then consider a specefic message and do the encrytion and decryption and provide the output. What will hapen for large values of n and why? 

In [631]:
# calculate p, q, n, phi, e, and d for certain value of n (the expected number of bits of modular - n).
(p,q) = generate_large_prime(4096)
n = p * q
phi = (p-1)*(q-1)
e = generate_public_exponent(phi)
d = calculate_private_exponent(phi, e)

print(p,q,n,phi,e,d)

# The message we want to transmit
# message = "Hello!!!"
message = "The Short Answer: Suppose you have a list of the first billion prime numbers. You pick 2 at random and multiply them together (a very fast process). Then you give me the result and ask me to find which 2 prime numbers you used. In the general case I have no choice but to try lots of guesswork (a very slow process)."
message = message.encode()
print(message)
# bytes_as_bits = ''.join(format(ord(byte), '08b') for byte in "Hello world!!!")
# print("length of message is :",len(bytes_as_bits))

# encrypt the message
ciphertext = encrypt(message, e, n)
print('ciphertext:',ciphertext)

# decrypt the message 
decrypt_message = decrypt(ciphertext,d,n)
print(decrypt_message)

772552372172443617434609215233491153522058141469100741669020508510084742051097363932362534850544561787986777528351905284542177220204293328826541728473760315072444679563340184724547308736467069558374821675403267949109521624321182515966455327032911020997044406200961735026027013807342538863608109223829833469925086147114182973195245813723613823278116221366977323170546271135967255973972690925218063410359296592884556560882864813534867930204785387945655820975495874402454410235269503938741045998561377616828060354124044471650958714711825959194687439782894998709527272663690149596312252532975060133508408411401038279265725412693642920359059499666577438477117 845376521354843231326963307502988406490469092954259877199390054525966764302021103582467674138330270712331363004569922556697390661380771957914967625964304211372145063708531732785173670495503366056932017687348002402172801580579479016543584970103574000667137962251815276687929036393275692504307118348112298903732012747002795100951710515945356385698

**Answer:** For larger value of n, it costs more time to execute key generation, encryotion and decryption. This is because it is more difficult and it needs more computation resources to factoring large integers.

# Denial of Service (DoS) attack

Suppose you have a Python web server that does the following:
- Receives a file over HTTP 
- Perform the processing of the file
- sends back a response via HTTP

For each HTTP request, the server uses a dedicated thread to allow multiple users to access the server simultaneously. In addition, in order to reduce the size of data transmitted through the network, this server utilizes the latest compression technology  (that's a joke). 

Files contain substitution rules along with payloads. At the beginning of every file, these rules are listed.Each line in the file that begins with "*" is considered a rule. As an example, consider the following document:

"""

$*$ DoS==Denial of Service 

$*$ nt==network

$*$ ac==attack

A DoS ac is an ac meant to shut down a machine or nt, making it inaccessible to its intended users.

"""

This text can be can be decompressed to:

"""

A Denial of Service attack is an attack meant to shut down a machine or network, making it inaccessible to its intended users.

"""

Following is the decompression function responsible for decoding the original text from the compressed text: 


## Decompression

In [416]:
def decompress(compressed_text):
    # Initialize an empty dictionary to store the substitution rules
    substitution_rules = {}
    # Initialize an empty string to store the final result
    decompressed_text = ""
    # Iterate over each line of the input text
    for line in compressed_text.split("\n"):
        # Check if the line starts with "*"
        if line.startswith("*"):
            # If so, split the line into key and value, remove leading/trailing whitespaces, and store the rule in the dictionary
            key, value = map(str.strip, line[1:].split("="))
            substitution_rules[key] = value
            # Continue to the next line
            continue
        # If the line does not start with "*", add the line to the final result
        decompressed_text += line + "\n"
    # Keep looping until no more rules are applied
    while True:
        # Initialize a flag to check if a rule is applied
        rule_applied = False
        # Iterate over each key in the dictionary
        for key in substitution_rules:
            # Check if the key is present in the final result
            if key in decompressed_text:
                # If so, replace the key with the corresponding value and set the flag to True
                decompressed_text = decompressed_text.replace(key, substitution_rules[key])
                rule_applied = True
        # If no rule is applied, exit the loop
        if not rule_applied:
            break
    # Return the final result
    return decompressed_text

### Example for decompress function:

In [417]:
if __name__ == '__main__':
    res = decompress("""
* ds = Denial of service attack
* p_ = Denial of service attacks
* A = attacks
* U=users
* S=server
* N=network
* sm=system

a ds is an attack meant to shut down a machine or N, making it inaccessible to its intended U.

This can be accomplished by overwhelming the targeted sm with excessive traffic or requests,
causing it to crash or become unavailable to U.

There are several different types of p_, including network-level A,
application-level A, and distributed p_ . 
    """)
    
    print(res)



a Denial of service attack is an attack meant to shut down a machine or network, making it inaccessible to its intended users.

This can be accomplished by overwhelming the targeted system with excessive traffic or requests,
causing it to crash or become unavailable to users.

There are several different types of Denial of service attacks, including network-level attacks,
application-level attacks, and distributed Denial of service attacks . 
    



## Vulnerability in this system

### Issue:

An issue of the implemmented decompress function is that the function does not have any protections against infinite looping. An attacker could craft a file with a rule that replaces a string with itself, which would cause the function to loop infinitely and consume a large amount of processing power. In other words, the problem will happen in the case:

key= key
Here, the rule is that "key" should be replaced with "key". When the decompress function encounters the word "key" in the input, it will replace it with "key", and the next time it will replace it with "key" again and so on. This will cause the function to loop infinitely and consume a large amount of processing power.An attacker could craft a file with the following text to cause the function to loop infinitely and consume a large amount of processing power:

In [418]:
if __name__ == '__main__':
    res_ = decompress("""
* ds = Denial of service attack
* p_ = Denial of service attacks
* A = A
* U=users
* S=server
* N=network
* sm=system

a ds is an attack meant to shut down a machine or N, making it inaccessible to its intended U.

This can be accomplished by overwhelming the targeted sm with excessive traffic or requests,
causing it to crash or become unavailable to U.

There are several different types of p_, including network-level A,
application-level A, and distributed p_ . 
    """)
    
    print(res_)

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/Users/chenbingcheng/Library/Python/3.11/lib/python/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/06/gmp4cfpd1kn490jzvs864zh00000gn/T/ipykernel_11717/1874739930.py", line 2, in <module>
    res_ = decompress("""
           ^^^^^^^^^^^^^^
  File "/var/folders/06/gmp4cfpd1kn490jzvs864zh00000gn/T/ipykernel_11717/362531113.py", line -1, in decompress
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/chenbingcheng/Library/Python/3.11/lib/python/site-packages/IPython/core/interactiveshell.py", line 2052, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/chenbingcheng/Library/Python/3.11/lib/python/site-packages/IPython/core/ultratb.py", line 1118, in structured_traceback
    return

The substitution rule "A = A" would replace the string "A" with itself, causing the function to keep looping and replacing the string "A" with itself, without making any progress. This would consume a large amount of processing power and could cause the program to crash.

## Question

Find the section of code within the provided decompression function that caused this vulnerability and explain your selection by detailing the rationale behind it.

**Answer：** The following code chunk in 'decompression' function has vulnerability problem. if 'substitution_rules' is not empty and in it there exists 'key = substitution_rules[key]',if the condition 'key in decompressed_text' holds, after the first round of for loop execution, rule_applied is set True, the second round of for loop will be executed, and rule_applied is set True again, and after that is the third round of for loop，etc., the code execution can not escape from while loop.

```python
while True:
        # Initialize a flag to check if a rule is applied
        rule_applied = False
        # Iterate over each key in the dictionary
        for key in substitution_rules:
            # Check if the key is present in the final result
            if key in decompressed_text:
                # If so, replace the key with the corresponding value and set the flag to True
                decompressed_text = decompressed_text.replace(key, substitution_rules[key])
                rule_applied = True
        # If no rule is applied, exit the loop
        if not rule_applied:
            break
```


# Solve the vulnerability:

## Question :

- Find a piece of code to be added to decompression function in order to prevent this vulnerability.

- Test the functionality of your code by testing it on the given example.

In [632]:
def decompress(compressed_text):
    # Initialize an empty dictionary to store the substitution rules
    substitution_rules = {}
    # Initialize an empty string to store the final result
    decompressed_text = ""
    # Iterate over each line of the input text
    for line in compressed_text.split("\n"):
        # Check if the line starts with "*"
        if line.startswith("*"):
            # If so, split the line into key and value, remove leading/trailing whitespaces, and store the rule in the dictionary
            key, value = map(str.strip, line[1:].split("="))
            substitution_rules[key] = value
            # Continue to the next line
            continue
        # If the line does not start with "*", add the line to the final result
        decompressed_text += line + "\n"
    # Keep looping until no more rules are applied
    while True:
        # Initialize a flag to check if a rule is applied
        # rule_applied = False

        # if substitution_rules is empty, means no rule will be applied,return the decompressed_text
        if not substitution_rules:
            break

        # Iterate over each key in the dictionary
        for key in substitution_rules:
            # Check if the key is present in the final result
            if (key != substitution_rules[key]) and (key in decompressed_text):
                # If so, replace the key with the corresponding value and set the flag to True
                decompressed_text = decompressed_text.replace(
                    key, substitution_rules[key])
                # rule_applied = True
        break

        ##### change a little bit of the code, in order to improve the efficiency ###
        # If no rule is applied, exit the loop
        # if not rule_applied:
        #     break
    # Return the final result
    return decompressed_text

In [633]:
if __name__ == '__main__':
    res_ = decompress("""
* ds = Denial of service attack
* p_ = Denial of service attacks
* A = A
* U=users
* S=server
* N=network
* sm=system

a ds is an attack meant to shut down a machine or N, making it inaccessible to its intended U.

This can be accomplished by overwhelming the targeted sm with excessive traffic or requests,
causing it to crash or become unavailable to U.

There are several different types of p_, including network-level A,
application-level A, and distributed p_ . 
    """)
    
    print(res_)



a Denial of service attack is an attack meant to shut down a machine or network, making it inaccessible to its intended users.

This can be accomplished by overwhelming the targeted system with excessive traffic or requests,
causing it to crash or become unavailable to users.

There are several different types of Denial of service attacks, including network-level A,
application-level A, and distributed Denial of service attacks . 
    

