# Set 1

### Challenge 1

In [1]:
from base64 import b64encode

In [2]:
hex_string = "49276d206b696c6c696e6720796f757220627261696e206c696b65206120706f69736f6e6f7573206d757368726f6f6d"
target_string = "SSdtIGtpbGxpbmcgeW91ciBicmFpbiBsaWtlIGEgcG9pc29ub3VzIG11c2hyb29t"

In [3]:
byte_array = bytes.fromhex(hex_string)

In [4]:
base64_byte_array = b64encode(byte_array)

In [5]:
assert base64_byte_array.decode("utf-8") == target_string

### Challenge 2

In [6]:
from base64 import b64encode

In [7]:
hex_string = "1c0111001f010100061a024b53535009181c"
key_string = "686974207468652062756c6c277320657965"
target_string = "746865206b696420646f6e277420706c6179"

In [8]:
byte_string = bytes.fromhex(hex_string)
key_byte_string = bytes.fromhex(key_string)

In [9]:
def xor_bytes(enc1, enc2):
    cipher = b"".join([bytes(b1^b2 for b1, b2 in zip(enc1, enc2))])
    return cipher

In [10]:
assert xor_bytes(byte_string, key_byte_string).hex() == target_string

### Challenge 3

In [11]:
from itertools import zip_longest

In [12]:
hex_string = "1b37373331363f78151b7f2b783431333d78397828372d363c78373e783a393b3736"

In [13]:
byte_string = bytes.fromhex(hex_string)

In [14]:
def calculate_score(text):
    
    # Block for frequency analysis
    frequency_chart = {
        'E': 12.70, 'T': 9.06, 'A': 8.17, 'O': 7.51, 'I': 6.97, 'N': 6.75, 'S': 6.33, 'H': 6.09, 
        'R': 5.99, 'D': 4.25, 'L': 4.03, 'C': 2.78, 'U': 2.76, 'M': 2.41, 'W': 2.36, 'F': 2.23,
        'G': 2.02, 'Y': 1.97, 'P': 1.93, 'B': 1.29, 'V': 0.98, 'K': 0.77, 'J': 0.15,
        'X': 0.15, 'Q': 0.10, 'Z': 0.07, ' ': 35
    }
    
    score = 0.0
    for letter in text.upper():
        score += frequency_chart.get(letter, 0)
    return score

In [15]:
def single_byte_xor_letters(ciphertext):
    
    ascii_text_chars = list(range(97, 122)) + [32]
    best_candidate = None
    
    for i in range(2**8): # for every possible key
        # converting the key from a number to a byte
        candidate_key = i.to_bytes(1, byteorder='little')
        keystream = candidate_key*len(ciphertext)
        candidate_message = bytes([x^y for (x, y) in zip(ciphertext, keystream)])
        nb_letters = sum([ x in ascii_text_chars for x in candidate_message])
        # if the obtained message has more letters than any other candidate before
        if best_candidate == None or nb_letters > best_candidate['nb_letters']:
            # store the current key and message as our best candidate so far
            best_candidate = {"message": candidate_message, 'nb_letters': nb_letters, 'key': candidate_key}
    
    return best_candidate


In [16]:
def single_byte_xor_score(ciphertext):
    max_score = 0
    key = ''
    plaintext = ""
    
    for testkey in range(256):
        testtext = ""
        for letter in ciphertext:
            testtext += chr(letter ^ testkey)
        
        cur_score = calculate_score(testtext)
        if cur_score > max_score:
            max_score = cur_score
            key = chr(testkey)
            plaintext = testtext
            
    return {'score' : max_score, 'key' : key, 'message' : plaintext}

In [17]:
single_byte_xor_letters(byte_string)['message'].decode("utf-8")

"Cooking MC's like a pound of bacon"

### Challenge 4

In [18]:
inf = open('4.txt', 'r')
hex_data = inf.read()
byte_strings = []

for line in hex_data.split():
    byte_line = bytes.fromhex(line)
    byte_strings.append(byte_line)

In [19]:
plaintext = ""
max_score = 0

for line in byte_strings:
    result = single_byte_xor_score(line)
    cur_score = result['score']
    testtext = result['message']
    if cur_score > max_score:
        max_score = cur_score
        plaintext = testtext

print(plaintext)

Now that the party is jumping



### Challenge 5

In [20]:
plaintext = "Burning 'em, if you ain't quick and nimble\nI go crazy when I hear a cymbal"
key = "ICE"
target_string = "0b3637272a2b2e63622c2e69692a23693a2a3c6324202d623d63343c2a26226324272765272a282b2f20430a652e2c652a3124333a653e2b2027630c692b20283165286326302e27282f"

In [21]:
byte_string = plaintext.encode()
bytekey = key.encode()

In [22]:
def repeated_xor(text, key):
    quotient, remainder = divmod(len(text), len(key))
    return bytes([x ^ y for x, y in zip(text, bytes(key * quotient + key[:remainder]))])

In [23]:
ciphertext = repeated_xor(byte_string, bytekey)

In [24]:
assert target_string == ciphertext.hex()

### Challenge 6

In [25]:
from base64 import b64decode

In [26]:
inf = open('6.txt', 'r')
b64_data = inf.read()
byte_data = b64decode(b64_data)

In [27]:
def hamming_distance(text1, text2):
    distance = 0
    
    dec_list = [b1 ^ b2 for b1, b2 in zip(text1, text2)]
    for decimal in dec_list:
        distance += bin(decimal).count("1")
        
    if len(text1) > len(text2):
        diff = len(text1) - len(text2)
        text = text1
    else:
        diff = len(text2) - len(text1)
        text = text2
        
    for i in range(1, diff+1):
        distance += bin(text[-i]).count("1")
        
    return distance

In [28]:
def break_repeated_xor(ciphertext):
    keysize = 0
    min_distance = 100000
    for key in range(2, 41):
        edit_distance = 0
        blocks = [ciphertext[i*key:(i+1)*key] for i in range(4)]
        for i in range(0, len(blocks)):
            for j in range(0, len(blocks)):
                edit_distance += hamming_distance(blocks[i], blocks[j])
        
        normalized_distance = edit_distance/key
        
        if normalized_distance < min_distance:
            min_distance = normalized_distance
            keysize = key
    
    return keysize

In [29]:
keysize = break_repeated_xor(byte_data)
cipher_blocks = [byte_data[i:i+keysize] for i in range(0, len(byte_data), keysize)]
#To remove the last block with less characters
cipher_blocks.pop()
cipher_block_size = len(cipher_blocks[0])

In [30]:
key = ""
for i in range(0, cipher_block_size):
    single_xor_block = b""
    for block in cipher_blocks:
        single_xor_block += (block[i]).to_bytes(1, "little")
        
    result = single_byte_xor_score(single_xor_block)
    testkey = result['key']
    key += testkey

In [31]:
print("Key:", key)
print("\nDeciphered text:\n", repeated_xor(byte_data, key.encode()).decode("utf-8").strip())

Key: Terminator X: Bring the noise

Deciphered text:
 I'm back and I'm ringin' the bell 
A rockin' on the mike while the fly girls yell 
In ecstasy in the back of me 
Well that's my DJ Deshay cuttin' all them Z's 
Hittin' hard and the girlies goin' crazy 
Vanilla's on the mike, man I'm not lazy. 

I'm lettin' my drug kick in 
It controls my mouth and I begin 
To just let it flow, let my concepts go 
My posse's to the side yellin', Go Vanilla Go! 

Smooth 'cause that's the way I will be 
And if you don't give a damn, then 
Why you starin' at me 
So get off 'cause I control the stage 
There's no dissin' allowed 
I'm in my own phase 
The girlies sa y they love me and that is ok 
And I can dance better than any kid n' play 

Stage 2 -- Yea the one ya' wanna listen to 
It's off my head so let the beat play through 
So I can funk it up and make it sound good 
1-2-3 Yo -- Knock on some wood 
For good luck, I like my rhymes atrocious 
Supercalafragilisticexpialidocious 
I'm an effect and that 

### Challenge 7

In [32]:
from base64 import b64decode
from Crypto.Cipher import AES

In [33]:
inf = open('7.txt', 'r')
b64_data = inf.read()
byte_data = b64decode(b64_data)

key = b"YELLOW SUBMARINE"

In [34]:
def AES_ECB_Decrypt(ciphertext, key):
    cipher = AES.new(key, AES.MODE_ECB)
    return cipher.decrypt(ciphertext)

In [35]:
byte_text = AES_ECB_Decrypt(byte_data, key)
#last 4 rubbish bytes is pkcs7 padding of \x04
print(byte_text.decode("utf-8").strip())

I'm back and I'm ringin' the bell 
A rockin' on the mike while the fly girls yell 
In ecstasy in the back of me 
Well that's my DJ Deshay cuttin' all them Z's 
Hittin' hard and the girlies goin' crazy 
Vanilla's on the mike, man I'm not lazy. 

I'm lettin' my drug kick in 
It controls my mouth and I begin 
To just let it flow, let my concepts go 
My posse's to the side yellin', Go Vanilla Go! 

Smooth 'cause that's the way I will be 
And if you don't give a damn, then 
Why you starin' at me 
So get off 'cause I control the stage 
There's no dissin' allowed 
I'm in my own phase 
The girlies sa y they love me and that is ok 
And I can dance better than any kid n' play 

Stage 2 -- Yea the one ya' wanna listen to 
It's off my head so let the beat play through 
So I can funk it up and make it sound good 
1-2-3 Yo -- Knock on some wood 
For good luck, I like my rhymes atrocious 
Supercalafragilisticexpialidocious 
I'm an effect and that you can bet 
I can take a fly girl and make her wet. 


### Challenge 8

In [36]:
from base64 import b64decode
from Crypto.Cipher import AES

In [37]:
ciphertext_list = [bytes.fromhex(line.strip()) for line in open('8.txt')]

In [38]:
def detectAES_ECB(ciphertext):
    blocks = [ciphertext[i:i+AES.block_size] for i in range(0, len(ciphertext), AES.block_size)]
    return len(blocks)-len(set(blocks))

In [39]:
max_score = 0
text_ECB = ""

for cipher in ciphertext_list:
    score = detectAES_ECB(cipher)
    if score > max_score:
        max_score = score
        text_ECB = cipher
        
print("Number of repitions: {}".format(max_score))
print("ECB ciphered text: {}".format(text_ECB))

Number of repitions: 3
ECB ciphered text: b'\xd8\x80a\x97@\xa8\xa1\x9bx@\xa8\xa3\x1c\x81\n=\x08d\x9a\xf7\r\xc0oO\xd5\xd2\xd6\x9ctL\xd2\x83\xe2\xdd\x05/kd\x1d\xbf\x9d\x11\xb04\x85B\xbbW\x08d\x9a\xf7\r\xc0oO\xd5\xd2\xd6\x9ctL\xd2\x83\x94u\xc9\xdf\xdb\xc1\xd4e\x97\x94\x9d\x9c~\x82\xbfZ\x08d\x9a\xf7\r\xc0oO\xd5\xd2\xd6\x9ctL\xd2\x83\x97\xa9>\xab\x8dj\xec\xd5fH\x91Tx\x9ak\x03\x08d\x9a\xf7\r\xc0oO\xd5\xd2\xd6\x9ctL\xd2\x83\xd4\x03\x18\x0c\x98\xc8\xf6\xdb\x1f*?\x9c@@\xde\xb0\xabQ\xb2\x993\xf2\xc1#\xc5\x83\x86\xb0o\xba\x18j'


# Set 2

### Challenge 9

In [74]:
plaintext = "YELLOW SUBMARINE"
target_bytes = b"YELLOW SUBMARINE\x04\x04\x04\x04"

In [75]:
block_size = 20

In [76]:
def pkcs7_pad(plaintext, block_size):
    if len(plaintext) == block_size:
        return plaintext
    pad = block_size - len(plaintext) % block_size
    plaintext += (pad.to_bytes(1,"big"))*pad
    return plaintext

In [77]:
assert pkcs7_pad(plaintext.encode(), block_size) == target_bytes

### Challenge 10

In [44]:
from base64 import b64decode
from Crypto.Cipher import AES

In [45]:
byte_string = b"".join([b64decode(line.strip()) for line in open("10.txt").readlines()])
key = b"YELLOW SUBMARINE"

In [46]:
def AES_ECB_Decrypt(ciphertext, key):
    cipher = AES.new(key, AES.MODE_ECB)
    return pkcs7_unpad(cipher.decrypt(ciphertext))

In [47]:
def pkcs7_padded(text):
    padding = text[-text[-1]:]
    # Check that all the bytes in the range indicated by the padding are equal to the padding value itself
    return all(padding[b] == len(padding) for b in range(0, len(padding)))

In [48]:
def pkcs7_unpad(paddedtext):
    if pkcs7_padded(paddedtext):
        pad_Length = paddedtext[len(paddedtext)-1]
        return paddedtext[:-pad_Length]
    else:
        return paddedtext

In [49]:
def AES_CBC_Decrypt(ciphertext, IV, key):
    previous = IV
    keysize = len(key)
    plaintext = b""
    cipher = ""
    
    for i in range(0, len(ciphertext), keysize):
        cipher = AES_ECB_Decrypt(ciphertext[i:i+keysize], key)
        xor_list = [chr(b1 ^ b2) for b1, b2 in zip(cipher, previous)]
        plaintext += "".join(xor_list).encode()
        previous = ciphertext[i:i+keysize]
        
    return plaintext

