## Note that some of the output are already printed below. You can just read them to save running time.
## Run the first 5 cells for the functions required to run the 3 parts.

In [1]:
import requests
import json
from datetime import datetime
import copy
import time
import base64
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric import dsa
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from Crypto.Util.strxor import strxor

In [2]:
def get_messages(username):
    r = requests.get(server_address + "getMessages/" + username, headers={"Accept" : "application/json"})
    json_r = json.loads(r.text)
    msg_qty = json_r["numMessages"]
    msg_qty_without_receipts = msg_qty
    encrypted_message_list = []
    if msg_qty == 0:
        return [], None
    else:
        for i in range(0, msg_qty):
            encrypted_msg = json_r["messages"][i]["message"]
            sender_id = json_r["messages"][i]["senderID"]
            message_id = json_r["messages"][i]["messageID"]
        return encrypted_message_list, json_r
def send_message(username, recipient_username, message_id, encrypted_message):
    data = {"recipient": recipient_username, "messageID": message_id, "message": encrypted_message}
    r = requests.post(server_address + "sendMessage/" + username, json=data, headers={"Accept" : "application/json"})
    return r
def intercept_message(username):
    encrypted_message_list = []
    json_r = None
    while encrypted_message_list==[]:
        time.sleep(0.05)
        encrypted_message_list, json_r = get_messages(username)
    for msg in encrypted_message_list:
        r = send_message(msg[2], username, msg[1], msg[0])
        #print r.text, "message sent from", msg[2], "to", username, msg[1], msg[0]
        print r.text
    return encrypted_message_list
def read_message(username):
    encrypted_message_list = []
    json_r = None
    while encrypted_message_list==[]:
        time.sleep(0.5)
        encrypted_message_list, json_r = get_messages(username)
    return encrypted_message_list

In [3]:
def rsa_keypair_gen():
    sk = rsa.generate_private_key(public_exponent=65537, key_size=1024, backend=default_backend())
    pk = sk.public_key()
    return pk, sk
def dsa_keypair_gen():
    sk = dsa.generate_private_key(key_size=1024, backend=default_backend())
    pk = sk.public_key()
    return pk, sk
def encode_pk(pk_rsa, pk_dsa):
    #1. DER encoding
    pk_rsa_bytes = pk_rsa.public_bytes(encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo)
    pk_dsa_bytes = pk_dsa.public_bytes(encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo)
    #2. Base64 encoding
    pk_rsa_bytes_encoded = base64.b64encode(pk_rsa_bytes)
    pk_dsa_bytes_encoded = base64.b64encode(pk_dsa_bytes)
    #3. Concatenation in format: Base64_RSA_PubKey||ASCII(0x25)||Base64_DSA_PubKey
    encoded_string = pk_rsa_bytes_encoded + "\x25" + pk_dsa_bytes_encoded
    return encoded_string
def register_key(username, key_data):
    data = {"keyData": key_data}
    r = requests.post(server_address + "registerKey/" + username, json=data, headers={"Accept" : "application/json"})
    return r
def dsa_sign(sk_dsa, msg):
    signer = sk_dsa.signer(hashes.SHA1())
    signer.update(msg)
    signature = signer.finalize()
    return signature

In [4]:
def maul_senderid_and_resign(c, new_sender_id, old_sender_id, new_sender_sk_dsa):
    c_elements = c.split("\x20")
    c1_base64 = c_elements[0]
    c2_base64 = c_elements[1]
    c1 = base64.b64decode(c1_base64)
    c2 = base64.b64decode(c2_base64)
    iv = c2[:16]
    c2_withoutiv = c2[16:]
    
    old_senderid_padded = str(old_sender_id) + str('\x00' * (len(c2_withoutiv)-len(old_sender_id)))
    new_senderid_padded = str(new_sender_id) + str('\x00' * (len(c2_withoutiv)-len(new_sender_id)))
    #print ' '.join(format(ord(x), 'b') for x in old_senderid_padded), ' '.join(format(ord(x), 'b') for x in new_senderid_padded)
    old_xor_new = strxor(old_senderid_padded, new_senderid_padded)
    #print ' '.join(format(ord(x), 'b') for x in old_xor_new)
    new_c2_withoutiv = strxor(str(c2_withoutiv), old_xor_new)
    #print new_c2_withoutiv
    new_c2 = iv + bytes(new_c2_withoutiv)
    new_c2_base64 = base64.b64encode(new_c2)
    
    new_c_without_sig_base64_utf8 = c1_base64.encode('utf-8')+"\x20".encode('utf-8')+new_c2_base64.encode('utf-8')
    #resign
    new_dsa_sig = dsa_sign(new_sender_sk_dsa, new_c_without_sig_base64_utf8)
    new_dsa_sig_base64_utf8 = base64.b64encode(new_dsa_sig).encode('utf-8')
    new_c = new_c_without_sig_base64_utf8 + "\x20".encode("utf-8") + new_dsa_sig_base64_utf8
    
    return new_c
    
