# Assignment 3 Question 1

### CO 487/687 Applied Cryptography Fall 2024

This Jupyter notebook contains Python 3 code for Assignment 3 Question 1 on "Symemtric Encryption in Python".

### Documentation

- [Python cryptography library](https://cryptography.io/en/latest/)

The following code imports all the required functions for the assignment.

In [2]:
import base64
import getpass
import json
import os
import sys
from cryptography.hazmat.primitives import hashes, hmac
from cryptography.hazmat.primitives.kdf.scrypt import Scrypt
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes    
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives import constant_time
from timeit import default_timer as timer

These two functions convert a byte array into a printable string and back, which might be helpful to you since cryptographic routines often work with byte arrays.

In [3]:
def bytes2string(b):
    return base64.urlsafe_b64encode(b).decode('utf-8')

def string2bytes(s):
    return base64.urlsafe_b64decode(s.encode('utf-8'))

Implement the main encryption function below. Your function will take as input a string, and will output a string or dictionary containing all the values needed to decrypt (other than the password, of course). The code below will prompt the user to enter their password during encryption.

In [15]:
def encrypt(message):
    
    # encode the string as a byte string, since cryptographic functions usually work on bytes
    plaintext = message.encode('utf-8')

    # Use getpass to prompt the user for a password
    password = getpass.getpass("Enter password:")
    password2 = getpass.getpass("Enter password again:")

    # Do a quick check to make sure that the password is the same!
    if password != password2:
        sys.stderr.write("Passwords did not match")
        sys.exit()

    ### START: This is what you have to change

    salt = os.urandom(16)
    iv = os.urandom(16)

    kdf = Scrypt(salt=salt, length=64, n=2**14, r=8, p=1)
    key = kdf.derive(password.encode('utf-8'))
    
    aes_key = key[:32]
    hmac_key = key[32:]
    
    cipher = Cipher(algorithms.AES256(aes_key), modes.CFB(iv))
    enc = cipher.encryptor()
    ciphertext = enc.update(plaintext) + enc.finalize()
    
    h = hmac.HMAC(hmac_key, hashes.SHA3_512())
    h.update(iv + ciphertext)
    mac = h.finalize()
    
    return json.dumps({
        'salt': bytes2string(salt),
        'iv': bytes2string(iv),
        'ciphertext': bytes2string(ciphertext),
        'mac': bytes2string(mac)
    })

    ### END: This is what you have to change

Now we call the `encrypt` function with a message, and print out the ciphertext it generates.

In [17]:
mymessage = "Hello, world!"
ciphertext = encrypt(mymessage)
print(ciphertext)

{"salt": "POQwC2QvcfGttMm8pDdAPQ==", "iv": "JS4OWIB_O_djIMihv7oLPA==", "ciphertext": "xFAjLtLwJiXMAfNOlQ==", "mac": "JdlDLtoXeTrpbuP5b-XfNmx85dj6CYqTc0nj6_-90amxNh51TmuJ8ySSZI_47PaTyheauhXOAJ2osjyZMMv-og=="}


Implement the main decryption function below.  Your function will take as input the string or dictionary output by `encrypt`, prompt the user to enter the password, and then do all the relevant cryptographic operations.

In [18]:
def decrypt(ciphertext):
    
    # prompt the user for the password
    password = getpass.getpass("Enter the password:")

    ### START: This is what you have to change

    try:
        data = json.loads(ciphertext)
        salt = string2bytes(data['salt'])
        iv = string2bytes(data['iv'])
        ciphertext = string2bytes(data['ciphertext'])
        mac = string2bytes(data['mac'])
    except:
        raise ValueError("Invalid ciphertext")

    kdf = Scrypt(salt=salt, length=64, n=2**14, r=8, p=1)
    key = kdf.derive(password.encode('utf-8'))
    aes_key = key[:32]
    hmac_key = key[32:]
    
    h = hmac.HMAC(hmac_key, hashes.SHA3_512())
    h.update(iv + ciphertext)
    try:
        h.verify(mac)
    except:
        raise ValueError("Invalid MAC")
    
    cipher = Cipher(algorithms.AES256(aes_key), modes.CFB(iv))
    dec = cipher.decryptor()
    plaintext = dec.update(ciphertext) + dec.finalize()

    ### END: This is what you have to change

    # decode the byte string back to a string
    return plaintext.decode('utf-8')

Now let's try decrypting the ciphertext you encrypted above by entering the same password as you used for encryption.

In [19]:
mymessagedecrypted = decrypt(ciphertext)
print(mymessagedecrypted)
assert mymessagedecrypted == mymessage

Hello, world!


Try again but this time see what happens if you use a different password to decrypt. Your function should fail.

In [20]:
mymessagedecrypted = decrypt(ciphertext)
print(mymessagedecrypted)
assert mymessagedecrypted == mymessage

ValueError: Invalid MAC

If you would like to measure the runtime of a particular operation, you can use the following snippit of code:

In [22]:
msg = "Hello, world!"

msg_len = len(msg.encode('utf-8'))
encrypted_len = len(encrypt(msg).encode('utf-8'))
print(f'{msg_len=}, {encrypted_len=}, {encrypted_len-msg_len=}')

msg_len=13, encrypted_len=207, encrypted_len-msg_len=194