In [50]:
text = pkcs7_unpad(AES_CBC_Decrypt(byte_string, b'\x00'*AES.block_size, key))
print(text.decode("utf-8").strip('\n'))

I'm back and I'm ringin' the bell 
A rockin' on the mike while the fly girls yell 
In ecstasy in the back of me 
Well that's my DJ Deshay cuttin' all them Z's 
Hittin' hard and the girlies goin' crazy 
Vanilla's on the mike, man I'm not lazy. 

I'm lettin' my drug kick in 
It controls my mouth and I begin 
To just let it flow, let my concepts go 
My posse's to the side yellin', Go Vanilla Go! 

Smooth 'cause that's the way I will be 
And if you don't give a damn, then 
Why you starin' at me 
So get off 'cause I control the stage 
There's no dissin' allowed 
I'm in my own phase 
The girlies sa y they love me and that is ok 
And I can dance better than any kid n' play 

Stage 2 -- Yea the one ya' wanna listen to 
It's off my head so let the beat play through 
So I can funk it up and make it sound good 
1-2-3 Yo -- Knock on some wood 
For good luck, I like my rhymes atrocious 
Supercalafragilisticexpialidocious 
I'm an effect and that you can bet 
I can take a fly girl and make her wet. 


### Challenge 11

In [51]:
import random
import os
from Crypto.Cipher import AES

In [52]:
def AES_CBC_Encrypt(plaintext, IV, key):
    previous = IV
    keysize = len(key)
    ciphertext = b""
    xored = b""
    
    for i in range(0, len(plaintext), keysize):
        xor_list = [(b1 ^ b2).to_bytes(1, "little") for b1, b2 in zip(pkcs7_pad(plaintext[i:i+keysize], keysize), previous)]
        xored = b"".join(xor_list)
        cipher = AES_ECB_Encrypt(xored, key)
        ciphertext += cipher
        previous = cipher
        
    return ciphertext

In [53]:
def AES_ECB_encrypt(plaintext, key):
    cipher = AES.new(key, AES.MODE_ECB)
    text = pkcs7_pad(plaintext, len(key))
    return cipher.encrypt(pkcs7_pad(plaintext, len(key)))

In [54]:
key = os.urandom(16)

# Enter a repeating text
text = open('8.txt').read()
    
plaintext = os.urandom(random.randint(5,11))
plaintext += text.encode()
plaintext += os.urandom(random.randint(5,11))

flag = random.randint(0,1)
if flag == 1:
    print("Encrypting using AES ECB Encryption.")
    ciphertext = AES_ECB_Encrypt(plaintext, key)
else:
    print("Encrypting using AES CBC Encryption.")
    ciphertext = AES_CBC_Encrypt(plaintext, os.urandom(AES.block_size), key)
    
if detectAES_ECB(ciphertext):
    print("Ciphertext is AES ECB encrypted.")
else:
    print("Ciphertext is AES CBC encrypted.")

Encrypting using AES CBC Encryption.
Ciphertext is AES CBC encrypted.


### Challenge 12

In [55]:
import random
import os
from base64 import b64decode
from Crypto.Cipher import AES

In [56]:
#pseudo random key, cell to be run only once
random_key = os.urandom(16)

In [57]:
def AES128(text):
    b64_string = "Um9sbGluJyBpbiBteSA1LjAKV2l0aCBteSByYWctdG9wIGRvd24gc28gbXkgaGFpciBjYW4gYmxvdwpUaGUgZ2lybGllcyBvbiBzdGFuZGJ5IHdhdmluZyBqdXN0IHRvIHNheSBoaQpEaWQgeW91IHN0b3A/IE5vLCBJIGp1c3QgZHJvdmUgYnkK"
    secret_string = b64decode(b64_string)
    plaintext = text + secret_string
    cipher = AES_ECB_Encrypt(plaintext, random_key)
    return cipher

In [58]:
def AES_ECB_KeySize(ciphertext):
    text = "A random key long enough to decode the key size used in the encryption"
#robust way checks repetition of blocks. could check increase in cipher length because of padding
    for i in range(1, len(text)):
        plaintext = text[:i] + text[:i]
        cipher = AES128(plaintext.encode())
        if cipher[:i] == cipher[i:2*i]:                          
            print("Key size used for the given ciphertext is {}".format(i))
            return i
        
def AES_ECB_KeySize():
    text = "A random key long enough to decode the key size used in the encryption"
    plaintext = "X"
    cipher = AES128(plaintext.encode())
    initial_length = len(cipher)
    for i in range(2, len(text)):
        plaintext = text[:i]
        cipher = AES128(plaintext.encode())
        if len(cipher) > initial_length:
            return len(cipher) - initial_length

In [59]:
def breakAES_ECB(keysize, func):
    deciphered = b""
    
    # to get added string length since 0 len input is provided, all cipher is of added string
    ciphertext = func(deciphered)
    run = len(ciphertext)
    
    for i in range(1, run+1):
        template = b'A'*(run - i)
        cipher = func(template)
        
        for j in range(256):
            text = template + deciphered + j.to_bytes(1, "little")
            c = func(text)
    # keysize used to refer to the block whose last character is made to be the appended string's 1st char
            if c[run-keysize:run] == cipher[run-keysize:run]:
                deciphered += chr(j).encode()
                break
    
    return pkcs7_unpad(deciphered)

In [60]:
#get keysize to identify block size
#plaintext = b"Is it easier to stay is it easier to go"
#ciphertext = AES128(plaintext)
keysize = AES_ECB_KeySize()

# decipher appended input
deciphered = breakAES_ECB(keysize, AES128)
print("Given base64 encoded string was:\n\n{}".format(deciphered.decode("utf-8").strip('\n')))

Given base64 encoded string was:

Rollin' in my 5.0
With my rag-top down so my hair can blow
The girlies on standby waving just to say hi
Did you stop? No, I just drove by


### Challenge 13

In [61]:
string_set = "foo=bar&baz=qux&zap=zazzle"
dictionary = {key:val for key, val in (element.split('=') for element in string_set.split('&'))}
print(dictionary)

{'foo': 'bar', 'baz': 'qux', 'zap': 'zazzle'}


In [62]:
def parser(user, encode):
    if encode == True:
        parsed_string =  "&".join(key.strip(":")+"="+val for key, val in zip(user.keys(), user.values()))
        return parsed_string.encode()
    else:
        return {key:val for key, val in (element.split('=') for element in user.split('&'))}

In [63]:
def profile_for(val):
    val = val.decode("utf-8")
    user = {"email:": val, "uid:": "10", "role": "user"}
    return parser(user, True)

In [65]:
def oracle(email):
    encoded_profile = AES_ECB_Encrypt(profile_for(email), random_key)
    return encoded_profile

In [66]:
email = b"lol@gmail.com"
decoded_profile = AES_ECB_Decrypt(oracle(email), random_key)
profile = parser(decoded_profile.decode("utf-8"), False)
print(decoded_profile.decode("utf-8"))

email=lol@gmail.com&uid=10&role=user


In [67]:
keysize = 16
email = b"f"*(keysize-len("email=")) + pkcs7_pad(b"admin", keysize)
cipher = oracle(email)
encoded_admin_bytes = cipher[keysize:keysize*2]

num_blocks = int((len("&uid=10") + len("email=") + len("&role="))/keysize) + 1
email = b"f"*(num_blocks*keysize - (len("&uid=10") + len("email=") + len("&role=")-6))+b"@gmail.com"
cipher = oracle(email)
modified_cipher = cipher[:48] + encoded_admin_bytes

cracked_cipher_plaintext = parser(AES_ECB_Decrypt(modified_cipher, random_key).decode("utf-8"), False)
print(cracked_cipher_plaintext)

{'email': 'fffffffffffffffffff@gmail.com', 'uid': '10', 'role': 'admin'}


In [68]:
assert cracked_cipher_plaintext['role'] == 'admin'

### Challenge 14

In [69]:
from base64 import b64decode
import random
import math
import os

In [70]:
#pseudo random key and prefix string, cell to be run only once
random_key = os.urandom(16)
random_string = os.urandom(random.randint(0,255))

In [71]:
def AES128_harder(text):
    b64_string = "Um9sbGluJyBpbiBteSA1LjAKV2l0aCBteSByYWctdG9wIGRvd24gc28gbXkgaGFpciBjYW4gYmxvdwpUaGUgZ2lybGllcyBvbiBzdGFuZGJ5IHdhdmluZyBqdXN0IHRvIHNheSBoaQpEaWQgeW91IHN0b3A/IE5vLCBJIGp1c3QgZHJvdmUgYnkK"
    secret_string = b64decode(b64_string)
    plaintext = random_string + text + secret_string
    cipher = AES_ECB_Encrypt(plaintext, random_key)
    return cipher

In [72]:
def breakAES_ECB_harder(keysize, func):
    
    # padding required to bridge gap between randomstringlength and block
    padding = 0
    random_blocks = 0
    cipher_length = len(func(b''))
    prefix_length = len(os.path.commonprefix([func(b'AAAA'), func(b'')]))
    print("Prefix length: ", prefix_length)
    
    for i in range(int(cipher_length/keysize)):
        if prefix_length < i*keysize:
            random_blocks = i
            break
    print("Random blocks: ", random_blocks)
    
    base_cipher = func(b'')
    for i in range(1, keysize):
        new_cipher = func(b'A'*i)
        new_prefix_length = len(os.path.commonprefix([base_cipher, new_cipher]))
        if new_prefix_length > prefix_length:
            padding = i - 1
            break
        base_cipher = new_cipher
    print("Number of bytes of padding required: ", padding)
    
    # to get added string length since 0 len input is provided, all cipher is of added string
    deciphered = b""
    ciphertext = func(deciphered)
    # because of one block increase due to addition of padding
    run = len(ciphertext) + keysize
    
    # should start after prefix random_blocks because till then it value will be same for original cipher and templated cipehr since same prepended string will be comapred
    for i in range(157, run+1):
        template = b'A'*(run - i + padding)
        cipher = func(template)
        for j in range(256):
            #print(i, j)
            text = template + deciphered + j.to_bytes(1, "little")
            c = func(text)
    # keysize used to refer to the block whose last character is made to be the appended string's 1st char
            if c[run-keysize:run] == cipher[run-keysize:run]:
                deciphered += chr(j).encode()
                break
    return pkcs7_unpad(deciphered)

In [73]:
keysize = 16
byte_text = breakAES_ECB_harder(keysize, AES128_harder)
print("\nDeciphered string:\n")
print(byte_text.decode("utf-8").strip())

Prefix length:  32
Random blocks:  3
Number of bytes of padding required:  11


IndexError: index out of range

### Challenge 15

In [None]:
given_string = "ICE ICE BABY\x04\x04\x04\x04"
target_string = "ICE ICE BABY"

In [None]:
assert target_string.encode() == pkcs7_unpad(given_string.encode())

### Challenge 16

In [None]:
random_key = os.urandom(16)
IV = os.urandom(random.randint(0,255))
keysize = 16
prepend_string = "comment1=cooking%20MCs;userdata="
append_string = ";comment2=%20like%20a%20pound%20of%20bacon"

In [None]:
def encryptor(text, IV, key):
    plaintext =  (prepend_string.encode() + text + append_string.encode()).replace(b';', b'";"').replace(b'=', b'"="')
    ciphertext = AES_CBC_Encrypt(pkcs7_pad(plaintext, len(key)), IV, key)
    return ciphertext

In [None]:
def decryptor(byte_string, IV, key) -> bool:
    decrypted_string = pkcs7_unpad(AES_CBC_Decrypt(byte_string, IV, key))
    if b';admin=true;' in decrypted_string:
        return True
    else:
        return False

In [None]:
# padding required to bridge gap between randomstringlength and block
padding = 0
random_blocks = 0

cipher_length = len(encryptor(b'', IV, random_key))
prefix_length = len(os.path.commonprefix([encryptor(b'AAAA', IV, random_key), encryptor(b'', IV, random_key)]))
print("Prefix length: ", prefix_length)

for i in range(int(cipher_length/keysize)):
    if prefix_length < i*keysize:
        random_blocks = i
        break
print("Random blocks: ", random_blocks)

base_cipher = encryptor(b'', IV, random_key)
for i in range(1, keysize):
    new_cipher = encryptor(b'A'*i, IV, random_key)
    new_prefix_length = len(os.path.commonprefix([base_cipher, new_cipher]))
    if new_prefix_length > prefix_length:
        padding = i - 1
        break
    base_cipher = new_cipher
print("Number of bytes of padding required: ", padding)

input_text = b'A'*padding + b'heytheremama'
string = b";admin=true;"
modified_string = b""
ciphertext = encryptor(input_text, IV, random_key)
for i in range(len(string)):
    modified_string += (ciphertext[i+(random_blocks-1)*keysize]^(input_text[i+padding]^string[i])).to_bytes(1, "little")
    
modified_ciphertext = ciphertext[:(random_blocks-1)*keysize] + modified_string + ciphertext[(random_blocks-1)*keysize + len(modified_string):]

In [None]:
AES_CBC_Decrypt(modified_ciphertext, IV, random_key)

In [None]:
assert decryptor(modified_ciphertext, IV, random_key) == True

# Set 3

### Challenge 17

In [None]:
import random
from base64 import b64decode
from Crypto.Cipher import AES