def maul_senderid_and_replay(encrypted_message, new_sender_id):
    encrypted_string = encrypted_message[0]
    message_id = encrypted_message[1]
    old_sender_id = encrypted_message[2]
    recipient_id = encrypted_message[3]
    special_sender_id = new_sender_id
    
    if ":" in new_sender_id:
        special_sender_id = new_sender_id[:new_sender_id.index(":")]
    
    #generate new keys for new_sender_id
    pk_rsa, sk_rsa = rsa_keypair_gen()
    pk_dsa, sk_dsa = dsa_keypair_gen()
    encoded_str = encode_pk(pk_rsa, pk_dsa)
    #register the keys
    r = register_key(special_sender_id, encoded_str)
    if r.text == '{"result": true}':
        print special_sender_id, "keys registered successfully."
    
    new_encrypted_string = maul_senderid_and_resign(encrypted_string, new_sender_id, old_sender_id, sk_dsa)
    #new_sender_id has to be same length as old_sender_id
    new_encrypted_message = [new_encrypted_string, message_id, special_sender_id, recipient_id]
    r = send_message(new_encrypted_message[2], new_encrypted_message[3], new_encrypted_message[1], new_encrypted_message[0])
    return r.text, new_encrypted_message, sk_dsa, new_sender_id

In [5]:
def oracle_output(msg):
    x, r = get_messages(msg[2])
    r = send_message(msg[2], msg[3], msg[1], msg[0])
    time.sleep(8)
    enc_msg, json_r2 = get_messages(msg[2])
    if enc_msg != [] and enc_msg[0][2] == msg[3]:
        return True
    else:
        return False
def get_padding_length(msg, sk_dsa, mauled_senderid):
    new_msg = copy.deepcopy(msg)
    c = new_msg[0]
    sender = new_msg[2]
    sender_len = len(mauled_senderid)
    c_elements = c.split("\x20")
    c1_base64 = c_elements[0]
    c2_base64 = c_elements[1]
    c1 = base64.b64decode(c1_base64)
    c2 = base64.b64decode(c2_base64)
    iv = c2[:16]
    c2_withoutiv = c2[16:]
    
    padding_value = None
    str_to_xor = ""
    max_pad = 17
    if len(c2_withoutiv) < 16+sender_len+1:
        max_pad = 12
    for i in xrange(1, max_pad):
        str_to_xor = "\x00"*(len(c2_withoutiv)-i-1) + "\x01" + "\x00"*i
        print str_to_xor.encode('hex')
        
        new_c2_withoutiv = strxor(str(c2_withoutiv), str_to_xor)
        print new_c2_withoutiv.encode('hex')
        
        new_c2 = iv + bytes(new_c2_withoutiv)
        new_c2_base64 = base64.b64encode(new_c2)
        new_c_without_sig_base64_utf8 = c1_base64.encode('utf-8')+"\x20".encode('utf-8')+new_c2_base64.encode('utf-8')
        
        #resign and repack
        new_dsa_sig = dsa_sign(sk_dsa, new_c_without_sig_base64_utf8)
        new_dsa_sig_base64_utf8 = base64.b64encode(new_dsa_sig).encode('utf-8')
        new_c = new_c_without_sig_base64_utf8 + "\x20".encode("utf-8") + new_dsa_sig_base64_utf8
        new_msg[0] = new_c
        
        if oracle_output(new_msg) == True:
            padding_value = i
            break
    return padding_value