In [None]:
b64_strings = [
    b'MDAwMDAwTm93IHRoYXQgdGhlIHBhcnR5IGlzIGp1bXBpbmc=',
    b'MDAwMDAxV2l0aCB0aGUgYmFzcyBraWNrZWQgaW4gYW5kIHRoZSBWZWdhJ3MgYXJlIHB1bXBpbic=',
    b'MDAwMDAyUXVpY2sgdG8gdGhlIHBvaW50LCB0byB0aGUgcG9pbnQsIG5vIGZha2luZw==',
    b'MDAwMDAzQ29va2luZyBNQydzIGxpa2UgYSBwb3VuZCBvZiBiYWNvbg==',
    b'MDAwMDA0QnVybmluZyAnZW0sIGlmIHlvdSBhaW4ndCBxdWljayBhbmQgbmltYmxl',
    b'MDAwMDA1SSBnbyBjcmF6eSB3aGVuIEkgaGVhciBhIGN5bWJhbA==',
    b'MDAwMDA2QW5kIGEgaGlnaCBoYXQgd2l0aCBhIHNvdXBlZCB1cCB0ZW1wbw==',
    b'MDAwMDA3SSdtIG9uIGEgcm9sbCwgaXQncyB0aW1lIHRvIGdvIHNvbG8=',
    b'MDAwMDA4b2xsaW4nIGluIG15IGZpdmUgcG9pbnQgb2g=',
    b'MDAwMDA5aXRoIG15IHJhZy10b3AgZG93biBzbyBteSBoYWlyIGNhbiBibG93',
]

random_key = os.urandom(16)
IV = os.urandom(16)
keysize = AES.block_size
block_size =keysize

In [None]:
def encryptor():
    index = random.randint(0, len(b64_strings)-1)
    selected_string = b64_strings[index]
    ciphertext = AES_CBC_Encrypt(selected_string, IV, random_key)
    return selected_string, ciphertext

In [None]:
def pkcs7_padding_validation(byte_string):
    last_byte = byte_string[-1]
    if last_byte > len(byte_string):
        return False
    for i in range(last_byte, 0, -1):
        if byte_string[-i] != last_byte:
            return False
    return True

In [None]:
def decryptor(ciphertext, IV):
    plaintext = AES_CBC_Decrypt(ciphertext, IV, random_key)
    if pkcs7_padding_validation(plaintext):
        return True
    else:
        return False

In [None]:
def modify_block(iv, guessed_byte, padding_len, found_plaintext):
    """Creates a forced block of the ciphertext, ideally to be given as IV to decrypt the following block.
    The forced IV will be used for the attack on the padding oracle CBC encryption.
    """

    # Get the index of the first character of the padding
    index_of_forced_char = len(iv) - padding_len

    # Using the guessed byte given as input, try to force the first character of the
    # padding to be equal to the length of the padding itself
    forced_character = iv[index_of_forced_char] ^ guessed_byte ^ padding_len

    # Form the forced ciphertext by adding to it the forced character...
    output = iv[:index_of_forced_char] + bytes([forced_character])

    # ...and the characters that were forced before (for which we already know the plaintext)
    m = 0
    for k in range(block_size - padding_len + 1, block_size):

        # Force each of the following characters of the IV so that the matching characters in
        # the following block will be decrypted to "padding_len"
        forced_character = iv[k] ^ ord(found_plaintext[m]) ^ padding_len
        output += bytes([forced_character])
        m += 1

    return output

In [None]:
def cbc_padding_attack(ciphertext, IV, func):

    plaintext = ""
    num_blocks = len(ciphertext)//keysize
    ciphertext_blocks = [IV] + [ciphertext[i:i+keysize] for i in range(0, len(ciphertext), keysize)]
    
    for i in range(1, num_blocks+1):
        plain_block = ""
        base_block = ciphertext_blocks[i-1] 
        target_block = ciphertext_blocks[i]
                
        for j in range(1, keysize+1):
            possible_last_bytes = []
            for k in range(256):
                
                mod_block = modify_block(base_block, k, j, plain_block)                
                check = decryptor(target_block, mod_block)
                if check == True:
                    possible_last_bytes += bytes([k])
             
            if len(possible_last_bytes) != 1:
                for byte in possible_last_bytes:
                    for k in range(256):
                        
                        mod_block = modify_block(base_block, k, j+1, chr(byte)+plain_block)                
                        
                        check = decryptor(target_block, mod_block)
                        if check == True:
                            possible_last_bytes = [byte]
                            break
                            
            plain_block = chr(possible_last_bytes[0]) + plain_block
        plaintext += plain_block
    
    return pkcs7_unpad(plaintext.encode())

In [None]:
selected_string, ciphertext = encryptor()
plaintext = cbc_padding_attack(ciphertext, IV, encryptor)
print(selected_string, plaintext)
assert selected_string == plaintext
b64decode(plaintext).decode("utf-8")

### Challenge 18

In [None]:
from base64 import b64decode

In [None]:
b64_string = "L77na/nrFsKvynd6HzOoG7GHTLXsTVu9qvY/2syLXzhPweyyMTJULu/6/kXX0KSvoOLSFQ=="
decoded_string = b64decode(b64_string)
key = "YELLOW SUBMARINE"

nonce = 0
# format=64 bit unsigned little endian nonce,
#        64 bit little endian block count (byte count / 16)

In [None]:
def CTR_Keystream(key, nonce):
    counter = 0
    # 8 byte because format says 64bit
    nonce_bytes = nonce.to_bytes(8, "little")
    
    while True:
        counter_bytes = counter.to_bytes(8, "little")
        # keep getting 16byte block from the encryption function
        keystream_block = AES_ECB_Encrypt(nonce_bytes + counter_bytes, key)
        yield from keystream_block
        counter += 1

In [None]:
def CTR(string, key, nonce):
    keystream = CTR_Keystream(key, nonce)
    if len(string) == 0:
        return b""
    else:
        return xor_bytes(string, keystream)

In [None]:
byte_text = CTR(decoded_string, key.encode(), 0)
print(byte_text.decode("utf-8"))

### Challenge 19 & 20

In [None]:
import os
from base64 import b64decode

In [None]:
random_key = os.urandom(16)
decoded_strings = [b64decode(line.strip()) for line in open('20.txt')]
ciphertext_list = [CTR(string, random_key, nonce) for string in decoded_strings]
min_ciphertext_length = min(map(len, ciphertext_list))

In [None]:
columns = []
for i in range(min_ciphertext_length):
    line = b""
    for cipher in ciphertext_list:
        line += cipher[i].to_bytes(1, "little")
    result = single_byte_xor_letters(line)
    columns.append(result['message'])

message = ""
for i in range(min_ciphertext_length):
    for c in columns:
        message += chr(c[i])
print(message)

# single_byte_xor_score gives a less accurate result

### Challenge 21

In [None]:
import time

In [None]:
def get_lowest_bits(n, number_of_bits):
    """Returns the lowest "number_of_bits" bits of n."""
    mask = (1 << number_of_bits) - 1
    return n & mask

class MT19937:
    """This implementation resembles the one of the Wikipedia pseudo-code."""
    W, N, M, R = 32, 624, 397, 31
    A = 0x9908B0DF
    U, D = 11, 0xFFFFFFFF
    S, B = 7, 0x9D2C5680
    T, C = 15, 0xEFC60000
    L = 18
    F = 1812433253
    LOWER_MASK = (1 << R) - 1
    UPPER_MASK = get_lowest_bits(not LOWER_MASK, W)
    
    def __init__(self, seed):
        self.mt = []

        self.index = self.N
        self.mt.append(seed)
        for i in range(1, self.index):
            self.mt.append(get_lowest_bits(self.F * (self.mt[i - 1] ^ (self.mt[i - 1] >> (self.W - 2))) + i, self.W))
            
    def extract_number(self):
        if self.index >= self.N:
            self.twist()

        y = self.mt[self.index]
        y ^= (y >> self.U) & self.D
        y ^= (y << self.S) & self.B
        y ^= (y << self.T) & self.C
        y ^= (y >> self.L)

        self.index += 1
        return get_lowest_bits(y, self.W)

    def twist(self):
        for i in range(self.N):
            x = (self.mt[i] & self.UPPER_MASK) + (self.mt[(i + 1) % self.N] & self.LOWER_MASK)
            x_a = x >> 1
            if x % 2 != 0:
                x_a ^= self.A

            self.mt[i] = self.mt[(i + self.M) % self.N] ^ x_a

        self.index = 0

In [None]:
# Check if the numbers look random
for i in range(10):
    print(MT19937(i).extract_number())

### Challenge 22

In [None]:
import time

In [None]:
def MT19937_timestamp_seed():
    time.sleep(random.randint(40, 100))
    seed = int(time.time())
    mt_rng = MT19937(seed)
    time.sleep(random.randint(40, 100))
    return mt_rng.extract_number(), seed

In [None]:
def break_MT19937_seed(rng_function):
    random_number, real_seed = rng_function()
    now = int(time.time())
    # Assuming nobody waits more than 220 seconds to get a random number
    before = now - 220
    for seed in range(before, now):
        rng = MT19937(seed)
        number = rng.extract_number()
        if number == random_number:
            return seed

In [None]:
break_MT19937_seed(MT19937_timestamp_seed)

### Challenge 23

In [None]:
import time
import random

In [None]:
def int_to_bit_list(x):
    return [int(b) for b in '{:032b}'.format(x)]

def bit_list_to_int(l):
    return int(''.join(str(x) for x in l), base=2)

def invert_shift_mask_xor(y, direction, shift, mask=0xFFFFFFFF):
    y = int_to_bit_list(y)
    mask = int_to_bit_list(mask)

    if direction == 'left':
        y.reverse()
        mask.reverse()
    else:
        assert direction == 'right'

    x = [None]*32
    for n in range(32):
        if n < shift:
            x[n] = y[n]
        else:
            x[n] = y[n] ^ (mask[n] & x[n-shift])

    if direction == 'left':
        x.reverse()

    return bit_list_to_int(x)

def untemper(y):
    (w, n, m, r) = (32, 624, 397, 31)
    a = 0x9908B0DF
    (u, d) = (11, 0xFFFFFFFF)
    (s, b) = (7, 0x9D2C5680)
    (t, c) = (15, 0xEFC60000)
    l = 18
    f = 1812433253

    xx = y
    xx = invert_shift_mask_xor(xx, direction='right', shift=l)
    xx = invert_shift_mask_xor(xx, direction='left', shift=t, mask=c)
    xx = invert_shift_mask_xor(xx, direction='left', shift=s, mask=b)
    xx = invert_shift_mask_xor(xx, direction='right', shift=u, mask=d)

    return xx

In [None]:
def get_cloned_rng(original_rng):
    """Taps the given rng for 624 outputs, untempers each of them to recreate the state of the generator,
    and splices that state into a new "cloned" instance of the MT19937 generator.
    """
    mt = []

    # Recreate the state mt of original_rng
    for i in range(MT19937.N):
        mt.append(untemper(original_rng.extract_number()))

    # Create a new generator and set it to have the same state
    cloned_rng = MT19937(0)
    cloned_rng.mt = mt

    return cloned_rng

In [None]:
seed = random.randint(0, 2**32 - 1)
rng = MT19937(seed)
cloned_rng = get_cloned_rng(rng)

# Check that the two PRNGs produce the same output now
for i in range(100):
    assert rng.extract_number() == cloned_rng.extract_number()

### Challenge 24

In [None]:
import os
import time
import math
import random

In [None]:
def MT19937_keystream_generator(seed):
    assert math.log2(seed) <= 16
    prng = MT19937(seed)
    while True:
        number = prng.extract_number()
        yield from number.to_bytes(4, "little")

In [None]:
def MT19937_CTR(string, seed):
    assert isinstance(seed, int)
    
    keystream = MT19937_keystream_generator(seed)
    if len(string) == 0:
        return b""
    else:
        return bytes([(b1 ^ b2) for b1, b2 in zip(string, keystream)])

In [None]:
plaintext = "Hello World!"

# append random characters before plainttext
string = b""
for _ in range(random.randint(0, 10)):
    i = random.randint(33, 126)
    string += chr(i).encode()
string += plaintext.encode()

seed = random.randint(1, 2**16)
print("> Seed value coded to be", seed)
cipher_bytes = MT19937_CTR(string, seed)
deciphered_bytes = MT19937_CTR(cipher_bytes, seed)

# verify if it can be decrypted
assert string == deciphered_bytes

#The number of possible keys is super small so you can just try them all. They even insist on it in the instructions: the cipher is using a 16-bits seed. It's kind of weird actually because from the specifications of MT19937 the seed seems to be 32 bits. Well even 32 bits should be small enough to crack, it would just take longer.
for seed in range(1, 2**16):
    deciphered_bytes = MT19937_CTR(cipher_bytes, seed)
    try:
        assert string == deciphered_bytes
        print("> Brute force successful.\nSeed:", seed)
        break
    except AssertionError:
        continue

# Set 4

### Challenge 25

In [None]:
import os
import itertools

In [None]:
def edit(ciphertext, key, offset, newtext):
    keystream = b""
    stream = CTR_Keystream(key, nonce)
    for i in itertools.islice(stream, offset, offset+len(newtext)):
        keystream += i.to_bytes(1, "little")
    append_cipher = xor_bytes(newtext, keystream)
    result = ciphertext[:offset] + append_cipher
    if len(result) < len(ciphertext):
        return result + ciphertext[len(result):]
    return result

In [None]:
random_key = os.urandom(16)
nonce = 0

In [None]:
# Testing the edit function
plaintext = b"hello there"
cipher = CTR(plaintext, random_key, nonce)
print("Original text:", CTR(cipher, random_key, nonce).decode("utf-8"))
edited_cipher = edit(cipher, random_key, 4, b"####")
print("Edited text:", CTR(edited_cipher, random_key, nonce).decode("utf-8"))

In [None]:
# if you give text as \x00 it gives out keystream, 
# xors keystream with 0 and thus can decode keystream 
# by using offset as 0
with open('25.txt') as f:
    recovered_bytes = b64decode(f.read())
    
ciphertext = CTR(recovered_bytes, random_key, nonce)
recovered_keystream = edit(ciphertext, random_key, 0, b'\x00'*len(ciphertext))
deciphered_bytes = xor_bytes(ciphertext, recovered_keystream)

In [None]:
assert deciphered_bytes == recovered_bytes

### Challenge 26

In [None]:
import os

In [None]:
random_key = os.urandom(16)
nonce = 0
prepend_string = "comment1=cooking%20MCs;userdata="
append_string = ";comment2=%20like%20a%20pound%20of%20bacon"

In [None]:
def encryptor(text, key, nonce):
    plaintext =  (prepend_string.encode() + text + append_string.encode()).replace(b';', b'";"').replace(b'=', b'"="')
    ciphertext = CTR(plaintext, key, nonce)
    return ciphertext

In [None]:
def admin_parser(byte_string, random_key, nonce) -> bool:
    decrypted_string = CTR(byte_string, random_key, nonce)
    if b';admin=true;' in decrypted_string:
        return True
    else:
        return False

In [None]:
target_bytes = b";admin=true;"
modified_string = b""

# we take out prefix length and then combine the recovered
# keystream from that offset onwards with inut text to produce
# the required string
prefix_length = len(os.path.commonprefix([encryptor(b'AAAA', random_key, nonce), encryptor(b'', random_key, nonce)]))
print("Prefix length: ", prefix_length)

dummy_input = b"heytheremama"
ciphertext = encryptor(dummy_input, random_key, nonce)
null_cipher = encryptor(b'\x00'*len(ciphertext), random_key, nonce)
recovered_keystream = null_cipher[prefix_length:len(ciphertext)]

injected_bytes = b""
for i in range(len(target_bytes)):
    injected_bytes += (target_bytes[i] ^ recovered_keystream[i]).to_bytes(1, "little")

modified_ciphertext = ciphertext[:prefix_length] + injected_bytes + ciphertext[prefix_length + len(injected_bytes):]

In [None]:
assert admin_parser(modified_ciphertext, random_key, nonce) == True

### Challenge 27

In [None]:
import os

In [None]:
random_key = os.urandom(16)
IV = random_key
keysize = 16
prepend_string = "comment1=cooking%20MCs;userdata="
append_string = ";comment2=%20like%20a%20pound%20of%20bacon"

In [None]:
def check_ascii_compliance(plaintext):
    """Returns true if all the characters of plaintext are ASCII compliant (ie are in the ASCII table)."""
    return all(c < 128 for c in plaintext)

In [None]:
def encryptor(text, IV, key):
    plaintext = text.replace(b';', b'";"').replace(b'=', b'"="')
    ciphertext = AES_CBC_Encrypt(pkcs7_pad(plaintext, len(key)), IV, key)
    return ciphertext

In [None]:
def decryptor(byte_string, IV, key) -> bool:
    decrypted_string = AES_CBC_Decrypt(byte_string, IV, key)
    print(len(decrypted_string), decrypted_string)
    if not check_ascii_compliance(decrypted_string):
        raise Exception(decrypted_string)

In [None]:
plaintext = b"lorem=ipsum;test=fun;padding=dull"
ciphertext = encryptor(plaintext, IV, random_key)
c1 = ciphertext[:keysize]
c2 = ciphertext[keysize:2*keysize]
c3 = ciphertext[2*keysize:]

try:
    decryptor(c1 + b'\x00'*16 + c1, IV, random_key)
except Exception as e:
    decrypted_string = str(e).encode()
    p1 = decrypted_string[:keysize]
    p3 = decrypted_string[2*keysize:]
    decrypted_key = xor_bytes(p1, p3)
    print("> Key found to be:", decrypted_key)

### Challenge 28

In [None]:
import os
import struct
import hashlib

In [None]:
def left_rotate(value, shift):
    """Returns value left-rotated by shift bits. In other words, performs a circular shift to the left."""
    return ((value << shift) & 0xffffffff) | (value >> (32 - shift))


def sha1(message, ml=None, h0=0x67452301, h1=0xEFCDAB89, h2=0x98BADCFE, h3=0x10325476, h4=0xC3D2E1F0):
    """Returns a string containing the SHA1 hash of the input message. This is a pure python 3 SHA1
    implementation, written starting from the SHA1 pseudo-code on Wikipedia.

    The parameters ml, h0, ..., h5 are for the next challenge.
    """
    # Pre-processing:
    if ml is None:
        ml = len(message) * 8

    message += b'\x80'
    while (len(message) * 8) % 512 != 448:
        message += b'\x00'

    message += struct.pack('>Q', ml)

    # Process the message in successive 512-bit chunks:
    for i in range(0, len(message), 64):

        # Break chunk into sixteen 32-bit big-endian integers w[i]
        w = [0] * 80
        for j in range(16):
            w[j] = struct.unpack('>I', message[i + j * 4:i + j * 4 + 4])[0]

        # Extend the sixteen 32-bit integers into eighty 32-bit integers:
        for j in range(16, 80):
            w[j] = left_rotate(w[j - 3] ^ w[j - 8] ^ w[j - 14] ^ w[j - 16], 1)

        # Initialize hash value for this chunk:
        a = h0
        b = h1
        c = h2
        d = h3
        e = h4

        # Main loop
        for j in range(80):
            if j <= 19:
                f = d ^ (b & (c ^ d))
                k = 0x5A827999
            elif 20 <= j <= 39:
                f = b ^ c ^ d
                k = 0x6ED9EBA1
            elif 40 <= j <= 59:
                f = (b & c) | (d & (b | c))
                k = 0x8F1BBCDC
            else:
                f = b ^ c ^ d
                k = 0xCA62C1D6

            temp = left_rotate(a, 5) + f + e + k + w[j] & 0xffffffff
            e = d
            d = c
            c = left_rotate(b, 30)
            b = a
            a = temp

        # Add this chunk's hash to result so far:
        h0 = (h0 + a) & 0xffffffff
        h1 = (h1 + b) & 0xffffffff
        h2 = (h2 + c) & 0xffffffff
        h3 = (h3 + d) & 0xffffffff
        h4 = (h4 + e) & 0xffffffff

    # Produce the final hash value (big-endian) as a 160 bit number, hex formatted:
    return '%08x%08x%08x%08x%08x' % (h0, h1, h2, h3, h4)

def sha1_mac(key, message):
    return sha1(key + message)

In [None]:
key = os.urandom(16)
message = b'This is a message to test that our implementation of the SHA1 MAC works properly.'

hashed = sha1_mac(key, message)

# Verify that I implemented SHA1 correctly
h = hashlib.sha1(key + message)

In [None]:
assert (hashed == h.hexdigest())

### Challenge 29

In [None]:
import os
import struct

In [None]:
message = "comment1=cooking%20MCs;userdata=foo;comment2=%20like%20a%20pound%20of%20bacon"
key = os.urandom(16)
payload = b";admin=true"

In [None]:
def md_pad(message):
    ml = len(message) * 8
    message += b'\x80'
    while (len(message) * 8) % 512 != 448:
        message += b'\x00'

    message += struct.pack('>Q', ml)
    return message
    
def validate(modified_message, new_md):
    if sha1_mac(key, modified_message) == new_md:
        return True
    return False
    
def sha1_length_extension_attack(message, original_md, payload):
    
    for key_length in range(20):
        h = struct.unpack('>5I', bytes.fromhex(original_md))
        modified_message = md_pad(b'A'*key_length + message)[key_length:] + payload
        new_md = sha1(payload, (len(modified_message) + key_length)*8, h[0], h[1], h[2], h[3], h[4])
        if validate(modified_message, new_md):
            print("> Length extension attack successful.")
            return modified_message, new_md
            break

In [None]:
original_md = sha1_mac(key, message.encode())
modified_message, new_md = sha1_length_extension_attack(message.encode(), original_md, payload)

In [None]:
assert payload in modified_message

### Challenge 30

In [None]:
import struct
from binascii import hexlify

In [None]:
message = "comment1=cooking%20MCs;userdata=foo;comment2=%20like%20a%20pound%20of%20bacon"
key = os.urandom(16)
payload = b";admin=true"

In [None]:
def md_pad(message):
    """Pads the given message the same way the pre-processing of the MD4 algorithm does."""
    ml = len(message) * 8

    message += b'\x80'
    message += bytes((56 - len(message) % 64) % 64)
    message += struct.pack('<Q', ml)

    return message
    
def validate(modified_message, new_md):
    if MD4(modified_message).hex_digest() == new_md:
        return True
    return False