def padding_oracle_decrypter(msg, padding_value, sk_dsa, mauled_senderid):
    new_msg = copy.deepcopy(msg)
    c = new_msg[0]
    sender = new_msg[2]
    sender_len = len(mauled_senderid)
    c_elements = c.split("\x20")
    c1_base64 = c_elements[0]
    c2_base64 = c_elements[1]
    c1 = base64.b64decode(c1_base64)
    c2 = base64.b64decode(c2_base64)
    iv = c2[:16]
    c2_withoutiv = c2[16:]
    #print c2_withoutiv.encode('hex')
    
    str_to_xor = ""
    recovered_msg = ""
    mauled_c2 = str(c2_withoutiv)
    msg_len = len(c2_withoutiv)-sender_len-1-padding_value
    x = 0
    for y in xrange(1, msg_len+1):
        x += 1
        print "currently guessing", y, "of", msg_len, "bytes (message+CRC)"
        #to be used to XOR to neuralize the previous padding
        prev_padding_neutralizer = "\x00"*(len(mauled_c2)-(padding_value+x-1)) + str(bytearray([padding_value+x-1])*(padding_value+x-1))
        mauled_c2 = strxor(mauled_c2, str(prev_padding_neutralizer))
        
        #if all 16 bytes padding used up, cut off 15 bytes and start with padding_value of 1
        if x+padding_value > 16:
            mauled_c2 = mauled_c2[:-15]
            padding_value = 2
            x = 0
        
        #to be used to XOR to add a new padding
        new_padding_adder = "\x00"*(len(mauled_c2)-(padding_value+x-1)) + str(bytearray([padding_value+x])*(padding_value+x-1))
        
        mauled_c2 = strxor(mauled_c2, str(new_padding_adder))
        #print "new padding added with\t\t", new_padding_adder.encode('hex')
        #print mauled_c2
        
        for num in range(256):
            curr_guess = chr(num)
            curr_guess_xor_curr_pad = strxor(curr_guess, str(bytearray([padding_value+x])))
            #print num, curr_guess.encode('hex'), "xor", padding_value+x, "=", curr_guess_xor_curr_pad.encode('hex')
            str_to_xor = "\x00"*(len(mauled_c2)-(padding_value+x)) + curr_guess_xor_curr_pad + "\x00"*(padding_value+x-1)
            #print "Trying out with\t\t\t", str_to_xor.encode('hex')
            #print len(mauled_c2), len(str_to_xor)
            mauled_c2_test = strxor(mauled_c2, str_to_xor)
            #print "Mauled C2 is now\t\t", mauled_c2.encode('hex')
            
            #concat IV and C1 again
            new_c2 = iv + bytes(mauled_c2_test)
            new_c2_base64 = base64.b64encode(new_c2)
            new_c_without_sig_base64_utf8 = c1_base64.encode('utf-8')+"\x20".encode('utf-8')+new_c2_base64.encode('utf-8')
        
            #resign and repack
            new_dsa_sig = dsa_sign(sk_dsa, new_c_without_sig_base64_utf8)
            new_dsa_sig_base64_utf8 = base64.b64encode(new_dsa_sig).encode('utf-8')
            new_c = new_c_without_sig_base64_utf8 + "\x20".encode("utf-8") + new_dsa_sig_base64_utf8
            new_msg[0] = new_c
            
            if oracle_output(new_msg) == True:
                recovered_msg = str(curr_guess) + recovered_msg
                print "recovered", curr_guess.encode('hex')
                mauled_c2 = mauled_c2_test
                break
            elif num == 255:
                print "error: something went wrong.", recovered_msg.encode('hex')
    return recovered_msg

## Run this cell below this so that the program knows what server address to use.

In [6]:
server_address = "http://127.0.0.1/"

## Part 1:
With alice/bob clients already running, run this cell below to intercept a message. This will "catch" a message, store it, and send it back to the server without changing it. Print the variable "intercepted" at an empty cell below to see the message intercepted.

In [7]:
intercepted = intercept_message("bob")

{"message": "message delivered", "result": true}


## Part 2:
Now that we have a message intercepted from bob, run this cell below to maul the sender name from "alice" to "a:lice". This means that the client will interpret it as being sent from "a", because the running clients cuts out bytes for the sender id up until the ":" character. This step is important for the padding oracle attack to work, in order to obtain the first 4 bytes of the message.

In [8]:
rtext, mauled_msg, sk_dsa, mauled_senderid = maul_senderid_and_replay(intercepted[0], "a:ice")

a keys registered successfully.


Run this cell below to get the read receipt for the mauled message sent out above, as sender "a". This is also important for clearing the read receipts in the server for "a", because it might mess up the oracle later on.

In [9]:
read_message(mauled_msg[2])