In [None]:
class MD4:
    buf = [0x00] * 64

    _F = lambda self, x, y, z: ((x & y) | (~x & z))
    _G = lambda self, x, y, z: ((x & y) | (x & z) | (y & z))
    _H = lambda self, x, y, z: (x ^ y ^ z)

    def __init__(self, message, ml=None, A=0x67452301, B=0xefcdab89, C=0x98badcfe, D=0x10325476):
        self.A, self.B, self.C, self.D = A, B, C, D

        if ml is None:
            ml = len(message) * 8
        length = struct.pack('<Q', ml)

        while len(message) > 64:
            self._handle(message[:64])
            message = message[64:]

        message += b'\x80'
        message += bytes((56 - len(message) % 64) % 64)
        message += length

        while len(message):
            self._handle(message[:64])
            message = message[64:]

    def _handle(self, chunk):
        X = list(struct.unpack('<' + 'I' * 16, chunk))
        A, B, C, D = self.A, self.B, self.C, self.D

        for i in range(16):
            k = i
            if i % 4 == 0:
                A = left_rotate((A + self._F(B, C, D) + X[k]) & 0xffffffff, 3)
            elif i % 4 == 1:
                D = left_rotate((D + self._F(A, B, C) + X[k]) & 0xffffffff, 7)
            elif i % 4 == 2:
                C = left_rotate((C + self._F(D, A, B) + X[k]) & 0xffffffff, 11)
            elif i % 4 == 3:
                B = left_rotate((B + self._F(C, D, A) + X[k]) & 0xffffffff, 19)

        for i in range(16):
            k = (i // 4) + (i % 4) * 4
            if i % 4 == 0:
                A = left_rotate((A + self._G(B, C, D) + X[k] + 0x5a827999) & 0xffffffff, 3)
            elif i % 4 == 1:
                D = left_rotate((D + self._G(A, B, C) + X[k] + 0x5a827999) & 0xffffffff, 5)
            elif i % 4 == 2:
                C = left_rotate((C + self._G(D, A, B) + X[k] + 0x5a827999) & 0xffffffff, 9)
            elif i % 4 == 3:
                B = left_rotate((B + self._G(C, D, A) + X[k] + 0x5a827999) & 0xffffffff, 13)

        order = [0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15]
        for i in range(16):
            k = order[i]
            if i % 4 == 0:
                A = left_rotate((A + self._H(B, C, D) + X[k] + 0x6ed9eba1) & 0xffffffff, 3)
            elif i % 4 == 1:
                D = left_rotate((D + self._H(A, B, C) + X[k] + 0x6ed9eba1) & 0xffffffff, 9)
            elif i % 4 == 2:
                C = left_rotate((C + self._H(D, A, B) + X[k] + 0x6ed9eba1) & 0xffffffff, 11)
            elif i % 4 == 3:
                B = left_rotate((B + self._H(C, D, A) + X[k] + 0x6ed9eba1) & 0xffffffff, 15)

        self.A = (self.A + A) & 0xffffffff
        self.B = (self.B + B) & 0xffffffff
        self.C = (self.C + C) & 0xffffffff
        self.D = (self.D + D) & 0xffffffff

    def digest(self):
        return struct.pack('<4I', self.A, self.B, self.C, self.D)

    def hex_digest(self):
        return hexlify(self.digest()).decode()
    
def md4_length_extension_attack(message, original_md, payload):
    
    for key_length in range(20):
        h = struct.unpack('<4I', bytes.fromhex(original_md))
        modified_message = md_pad(b'A'*key_length + message)[key_length:] + payload
        new_md = MD4(payload, (len(modified_message) + key_length)*8, h[0], h[1], h[2], h[3]).hex_digest()
        if validate(modified_message, new_md):
            print("> Length extension attack successful.")
            return modified_message, new_md
            break

In [None]:
original_md = MD4(message.encode()).hex_digest()
modified_message, new_md = md4_length_extension_attack(message.encode(), original_md, payload)

In [None]:
assert payload in modified_message

### Challenge 31

In [None]:
class HMAC:
    
    def __init__(self, random_key, hash_func):
        self.hash_func = hash_func
        self.block_size = hash_func().block_size

        if len(random_key) > self.block_size:
            self.key = hash_func(random_key).digest()
        elif len(random_key) < self.block_size:
            self.key = random_key + b'\x00' * (self.block_size-len(random_key))

    def compute(self, message):
        o_key_pad = xor_bytes(self.key, b'\x5c' * self.block_size)
        i_key_pad = xor_bytes(self.key, b'\x36' * self.block_size)
        
        inner_hash = self.hash_func(i_key_pad + message).digest()
        
        return self.hash_func(o_key_pad + inner_hash).hexdigest()

In [None]:
import os
import web
import json
import time
import hashlib

urls = (
    '/hello', 'Hello',
    '/test', 'Hash'
)

app = web.application(urls, globals())

HMAC_obj = HMAC(b"YELLOW_SUBMARINE", hashlib.sha1)

class Hello:        
    
    def GET(self):
        params = web.input()
        name = params.name
        if not name:
            name = 'World'
            
        string = "Hello, " + name + "!"
        return {"name" : string}

class Hash:
    
    def _insecure_compare(self, hash1, hash2, delay):
        for b1, b2 in zip(hash1, hash2):
            if b1 != b2:
                return False
            time.sleep(delay)
        return True
    
    def GET(self):
        global HMAC_obj
        params = web.input()
        file = params.file
        signature = params.signature
        delay = params.delay
        
        hmac = HMAC_obj.compute(file.encode())
        if self._insecure_compare(hmac.encode(), signature.encode(), float(delay)):
            return web.HTTPError(200)
        else:
            return web.HTTPError(500)

In [None]:
response1 = app.request("/hello?name=")
print(response1.data)

response2 = app.request("/hello?name=hexterisk")
print(json.loads(response2.data.decode("utf-8").replace("'",'"')))

file = "foo"
signature = "274b7c4d98605fcf739a0bf9237551623f415fb8"
response = app.request("/test?delay=" + str(delay) + "&file=" + file + "&signature=" + signature)
print(response)

signature = "8c80a95a8e72b3e822a13924553351a433e267d8"
response = app.request("/test?delay=" + str(delay) + "&file=" + file + "&signature=" + signature)
print(response)

In [None]:
file = "foo"
delay = 0.05

In [None]:
signature = ""
for _ in range(hashlib.sha1().digest_size * 2):
    """produces a 160-bit (20-byte) hash value known as a message digest,
    typically rendered as a hexadecimal number, 40 digits long."""
    
    times = []
    for i in range(16):
        start = time.time()
        response = app.request("/test?delay=" + str(delay) + "&file=" + file + "&signature=" + signature + hex(i)[-1])
        finish = time.time()
        times.append(finish - start)
    signature += hex(times.index(max(times)))[-1]
    print("> Discovered signature:", signature)

response = app.request("/test?delay=" + str(delay) + "&file=" + file + "&signature=" + signature + hex(i)[-1])
if response.status == 200:
    print("> Brute force successful.\n> Signature:", signature)
else:
    print("Brute force failed.")

### Challenge 32

In [None]:
HMAC_obj = HMAC(b"YELLOW_SUBMARINE", hashlib.sha1)

file = "foo"
delay = 0.005

In [None]:
signature = ""
for _ in range(hashlib.sha1().digest_size * 2):
    """produces a 160-bit (20-byte) hash value known as a message digest,
    typically rendered as a hexadecimal number, 40 digits long."""
    
    times = []
    for i in range(16):
        runtime = 0
        # introduced more rounds so the difference is prominent
        for _ in range(20):
            start = time.time()
            response = app.request("/test?delay=" + str(delay) + "&file=" + file + "&signature=" + signature + hex(i)[-1])
            finish = time.time()
            runtime += finish - start
        times.append(runtime)
    signature += hex(times.index(max(times)))[-1]
    print("> Discovered signature:", signature)

response = app.request("/test?delay=" + str(delay) + "&file=" + file + "&signature=" + signature + hex(i)[-1])
if response.status == 200:
    print("> Brute force successful.\n> Signature:", signature)
else:
    print("Brute force failed.")

# Set 5

### Challenge 33

In [None]:
import random
import hashlib

In [None]:
p = 37
g = 5

In [None]:
# Alice
a = random.randint(0, 100)
A = (g**a) % p

In [None]:
# Bob
b = random.randint(0, 100)
B = (g**b) % p

In [None]:
session_key_Alice = (B**a) % p
session_key_Bob = (A**b) % p
assert session_key_Alice == session_key_Bob

In [None]:
key = hashlib.sha256(str(session_key_Alice).encode()).hexdigest()

In [None]:
print("> Key:", key)

In [None]:
class DiffieHellman():
    """Implements the Diffie-Helman key exchange. Each class is a party, which has his secret key (usually
    referred to as lowercase a or b) shares the public key (usually referred to as uppercase A or B) and can
    compute the shared secret key between itself and another party, given their public key, assuming that
    they are agreeing on the same p and g.
    """

    DEFAULT_G = 2
    DEFAULT_P = int('ffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024e088a67cc74020bbea63b139b225'
                    '14a08798e3404ddef9519b3cd3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec6f4'
                    '4c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f24117c4b1fe649286651ece45b3dc20'
                    '07cb8a163bf0598da48361c55d39a69163fa8fd24cf5f83655d23dca3ad961c62f356208552bb9ed5'
                    '29077096966d670c354e4abc9804f1746c08ca237327ffffffffffffffff', 16)

    def __init__(self, g=DEFAULT_G, p=DEFAULT_P):
        self.g = g
        self.p = p
        self._secret_key = random.randint(0, p - 1)
        self.shared_key = None

    def gen_public_key(self):
        return pow(self.g, self._secret_key, self.p)

    def gen_shared_secret_key(self, other_party_public_key):
        if self.shared_key is None:
            self.shared_key = pow(other_party_public_key, self._secret_key, self.p)
        return self.shared_key

In [None]:
client1 = DiffieHellman()
client2 = DiffieHellman()

# Check that our DiffieHellman implementation works and two parties will agree on the same key
assert client1.gen_shared_secret_key(client2.gen_public_key()) == client2.gen_shared_secret_key(client1.gen_public_key())

### Challenge 34

In [None]:
import os
from Crypto.Cipher import AES
import hashlib

In [None]:
p = 37
g = 5

In [None]:
def parameter_injection_attack(alice, bob):
    block_size = AES.block_size
    # A -> M
    A = alice.gen_public_key()
    # M -> B
    A = alice.p
    # B -> M
    B = bob.gen_public_key()
    # M -> A
    B = bob.p
    
    # A -> M
    msg = b"Hello there!"
    s_a = hashlib.sha1(str(alice.gen_shared_secret_key(B)).encode()).digest()[:AES.block_size]
    iv = os.urandom(16)
    cipher_a = AES_CBC_Encrypt(msg, iv, s_a) + iv
    
    # M -> B
    
    # B -> M
    s_b = hashlib.sha1(str(bob.gen_shared_secret_key(A)).encode()).digest()[:16]
    a_iv = cipher[-AES.block_size:]
    a_msg = AES_CBC_Decrypt(cipher_a[:-AES.block_size], iv, s_b)
    print("A sent:", pkcs7_unpad(a_msg))
    iv = os.urandom(16)
    cipher_b = AES_CBC_Encrypt(a_msg, iv, s_b) + iv
    
    # M -> A
    
    # Finding the key after replacing A and B with p is, in fact, very easy.
    # Instead of (B^a % p) or (A^b % p), the shared secret key of the exercise became (p^a % p)
    # and (p^b % p), both equal to zero!
    mitm_key = hashlib.sha1(b'0').digest()[:AES.block_size]
    
    mitm_iv_a = cipher_a[-block_size:]
    mitm_msg_a_read = AES_CBC_Decrypt(cipher_a[:-block_size], mitm_iv_a, mitm_key)
    print("MITM MSG A:", pkcs7_unpad(mitm_msg_a_read))
    
    mitm_iv_b = cipher_b[-block_size:]
    mitm_msg_b_read = AES_CBC_Decrypt(cipher_b[:-block_size], mitm_iv_b, mitm_key)
    print("MITM MSG B:", pkcs7_unpad(mitm_msg_b_read))
    

In [None]:
alice = DiffieHellman(g, p)
bob = DiffieHellman(g, p)

parameter_injection_attack(alice, bob)

### Challenge 35

In [None]:
import os
from Crypto.Cipher import AES
import hashlib

In [None]:
def malicious_g_attack():
    """Simulates the break of Diffie-Hellman with negotiated groups by using malicious 'g' parameters."""
    p = DiffieHellman.DEFAULT_P

    for g in [1, p, p - 1]:

        # Step 1: the MITM changes the default g sent by Alice to Bob with a forced value
        alice = DiffieHellman()
        bob = DiffieHellman(g=g)

        # Step 2: Bob receives this forced g and sends an ACK to Alice

        # Step 3: Alice computes A and sends it to the MITM (thinking of Bob)
        A = alice.gen_public_key()

        # Step 4: Bob computes B and sends it to the MITM (thinking of Alice)
        B = bob.gen_public_key()

        # Step 5: Alice sends her encrypted message to Bob (without knowledge of MITM)
        _msg = b'Hello, how are you?'
        _a_key = hashlib.sha1(str(alice.gen_shared_secret_key(B)).encode()).digest()[:16]
        _a_iv = os.urandom(AES.block_size)
        a_question = AES_CBC_Encrypt(_msg, _a_iv, _a_key) + _a_iv

        # Step 6: Bob receives the message sent by Alice (without knowing of the attack)
        # However, this time Bob will not be able to decrypt it, because (if I understood the
        # challenge task correctly) Alice and Bob now use different values of g.

        # Step 7: the MITM decrypts the Alice's question
        mitm_a_iv = a_question[-AES.block_size:]

        # When g is 1, the secret key is also 1
        if g == 1:
            mitm_hacked_key = hashlib.sha1(b'1').digest()[:16]
            mitm_hacked_message = AES_CBC_Decrypt(a_question[:-AES.block_size], mitm_a_iv, mitm_hacked_key)

        # When g is equal to p, it works the same as in the S5C34 attack (the secret key is 0)
        elif g == p:
            mitm_hacked_key = hashlib.sha1(b'0').digest()[:16]
            mitm_hacked_message = AES_CBC_Decrypt(a_question[:-AES.block_size], mitm_a_iv, mitm_hacked_key)

        # When g is equal to p - 1, the secret key is (-1)^(ab), which is either (+1 % p) or (-1 % p).
        # We can try both and later check the padding to see which one is correct.
        else:

            for candidate in [str(1).encode(), str(p - 1).encode()]:
                mitm_hacked_key = hashlib.sha1(candidate).digest()[:16]
                mitm_hacked_message = AES_CBC_Decrypt(a_question[:-AES.block_size], mitm_a_iv, mitm_hacked_key)

                if pkcs7_padded(mitm_hacked_message):
                    mitm_hacked_message = pkcs7_unpad(mitm_hacked_message)
                    break
        print(mitm_hacked_message)
        # Check if the attack worked
        #assert _msg == mitm_hacked_message


In [None]:
malicious_g_attack()

### Challenge 36

In [None]:
import random
import hashlib

In [None]:
# Client and server agree on these values beforehand

# Generated using "openssl dhparam -text 1024".
N = int("008c5f8a80af99a7db03599f8dae8fb2f75b52501ef54a827b8a1a586f14dfb20d6b5e2ff878b9ad6bca0bb9"
        "18d30431fca1770760aa48be455cf5b949f3b86aa85a2573769e6c598f8d902cc1a0971a92e55b6e04c4d07e"
        "01ac1fa9bdefd1f04f95f197b000486c43917568ff58fafbffe12bde0c7e8f019fa1cb2b8e1bcb1f33", 16)
g = 2
k = 3
I = "hextersik@hexterisk.com"
P = "hexterisk"

In [None]:
class HMAC:
    
    def __init__(self, random_key, hash_func):
        self.hash_func = hash_func
        self.block_size = hash_func().block_size

        if len(random_key) > self.block_size:
            self.key = hash_func(random_key).digest()
        elif len(random_key) < self.block_size:
            self.key = random_key + b'\x00' * (self.block_size-len(random_key))

    def compute(self, message):
        o_key_pad = xor_bytes(self.key, b'\x5c' * self.block_size)
        i_key_pad = xor_bytes(self.key, b'\x36' * self.block_size)
        
        inner_hash = self.hash_func(i_key_pad + message).digest()
        
        return self.hash_func(o_key_pad + inner_hash).hexdigest()

In [None]:
import os
import web
import json
import time
import random
import hashlib

urls = (
    '/hello', 'Hello',
    '/init', 'Initiate',
    '/verify', 'Verify'
)

app = web.application(urls, globals())

K = None
salt = str(random.randint(0, 2**32 - 1))
# since we can't save x, xH
v = pow(g, int(hashlib.sha256(salt.encode()+P.encode()).hexdigest(), 16), N)

class Hello:        
    
    def GET(self):
        params = web.input()
        name = params.name
        if not name:
            name = 'World'
            
        string = "Hello, " + name + "!"
        return {"name" : string}
    
class Verify:

    def GET(self):
        
        global K, salt
        
        params = web.input()
        hmac_received = params.hmac
        
        HMAC_obj = HMAC(K, hashlib.sha256)
        hmac = HMAC_obj.compute(salt.encode())
        
        if hmac == hmac_received:
            return "OK"

class Initiate:
    
    def GET(self):
        
        global K, salt
        
        params = web.input()
        I = params.I
        A = int(params.A)
        
        b = random.randint(0, N - 1)
        B = k*v + pow(g, b, N)
        
        uH = hashlib.sha256(str(A).encode()+str(B).encode()).hexdigest()
        u = int(uH, 16)
        S = pow(A * pow(v, u, N), b, N)
        K = hashlib.sha256(str(S).encode()).digest()
        
        return {"salt":salt, "B":B}

In [None]:
response1 = app.request("/hello?name=")
print(response1.data)

response2 = app.request("/hello?name=hexterisk")
print(json.loads(response2.data.decode("utf-8").replace("'",'"')))

In [None]:
def implement_SRP():
    
    a = random.randint(0, N - 1)
    A = pow(g, a, N)
    
    response = app.request("/init?I=" + I + "&A=" + str(A))
    response_dict = json.loads(response.data.decode("utf-8").replace("'",'"'))
    salt = response_dict["salt"]
    B = int(response_dict["B"])

    uH = hashlib.sha256(str(A).encode()+str(B).encode()).hexdigest()
    u = int(uH, 16)
    
    xH = hashlib.sha256(salt.encode()+P.encode()).hexdigest()
    x = int(xH, 16)
    
    S = pow((B - k * pow(g, x, N)), (a + u * x), N)
   
    HMAC_obj = HMAC(K, hashlib.sha256)
    hmac = HMAC_obj.compute(salt.encode())
    
    response = app.request("/verify?hmac=" + hmac)
    assert response.data.decode("utf-8") == "OK"
    print("> Verification successful.")

In [None]:
implement_SRP()

### Challenge 37

In [None]:
import random
import hashlib

In [None]:
class HMAC:
    
    def __init__(self, random_key, hash_func):
        self.hash_func = hash_func
        self.block_size = hash_func().block_size

        if len(random_key) > self.block_size:
            self.key = hash_func(random_key).digest()
        elif len(random_key) < self.block_size:
            self.key = random_key + b'\x00' * (self.block_size-len(random_key))

    def compute(self, message):
        o_key_pad = xor_bytes(self.key, b'\x5c' * self.block_size)
        i_key_pad = xor_bytes(self.key, b'\x36' * self.block_size)
        
        inner_hash = self.hash_func(i_key_pad + message).digest()
        
        return self.hash_func(o_key_pad + inner_hash).hexdigest()

In [None]:
import os
import web
import json
import time
import random
import hashlib

urls = (
    '/hello', 'Hello',
    '/init', 'Initiate',
    '/verify', 'Verify'
)

app = web.application(urls, globals())

K = None
salt = str(random.randint(0, 2**32 - 1))
# since we can't save x, xH
v = pow(g, int(hashlib.sha256(salt.encode()+P.encode()).hexdigest(), 16), N)

class Hello:        
    
    def GET(self):
        params = web.input()
        name = params.name
        if not name:
            name = 'World'
            
        string = "Hello, " + name + "!"
        return {"name" : string}
    
class Verify:

    def GET(self):
        
        global K, salt
        
        params = web.input()
        hmac_received = params.hmac
        
        HMAC_obj = HMAC(K, hashlib.sha256)
        hmac = HMAC_obj.compute(salt.encode())
        
        if hmac == hmac_received:
            return "OK"

class Initiate:
    
    def GET(self):
        
        global K, salt
        
        params = web.input()
        I = params.I
        A = int(params.A)
        
        b = random.randint(0, N - 1)
        B = k*v + pow(g, b, N)
        
        uH = hashlib.sha256(str(A).encode()+str(B).encode()).hexdigest()
        u = int(uH, 16)
        # S will be zero since modulo N will be zero for 0 and multiples of N
        S = pow(A * pow(v, u, N), b, N)
        K = hashlib.sha256(str(S).encode()).digest()
        
        return {"salt":salt, "B":B}

In [None]:
response1 = app.request("/hello?name=")
print(response1.data)

response2 = app.request("/hello?name=hexterisk")
print(json.loads(response2.data.decode("utf-8").replace("'",'"')))

In [None]:
def implement_SRP():
    
    for A in [0, N, N*2]:
        a = random.randint(0, N - 1)

        response = app.request("/init?I=" + I + "&A=" + str(A))
        response_dict = json.loads(response.data.decode("utf-8").replace("'",'"'))
        salt = response_dict["salt"]
        B = int(response_dict["B"])

        uH = hashlib.sha256(str(A).encode()+str(B).encode()).hexdigest()
        u = int(uH, 16)

        xH = hashlib.sha256(salt.encode()+P.encode()).hexdigest()
        x = int(xH, 16)

        # S = modular_pow((B - k * modular_pow(g, x, N)), (a + u * x), N)
        # We put S=0 because we know it's going to be zero on the server side
        S = 0
        K = hashlib.sha256(str(S).encode()).digest()

        HMAC_obj = HMAC(K, hashlib.sha256)
        hmac = HMAC_obj.compute(salt.encode())

        response = app.request("/verify?hmac=" + hmac)
        assert response.data.decode("utf-8") == "OK"
        print("> Verification successful.")

In [None]:
implement_SRP()

### Challenge 38

In [None]:
import random
import hashlib

In [None]:
# Client and server agree on these values beforehand

# Generated using "openssl dhparam -text 1024".
N = int("008c5f8a80af99a7db03599f8dae8fb2f75b52501ef54a827b8a1a586f14dfb20d6b5e2ff878b9ad6bca0bb9"
        "18d30431fca1770760aa48be455cf5b949f3b86aa85a2573769e6c598f8d902cc1a0971a92e55b6e04c4d07e"
        "01ac1fa9bdefd1f04f95f197b000486c43917568ff58fafbffe12bde0c7e8f019fa1cb2b8e1bcb1f33", 16)
g = 2
k = 3
I = "hextersik@hexterisk.com"
P = "BackupU$r"

In [None]:
class HMAC:
    
    def __init__(self, random_key, hash_func):
        self.hash_func = hash_func
        self.block_size = hash_func().block_size

        if len(random_key) > self.block_size:
            self.key = hash_func(random_key).digest()
        elif len(random_key) < self.block_size:
            self.key = random_key + b'\x00' * (self.block_size-len(random_key))

    def compute(self, message):
        o_key_pad = xor_bytes(self.key, b'\x5c' * self.block_size)
        i_key_pad = xor_bytes(self.key, b'\x36' * self.block_size)
        
        inner_hash = self.hash_func(i_key_pad + message).digest()
        
        return self.hash_func(o_key_pad + inner_hash).hexdigest()

In [None]:
import os
import web
import json
import time
import random
import hashlib

urls = (
    '/hello', 'Hello',
    '/init', 'Initiate',
    '/verify', 'Verify'
)

app = web.application(urls, globals())

K = None
salt = str(random.randint(0, 2**32 - 1))
# since we can't save x, xH
v = pow(g, int(hashlib.sha256(salt.encode()+P.encode()).hexdigest(), 16), N)

class Hello:        
    
    def GET(self):
        params = web.input()
        name = params.name
        if not name:
            name = 'World'
            
        string = "Hello, " + name + "!"
        return {"name" : string}
    
class Verify:

    def GET(self):
        
        global K, salt
        
        params = web.input()
        hmac_received = params.hmac
        
        HMAC_obj = HMAC(K, hashlib.sha256)
        hmac = HMAC_obj.compute(salt.encode())
        
        if hmac == hmac_received:
            return "OK"

class Initiate:
    
    def GET(self):
        
        global K, salt
        
        params = web.input()
        I = params.I
        A = int(params.A)
        
        b = random.randint(0, N - 1)
        B = pow(g, b, N)
        
        u = random.getrandbits(128)
        S = pow(A * pow(v, u, N), b, N)
        K = hashlib.sha256(str(S).encode()).digest()
        
        return {"salt":salt, "B":B, "u":u}

In [None]:
response1 = app.request("/hello?name=")
print(response1.data)

response2 = app.request("/hello?name=hexterisk")
print(json.loads(response2.data.decode("utf-8").replace("'",'"')))

In [None]:
def implement_SRP():
    
    a = random.randint(0, N - 1)
    A = pow(g, a, N)
    
    response = app.request("/init?I=" + I + "&A=" + str(A))
    response_dict = json.loads(response.data.decode("utf-8").replace("'",'"'))
    salt = response_dict["salt"]
    B = int(response_dict["B"])
    u = int(response_dict["u"])
    
    xH = hashlib.sha256(salt.encode()+P.encode()).hexdigest()
    x = int(xH, 16)
    
    S = pow(B, (a + u * x), N)
    K = hashlib.sha256(str(S).encode()).digest()
    
    HMAC_obj = HMAC(K, hashlib.sha256)
    hmac = HMAC_obj.compute(salt.encode())
    
    response = app.request("/verify?hmac=" + hmac)
    assert response.data.decode("utf-8") == "OK"
    print("> Verification successful.")

In [None]:
implement_SRP()

In [None]:
def MITM():
    
    a = random.randint(0, N - 1)
    A = pow(g, a, N)
    
    response = app.request("/init?I=" + I + "&A=" + str(A))
    response_dict = json.loads(response.data.decode("utf-8").replace("'",'"'))
    salt = response_dict["salt"]
    B = int(response_dict["B"])
    u = int(response_dict["u"])
    
    data = open('dictionary.txt', 'r').read()
    passwords = data.split('\n')
    
    for password in passwords:
        xH = hashlib.sha256(salt.encode()+password.encode()).hexdigest()
        x = int(xH, 16)

        S = pow(B, (a + u * x), N)
        K = hashlib.sha256(str(S).encode()).digest()

        HMAC_obj = HMAC(K, hashlib.sha256)
        hmac = HMAC_obj.compute(salt.encode())

        response = app.request("/verify?hmac=" + hmac)
        
        if response.data.decode("utf-8") == "OK":
            print("> Brute force successful.")
            print("> Password found to be:", P)
            break

In [None]:
MITM()

### Challenge 39

In [None]:
import math
import random
from Crypto.Util.number import getPrime

In [None]:
def mod_inverse(a, n) : 
    """Computes the multiplicative inverse of a modulo n using the extended Euclidean algorithm."""
    t, r = 0, n
    new_t, new_r = 1, a

    while new_r != 0:
        quotient = r // new_r
        t, new_t = new_t, t - quotient * new_t
        r, new_r = new_r, r - quotient * new_r

    if r > 1:
        raise Exception("a is not invertible")
    if t < 0:
        t = t + n

    return t

In [None]:
class RSA:
    
    def __init__(self, keysize):
        e = 3
        et = 0
        n = 0

        while math.gcd(e, et) != 1:
            p, q = getPrime(keysize // 2), getPrime(keysize // 2)
            et = ((p - 1) * (q - 1)) // math.gcd(p - 1, q - 1)
            n = p * q

        d = mod_inverse(e, et)
        
        self.pub = (e, n)
        self.pvt = (d, n)

    def encrypt(self, message, byteorder="big"):
        (e, n) = self.pub
        data = int.from_bytes(message, byteorder)
        
        if data < 0 or data >= n:
            raise ValueError(str(data) + ' out of range')
            
        return pow(data, e, n)
    
    def encryptnum(self, m):
        (e, n) = self.pub
        if m < 0 or m >= n:
            raise ValueError(str(m) + ' out of range')
        return pow(m, e, n)
    
    def decrypt(self, ciphertext, byteorder="big"):
        (d, n) = self.pvt
        
        if ciphertext < 0 or ciphertext >= n:
            raise ValueError(str(ciphertext) + ' out of range')
        
        numeric_plain = pow(ciphertext, d, n)
        return numeric_plain.to_bytes((numeric_plain.bit_length() + 7) // 8, byteorder)
    
    def decryptnum(self, m):
        (d, n) = self.pvt
        if m < 0 or m >= n:
            raise ValueError(str(m) + ' out of range')
        return pow(m, d, n)

In [None]:
rsa = RSA(1024)
message = "Testing 1..2..3..."

In [None]:
ciphertext = rsa.encrypt(message.encode())
assert rsa.decrypt(ciphertext).decode("utf-8") == message

### Challenge 40

In [None]:
import math

In [None]:
def floorRoot(n, s):
    b = n.bit_length()
    p = math.ceil(b/s)
    x = 2**p
    while x > 1:
        y = (((s - 1) * x) + (n // (x**(s-1)))) // s
        if y >= x:
            return x
        x = y
    return 1

def mod_inverse(a, n) : 
    """Computes the multiplicative inverse of a modulo n using the extended Euclidean algorithm."""
    t, r = 0, n
    new_t, new_r = 1, a

    while new_r != 0:
        quotient = r // new_r
        t, new_t = new_t, t - quotient * new_t
        r, new_r = new_r, r - quotient * new_r

    if r > 1:
        raise Exception("a is not invertible")
    if t < 0:
        t = t + n

    return t

In [None]:
def RSA_Broadcast_Attack(message, rsa0, rsa1, rsa2):
    """Uses the Chinese Remainder Theorem (CRT) to break e=3 RSA given three ciphertexts of the same plaintext.
    This attack could be easily coded to work also when a different number of ciphertexts is provided.
    Check here for reference: https://crypto.stanford.edu/pbc/notes/numbertheory/crt.html
    """
    
    rsa0 = RSA(256)
    rsa1 = RSA(256)
    rsa2 = RSA(256)

    pub0 = rsa0.pub
    pub1 = rsa1.pub
    pub2 = rsa2.pub
    
    plainnum = int.from_bytes(message, "big")
    c0 = rsa0.encryptnum(plainnum)
    c1 = rsa1.encryptnum(plainnum)
    c2 = rsa2.encryptnum(plainnum)
    
    n0 = rsa0.pub[1]
    n1 = rsa1.pub[1]
    n2 = rsa2.pub[1]
    
    # can't do N/n0 for ms0 instead because floating point operations arent accurate
    N = n0 * n1 * n2
    ms0 = n1 * n2
    ms1 = n0 * n2
    ms2 = n0 * n1
    
    r0 = (c0 * ms0 * mod_inverse(ms0, n0))
    r1 = (c1 * ms1 * mod_inverse(ms1, n1))
    r2 = (c2 * ms2 * mod_inverse(ms2, n2))
    
    R = (r0 + r1 + r2) % N
    m = floorRoot(R, 3)
    
    return m.to_bytes((m.bit_length() + 7) // 8, "big")

In [None]:
message = "This is RSA Broadcast Attack"
assert RSA_Broadcast_Attack(message.encode(), RSA(256), RSA(256), RSA(256)).decode("utf-8") == message

# Set 6

### Challenge 41

In [None]:
import random

In [None]:
class RSA_server:
    
    def __init__(self, rsa):
        self.rsa = rsa
        self.decrypted = []
    
    def get_public_key(self):
        return self.rsa.pub
    
    def decrypt(self, ciphertext):
        if cipher in decrypted:
            raise Exception("This ciphertext has already been deciphered before!")
        self.decrypted.append(ciphertext)
        return rsa.decrypt(ciphertext)

In [None]:
def unpadded_message_recovery(ciphertext, rsa_server):
    
    (E, N) = rsa_server.get_public_key()
    S = random.randint(1, N)
    while True:
        if S % N > 1:
            break
    
    modified_ciphertext = (pow(S, E, N) * ciphertext) % N
    
    modified_plaintext = rsa.decrypt(modified_ciphertext)
    recovered_plaintext_int = (int.from_bytes(modified_plaintext, "big") * mod_inverse(S, N) % N)
    
    return (recovered_plaintext_int).to_bytes((recovered_plaintext_int.bit_length() + 7) // 8, "big")

In [None]:
rsa = RSA(256)
rsa_server = RSA_server(rsa)

plaintext = "Unpadded message"
ciphertext = rsa.encrypt(plaintext.encode())
assert unpadded_message_recovery(ciphertext, rsa_server).decode("utf-8") == plaintext

### Challenge 42

In [None]:
import re
import hashlib

In [None]:
# Shit has to be in big endian because cubing takes place, and interferes with the bit manipulation

In [None]:
class RSA_Digital_Signature(RSA):
    """Extends the RSA class coded before with the sign / verify functions."""

    def generate_signature(self, message):
        digest = hashlib.sha1(message).digest()
        block = b'\x00\x01' + (b'\xff' * (128 - len(digest) - 3 - 15)) + b'\x00' + ASN1_SHA1 + digest
        signature = rsa.decrypt(int.from_bytes(block, "big"), "big")
        return signature

    def verify_signature(self, message, signature):
        cipher = rsa.encrypt(signature, "big")
        block = b'\x00' + cipher.to_bytes((cipher.bit_length() + 7) // 8, "big")
        r = re.compile(b'\x00\x01\xff+?\x00.{15}(.{20})', re.DOTALL)
        m = r.match(block)
        if not m:
            return False
        digest = m.group(1)
        return digest == hashlib.sha1(message).digest()

In [None]:
message = "hi mom"
# 15-byte ASN.1 value for SHA1 (from rfc 3447)
ASN1_SHA1 = b"\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14"

rsa = RSA_Digital_Signature(1024)
signature = rsa.generate_signature(message.encode())
if not rsa.verify_signature(message.encode(), signature):
    raise Exception(message + b' has invalid signature ' + signature)
else:
    print("> Signature verified for message:", message)

In [None]:
def forge_signature(message):
    digest = hashlib.sha1(message).digest()
    block = b'\x00\x01\xff\x00' +  ASN1_SHA1 + digest + (b'\x00' * (128 - len(digest) - 4 - 15))
    block_int = int.from_bytes(block, "big")
    sig = floorRoot(block_int, 3) + 1
    return sig.to_bytes((sig.bit_length() + 7) // 8, "big")

In [None]:
forged_signature = forge_signature(message.encode())
if not rsa.verify_signature(message.encode(), forged_signature):
    raise Exception(message + b' has invalid signature ' + forged_signature)
else:
    print("> Signature verified for message:", message)

### Challenge 43

In [None]:
import random
import hashlib
from Crypto.Util.number import getPrime

In [None]:
class DSA:
    
    """Implements the DSA public key encryption / decryption."""
    DEFAULT_P = int("800000000000000089e1855218a0e7dac38136ffafa72eda7859f2171e25e65eac698c1702578b07dc2a1076da241c76"
                    "c62d374d8389ea5aeffd3226a0530cc565f3bf6b50929139ebeac04f48c3c84afb796d61e5a4f9a8fda812ab59494232"
                    "c7d2b4deb50aa18ee9e132bfa85ac4374d7f9091abc3d015efc871a584471bb1", 16)
    DEFAULT_Q = 0xf4f47f05794b256174bba6e9b396a7707e563c5b
    DEFAULT_G = int("5958c9d3898b224b12672c0b98e06c60df923cb8bc999d119458fef538b8fa4046c8db53039db620c094c9fa077ef389"
                    "b5322a559946a71903f990f1f7e0e025e2d7f7cf494aff1a0470f5b64c36b625a097f1651fe775323556fe00b3608c88"
                    "7892878480e99041be601a62166ca6894bdd41a7054ec89f756ba9fc95302291", 16)
        
    def __init__(self, p = DEFAULT_P, q = DEFAULT_Q, g = DEFAULT_G):
        self.p = p
        self.q = q
        self.g = g
        self.x, self. y = self._per_user_key()
        self.pvt, self.pub = self.x, self.y
        
    def _per_user_key(self):
        x = random.randint(1, self.q - 1)
        y = pow(self.g, x, self.p)
        return x, y
    
    def H(self, message):
        return int(hashlib.sha1(message).hexdigest(), 16)
    
    def key_distribution(self):
        return self.pub
    
    def generate_signature(self, message):
        
        while True:
            k = random.randint(1, self.q - 1)
            r = pow(self.g, k, self.p) % self.q
            if r == 0:
                continue
                
            s = (mod_inverse(k, self.q) * (self.H(message) + self.x * r)) % self.q
            if s != 0:
                break
        return (r, s)
    
    def verify_signature(self, r, s, message):
        if r < 0 or r > self.q:
            return False
        if s < 0 or s > self.q:
            return False
        
        w = mod_inverse(s, self.q)
        u1 = (self.H(message) * w) % self.q
        u2 = (r * w) % self.q
        
        v1 = pow(self.g, u1, self.p)
        v2 = pow(self.y, u2, self.p)
        
        v = ((v1 * v2) % self.p) % self.q
        return v == r

In [None]:
# Filter object is created, iter is used to fetch values and then tuple is accessed
modulo_list = [(1024, 160), (2048, 224), (2048, 256), (3072, 256)]    

def DSA_parameter_generation(key_length):
    N = filter(lambda x:key_length in x, modulo_list).__next__()[1]
    q = getPrime(N)
    
    p = 0
    while True:
        p = getPrime(key_length)
        if (p - 1) % q == 0:
            break

    g = 1
    h = 0
    
    while True:
        h = random.randint(2, p - 2)
        g = h**((p - 1) / q)
        if g != 1:
            break
    
    return p, q, g

# Takes a lot of time
# p, q, g = DSA_parameter_generation(1024)
# dsa = DSA(p, q, g)

In [None]:
dsa = DSA()
signature = dsa.generate_signature(b"Hello World!")
assert dsa.verify_signature(signature[0], signature[1], b"Hello World!")

In [None]:
def DSA_x_from_k(k, q, r, s, message_int):
    return (((s * k) - message_int) * mod_inverse(r, q)) % q

In [None]:
def key_recovery_from_nonce(q, r, s, y, message_int):
    for k in range(2**16):
        x = DSA_x_from_k(k, q, r, s, message_int)
        
        # given in question, [2:] tp skip the 0x
        if hashlib.sha1(hex(x)[2:].encode()).hexdigest() == "0954edd5e0afe5542a4adf012611a91912a3ec16":
            return x
    return 0

message = "For those that envy a MC it can be hazardous to your health\nSo be friendly, a matter of life and death, just like a etch-a-sketch\n"
r, s = dsa.generate_signature(message.encode())
assert dsa.verify_signature(r, s, message.encode())

# parameters given in the quesion
# used to verify if our implementation works correctly
q = 0xf4f47f05794b256174bba6e9b396a7707e563c5b
r = 548099063082341131477253921760299949438196259240
s = 857042759984254168557880549501802188789837994940
y = int("84ad4719d044495496a3201c8ff484feb45b962e7302e56a392aee4abab3e4bdebf2955b4736012f21a0808"
        "4056b19bcd7fee56048e004e44984e2f411788efdc837a0d2e5abb7b555039fd243ac01f0fb2ed1dec56828"
        "0ce678e931868d23eb095fde9d3779191b8c0299d6e07bbb283e6633451e535c45513b2d33c99ea17", 16)

key = key_recovery_from_nonce(q, r, s, y, dsa.H(message.encode()))
if key != 0:
    print("> Brute force successful.\nPrivate key:", key)

### Challenge 44

In [None]:
import hashlib

In [None]:
data = open('44.txt', 'r').read()
data_list = data.split('\n')

q = 0xf4f47f05794b256174bba6e9b396a7707e563c5b
y = int("2d026f4bf30195ede3a088da85e398ef869611d0f68f0713d51c9c1a3a26c95105d915e2d8cdf26d056b86b8a7b8"
    "5519b1c23cc3ecdc6062650462e3063bd179c2a6581519f674a61f1d89a1fff27171ebc1b93d4dc57bceb7ae2430f98a"
    "6a4d83d8279ee65d71c1203d2c96d65ebbf7cce9d32971c3de5084cce04a2e147821", 16)

target = "ca8f6f7c66fa362d40760d135b763eb8527d3d52"

message_dicts = []
for i in range(0, len(data_list)-4, 4):
    message_dicts.append({"msg":data_list[i][5:], "s":int(data_list[i + 1][3:]), "r":int(data_list[i + 2][3:]), "m":int(data_list[i + 3][3:], 16)})

In [None]:
def DSA_x_from_k(k, q, r, s, message_int):
    return (((s * k) - message_int) * mod_inverse(r, q)) % q

In [None]:
def nonce_recovery_from_repeated_nonce():
    
    # Find two pairs of signatures that used the same k
    # This is easy to find, because when the same k is used r will be the same, since r
    # depends only on (g, p, q and k), and (g, p, q) are fixed in our implementation.

    # find indices of matching r
    found = False
    r1, s1, s2, m1, m2 = 0, 0, 0, 0, 0
    for i in range(len(message_dicts)):
        for j in range(len(message_dicts[i:])):
            if message_dicts[i]["r"] == message_dicts[j]["r"]:
                m1 = message_dicts[i]["m"]
                m2 = message_dicts[j]["m"]
                if m1 == m2:
                    continue
                found = True
                r1 = message_dicts[i]["r"]
                s1 = message_dicts[i]["s"]
                s2 = message_dicts[j]["s"]                
                break
        if found:
            break
    k = (((m1 - m2) % q) * mod_inverse((s1 - s2) % q, q)) % q
    return DSA_x_from_k(k, q, r1, s1, m1)

In [None]:
recovered_x = nonce_recovery_from_repeated_nonce()
assert hashlib.sha1(hex(recovered_x)[2:].encode()).hexdigest() == target

### Challenge 45

In [None]:
class DSA_flawed:
    # allows r = 0
    """Implements the DSA public key encryption / decryption."""
    DEFAULT_P = int("800000000000000089e1855218a0e7dac38136ffafa72eda7859f2171e25e65eac698c1702578b07dc2a1076da241c76"
                    "c62d374d8389ea5aeffd3226a0530cc565f3bf6b50929139ebeac04f48c3c84afb796d61e5a4f9a8fda812ab59494232"
                    "c7d2b4deb50aa18ee9e132bfa85ac4374d7f9091abc3d015efc871a584471bb1", 16)
    DEFAULT_Q = 0xf4f47f05794b256174bba6e9b396a7707e563c5b
    DEFAULT_G = int("5958c9d3898b224b12672c0b98e06c60df923cb8bc999d119458fef538b8fa4046c8db53039db620c094c9fa077ef389"
                    "b5322a559946a71903f990f1f7e0e025e2d7f7cf494aff1a0470f5b64c36b625a097f1651fe775323556fe00b3608c88"
                    "7892878480e99041be601a62166ca6894bdd41a7054ec89f756ba9fc95302291", 16)
        
    def __init__(self, p = DEFAULT_P, q = DEFAULT_Q, g = DEFAULT_G):
        self.p = p
        self.q = q
        self.g = g
        self.x, self. y = self._per_user_key()
        self.pvt, self.pub = self.x, self.y
        
    def _per_user_key(self):
        x = random.randint(1, self.q - 1)
        y = pow(self.g, x, self.p)
        return x, y
    
    def H(self, message):
        return int(hashlib.sha1(message).hexdigest(), 16)
    
    def key_distribution(self):
        return self.pub
    
    def generate_signature(self, message):
        
        while True:
            k = random.randint(1, self.q - 1)
            r = pow(self.g, k, self.p) % self.q                
            s = (mod_inverse(k, self.q) * (self.H(message) + self.x * r)) % self.q
            if s != 0:
                break
        return (r, s)
    
    def verify_signature(self, r, s, message):
        if r < 0 or r > self.q:
            return False
        if s < 0 or s > self.q:
            return False
        
        w = mod_inverse(s, self.q)
        u1 = (self.H(message) * w) % self.q
        u2 = (r * w) % self.q
        
        v1 = pow(self.g, u1, self.p)
        v2 = pow(self.y, u2, self.p)
        
        v = ((v1 * v2) % self.p) % self.q
        return v == r

In [None]:
dsa = DSA_flawed(g = 0)
message = "Original message"

signature = dsa.generate_signature(message.encode())
print("> Message:", message)
print("> Signature generated for g = 0.\nr:", signature[0], "\ns:", signature[1])
check = dsa.verify_signature(signature[0], signature[1], message.encode())
if check:
    print("> Signature successfully verified.")
    
tampered_message = "Tampered message!"
print("> Trying to verify signature of initial message for message:", tampered_message)
print("> Values from previous signature:\nr:", signature[0], "\ns:", signature[1])
check = dsa.verify_signature(signature[0], signature[1], tampered_message.encode())
if check:
    print("> Signature successfully verified.")

In [None]:
def DSA_parameter_tampering():

    dsa = DSA_flawed(g = DSA.DEFAULT_P + 1)
    message = "g = (p + 1) DSA"
    signature = dsa.generate_signature(message.encode())
    print("> Message:", message)
    print("> Signature generated for g = (p + 1).\nr:", signature[0], "\ns:", signature[1])
    check = dsa.verify_signature(signature[0], signature[1], message.encode())
    if check:
        print("> Signature successfully verified for original message.")
    
    z = random.randint(1, 100)
    y = dsa.key_distribution()
    forged_r = pow(y, z, DSA_flawed.DEFAULT_P) % DSA_flawed.DEFAULT_Q
    forged_s = (forged_r * mod_inverse(z, dsa.DEFAULT_Q)) % dsa.DEFAULT_Q
    
    message1 = "Hello, world"
    message2 = "Goodbye, world"
    
    print("> Values from forged signature:\nr:", forged_r, "\ns:", forged_s)
    
    print("> Message 1:", message1)
    if dsa.verify_signature(forged_r, forged_s, message1.encode()):
        print("> Signature successfully verified for message 1.")
    print("> Message 2:", message2)
    if dsa.verify_signature(forged_r, forged_s, message2.encode()):
        print("> Signature successfully verified for message 2.")

In [None]:
DSA_parameter_tampering()

### Challenge 46

In [None]:
import math
import base64
import decimal

In [None]:
def check_parity(ciphertext_int, rsa):
    return rsa.decryptnum(ciphertext_int) & 1

In [None]:
rsa = RSA(1024)
ciphertext = rsa.encrypt(b"Hello")
ciphertext_int = int.from_bytes(ciphertext, "big")
print(check_parity(ciphertext_int, rsa))

In [None]:
def parity_attack(message, rsa):
    (_, n) = rsa.pub
    ciphertext = rsa.encryptnum(int.from_bytes(message, "big"))
    
    # encrypt multiplier
    multiplier = rsa.encryptnum(2)
    
    # Initialize lower and upper bound.
    # I need to use Decimal because it allows me to set the precision for the floating point
    # numbers, which we will need when doing the binary search divisions.
    lower_bound = decimal.Decimal(0)
    upper_bound = decimal.Decimal(n)
    
    # Compute the number of iterations that we have to do
    num_iter = int(math.ceil(math.log(n, 2)))
    # Set the precision of the floating point number to be enough
    decimal.getcontext().prec = num_iter

    for _ in range(num_iter):
        ciphertext = (ciphertext * multiplier) % n
        
        # checking parity
        if check_parity(ciphertext, rsa) & 1:
            lower_bound = (lower_bound + upper_bound) / 2
        else:
            upper_bound = (lower_bound + upper_bound) / 2

    # Return the binary version of the upper_bound (converted from Decimal to int)
    return int(upper_bound).to_bytes((int(upper_bound).bit_length() + 7) // 8, "big").decode("utf-8")

In [None]:
given_string = "VGhhdCdzIHdoeSBJIGZvdW5kIHlvdSBkb24ndCBwbGF5IGFyb3VuZCB3aXRoIHRoZSBGdW5reSBDb2xkIE1lZGluYQ=="
byte_string = base64.b64decode(given_string)
plaintext = parity_attack(byte_string, RSA(1024))

In [None]:
assert plaintext == byte_string.decode("utf-8")

### Challenge 47

In [None]:
import os
import random

In [None]:
class RSA_PKCS1_Oracle(RSA):
    """Extends the RSA class by making the decryption PKCS 1.5 compliant and by adding a method
    to verify the padding of data."""
    
    def PKCS1_Pad(self, message):
        """Pads the given binary data conforming to the PKCS 1.5 format."""
        (e, n) = self.pub
        byte_length = (n.bit_length() + 7) // 8
        padding_string = os.getrandom(byte_length - 3 - len(message))
        return b"\x00\x02" + padding_string + b'\x00' + message
  
    def PKCS1_check_padding(self, ciphertext):
        """Decrypts the input data and returns whether its padding is correct according to PKCS 1.5."""
        _, n = self.pub
        k = (n.bit_length() + 7) // 8
        pbytes = self.decrypt(ciphertext)
        pbytes = (b'\x00' * (k - len(pbytes))) + pbytes
        return pbytes[0:2] == b'\x00\x02'

In [None]:
def append_interval(M_narrow, lower_bound, upper_bound):
    # Check if there exist an interval which is overlapping with the lower_bound and
    # upper_bound of the new interval we want to append
    for i, (a, b) in enumerate(M_narrow):

        # If there is an overlap, then replace the boundaries of the overlapping
        # interval with the wider (or equal) boundaries of the new merged interval
        if not (b < lower_bound or a > upper_bound):
            new_a = min(lower_bound, a)
            new_b = max(upper_bound, b)
            M_narrow[i] = new_a, new_b
            return

    # If there was no interval overlapping with the one we want to add, add
    # the new interval as a standalone interval to the list
    M_narrow.append((lower_bound, upper_bound))
    return

def ceil(a, b):
    return (a + b - 1) // b

def padding_oracle_attack(ciphertext, rsa):
    
    # Setting initial values
    
    (e, n) = rsa.pub
    k = (n.bit_length() + 7) // 8 # byte length
    B = 2**(8 * (k - 2))
    M = [(2 * B, 3 * B - 1)]
    i = 1
    
    if not rsa.PKCS1_check_padding(ciphertext):
        #Step 1 Blinding
        while True:
            s = random.randint(0, n - 1)
            c0 = (ciphertext * pow(s, e, n)) % n
            if rsa.PKCS1_check_padding(c0):
                break

    else:
        c0 = ciphertext
        
    # Step 2 Searching for PKCS conforming messages
    while True:
        # 2a
        if i == 1:
            s = (n + 3 * B - 1) // (3 * B)
            while True:
                c = (c0 * pow(s, e, n)) % n
                if rsa.PKCS1_check_padding(c):
                    break
                s += 1
        
        #2c
        # Step 2.c: Searching with one interval left
        elif len(M) == 1:
            a, b = M[0]

            # Check if the interval contains the solution
            if a == b:

                # And if it does, return it as bytes
                return b'\x00' + (a).to_bytes((a.bit_length() +7) // 8, "big")

            r = ceil(2 * (b * s - 2 * B), n)
            s = ceil(2 * B + r * n, b)

            while True:
                c = (c0 * pow(s, e, n)) % n
                if rsa.PKCS1_check_padding(c):
                    break

                s += 1
                if s > (3 * B + r * n) // a:
                    r += 1
                    s = ceil((2 * B + r * n), b)

        # Step 3: Narrowing the set of solutions
        M_new = []

        for a, b in M:
            min_r = ceil(a * s - 3 * B + 1, n)
            max_r = (b * s - 2 * B) // n

            for r in range(min_r, max_r + 1):
                l = max(a, ceil(2 * B + r * n, s))
                u = min(b, (3 * B - 1 + r * n) // s)

                if l > u:
                    raise Exception('Unexpected error: l > u in step 3')

                append_interval(M_new, l, u)

        if len(M_new) == 0:
            raise Exception('Unexpected error: there are 0 intervals.')

        M = M_new
        i += 1

In [None]:
rsa = RSA_PKCS1_Oracle(256)
message = "kick it, CC"
m = rsa.PKCS1_Pad(message.encode())

c = rsa.encrypt(m)
assert rsa.PKCS1_check_padding(c)
print("> Ciphertext padding verified.")

In [None]:
recovered_plaintext = padding_oracle_attack(c, rsa)
assert recovered_plaintext == m

### Challenge 48

In [None]:
import os
import random

In [None]:
class RSA_PKCS1_Oracle(RSA):
    """Extends the RSA class by making the decryption PKCS 1.5 compliant and by adding a method
    to verify the padding of data."""
    
    def PKCS1_Pad(self, message):
        """Pads the given binary data conforming to the PKCS 1.5 format."""
        (e, n) = self.pub
        byte_length = (n.bit_length() + 7) // 8
        padding_string = os.getrandom(byte_length - 3 - len(message))
        return b"\x00\x02" + padding_string + b'\x00' + message
  
    def PKCS1_check_padding(self, ciphertext):
        """Decrypts the input data and returns whether its padding is correct according to PKCS 1.5."""
        _, n = self.pub
        k = (n.bit_length() + 7) // 8
        pbytes = self.decrypt(ciphertext)
        pbytes = (b'\x00' * (k - len(pbytes))) + pbytes
        return pbytes[0:2] == b'\x00\x02'

In [None]:
def append_interval(M_narrow, lower_bound, upper_bound):
    # Check if there exist an interval which is overlapping with the lower_bound and
    # upper_bound of the new interval we want to append
    for i, (a, b) in enumerate(M_narrow):

        # If there is an overlap, then replace the boundaries of the overlapping
        # interval with the wider (or equal) boundaries of the new merged interval
        if not (b < lower_bound or a > upper_bound):
            new_a = min(lower_bound, a)
            new_b = max(upper_bound, b)
            M_narrow[i] = new_a, new_b
            return

    # If there was no interval overlapping with the one we want to add, add
    # the new interval as a standalone interval to the list
    M_narrow.append((lower_bound, upper_bound))
    return

def ceil(a, b):
    return (a + b - 1) // b

def padding_oracle_attack(ciphertext, rsa):
    
    # Setting initial values
    
    (e, n) = rsa.pub
    k = (n.bit_length() + 7) // 8 # byte length
    B = 2**(8 * (k - 2))
    M = [(2 * B, 3 * B - 1)]
    i = 1
    
    if not rsa.PKCS1_check_padding(ciphertext):
        #Step 1 Blinding
        while True:
            s = random.randint(0, n - 1)
            c0 = (ciphertext * pow(s, e, n)) % n
            if rsa.PKCS1_check_padding(c0):
                break

    else:
        c0 = ciphertext
        
    # Step 2 Searching for PKCS conforming messages
    while True:
        # 2a
        if i == 1:
            s = (n + 3 * B - 1) // (3 * B)
            while True:
                c = (c0 * pow(s, e, n)) % n
                if rsa.PKCS1_check_padding(c):
                    break
                s += 1

        #2b
        elif len(M) >= 2:
            while True:
                s += 1
                c = (c0 * pow(s, e, n)) % n
                if rsa.PKCS1_check_padding(c):
                    break
        
        #2c
        # Step 2.c: Searching with one interval left
        elif len(M) == 1:
            a, b = M[0]

            # Check if the interval contains the solution
            if a == b:

                # And if it does, return it as bytes
                return b'\x00' + (a).to_bytes((a.bit_length() +7) // 8, "big")

            r = ceil(2 * (b * s - 2 * B), n)
            s = ceil(2 * B + r * n, b)

            while True:
                c = (c0 * pow(s, e, n)) % n
                if rsa.PKCS1_check_padding(c):
                    break

                s += 1
                if s > (3 * B + r * n) // a:
                    r += 1
                    s = ceil((2 * B + r * n), b)

        # Step 3: Narrowing the set of solutions
        M_new = []

        for a, b in M:
            min_r = ceil(a * s - 3 * B + 1, n)
            max_r = (b * s - 2 * B) // n

            for r in range(min_r, max_r + 1):
                l = max(a, ceil(2 * B + r * n, s))
                u = min(b, (3 * B - 1 + r * n) // s)

                if l > u:
                    raise Exception('Unexpected error: l > u in step 3')

                append_interval(M_new, l, u)

        if len(M_new) == 0:
            raise Exception('Unexpected error: there are 0 intervals.')

        M = M_new
        i += 1

In [None]:
rsa = RSA_PKCS1_Oracle(768)
message = "kick it, CC"
m = rsa.PKCS1_Pad(message.encode())

c = rsa.encrypt(m)
assert rsa.PKCS1_check_padding(c)
print("> Ciphertext padding verified.")

In [None]:
recovered_plaintext = padding_oracle_attack(c, rsa)
assert recovered_plaintext == m