[[u'KQ+vPbTSzTNmx2dHYUX3mtN2KFDVipBVCvh2cGdTADD+8SABnZwHydE4pHXvSPOrnAb49lqUrZwqHU6KAnAaUQvwuuwg62lJM9JObwjYWolsm7h1xeqEyoAD361EggQdmbm4oZePHMAlkMCD6eRxkTv/9hBVPbvV/E397uKdLus= AAAAAAAAAAAAAAAAAAAAAGs0FlQgzS2WaFrFFOn51VFpGqWRzjahXPYSGAx6Vz8G MCwCFD7EpF9VEr6ed0oNVL6Whbvm11VOAhQM5/AJSqB8C4M4c+SHQ5zzhI/Pjg==',
  9036,
  u'bob',
  'a']]

## Part 3:
This is the first part of the padding oracle attack. Run this cell below to find out the length of the padding currently in the encrypted message.

In [10]:
padding_val = get_padding_length(mauled_msg, sk_dsa, mauled_senderid)
print padding_val

000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100
21b153dac6a9b2d2b0a6052b924a78891e8ccacd34f07bba41e7f64aec810071ca8fbeffb15d6a496aa8927816a42f7b11a8b4c79e961adc5431099407fa4617dfdf993c5a66c6f7893325fe3cc0d9bc9e8a9c07442a9c1e2b05dbfe41e4107a
000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000
21b153dac6a9b2d2b0a6052b924a78891e8ccacd34f07bba41e7f64aec810071ca8fbeffb15d6a496aa8927816a42f7b11a8b4c79e961adc5431099407fa4617dfdf993c5a66c6f7893325fe3cc0d9bc9e8a9c07442a9c1e2b05dbfe41e5117a
000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000
21b153dac6a9b2d2b0a6052b924a78891e8

## Main part of padding oracle attack:
This is the main part of the padding oracle attack to decrypt the message that was intercepted and mauled earlier, given that we now know the length of the padding in the ciphertext. It will take several hours to recover the entire plaintext since the oracle takes about 5-8s to wait for a possible read receipt for each message. Part of the code's output is already printed below (last 20 bytes, including 4 bytes of CRC), showing the bytes that were recovered. Didn't have time to get the entire output below because it took too long.  
The entire plaintext in ASCII when fully decrypted should be:  
"Hey bob, what do you think about 09 F9 11 02 9D 74 E3 5B D8 41 56 C5 63 56 88 C0?"  
In hexadecimal:  
"48657920626f622c207768617420646f20796f75207468696e6b2061626f75742030392046392031312030322039442037342045332035422044382034312035362043352036332035362038382043303f"  
In hexadecimal with spaces between each character, for easier comparison with the output of my code:  
"48 65 79 20 62 6f 62 2c 20 77 68 61 74 20 64 6f 20 79 6f 75 20 74 68 69 6e 6b 20 61 62 6f 75 74 20 30 39 20 46 39 20 31 31 20 30 32 20 39 44 20 37 34 20 45 33 20 35 42 20 44 38 20 34 31 20 35 36 20 43 35 20 36 33 20 35 36 20 38 38 20 43 30 3f"  


In [None]:
print time.ctime()
padding_oracle_decrypter(mauled_msg, padding_val, sk_dsa, mauled_senderid)
print time.ctime()

Sun Dec 11 20:48:45 2016
currently guessing 1 of 85 bytes (message+CRC)
recovered f1
currently guessing 2 of 85 bytes (message+CRC)
recovered 8c
currently guessing 3 of 85 bytes (message+CRC)
recovered f4
currently guessing 4 of 85 bytes (message+CRC)
recovered c6
currently guessing 5 of 85 bytes (message+CRC)
recovered 3f
currently guessing 6 of 85 bytes (message+CRC)
recovered 30
currently guessing 7 of 85 bytes (message+CRC)
recovered 43
currently guessing 8 of 85 bytes (message+CRC)
recovered 20
currently guessing 9 of 85 bytes (message+CRC)
recovered 38
currently guessing 10 of 85 bytes (message+CRC)
recovered 38
currently guessing 11 of 85 bytes (message+CRC)
recovered 20
currently guessing 12 of 85 bytes (message+CRC)
recovered 36
currently guessing 13 of 85 bytes (message+CRC)
recovered 35
currently guessing 14 of 85 bytes (message+CRC)
recovered 20
currently guessing 15 of 85 bytes (message+CRC)
recovered 33
currently guessing 16 of 85 bytes (message+CRC)
recovered 36
currentl